simplify QueryNode penalty and initial assignment

This commit is contained in:
Sarah Hoffmann
2025-07-09 15:36:11 +02:00
parent 1aeb8a262c
commit 4a9253a0a9
3 changed files with 33 additions and 47 deletions

View File

@@ -37,17 +37,16 @@ DB_TO_TOKEN_TYPE = {
'C': qmod.TOKEN_COUNTRY
}
PENALTY_IN_TOKEN_BREAK = {
qmod.BREAK_START: 0.5,
qmod.BREAK_END: 0.5,
qmod.BREAK_PHRASE: 0.5,
qmod.BREAK_SOFT_PHRASE: 0.5,
qmod.BREAK_WORD: 0.1,
qmod.BREAK_PART: 0.0,
qmod.BREAK_TOKEN: 0.0
PENALTY_BREAK = {
qmod.BREAK_START: -0.5,
qmod.BREAK_END: -0.5,
qmod.BREAK_PHRASE: -0.5,
qmod.BREAK_SOFT_PHRASE: -0.5,
qmod.BREAK_WORD: 0.0,
qmod.BREAK_PART: 0.2,
qmod.BREAK_TOKEN: 0.4
}
@dataclasses.dataclass
class ICUToken(qmod.Token):
""" Specialised token for ICU tokenizer.
@@ -78,13 +77,13 @@ class ICUToken(qmod.Token):
self.penalty += (distance/len(self.lookup_word))
@staticmethod
def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
def from_db_row(row: SaRow) -> 'ICUToken':
""" Create a ICUToken from the row of the word table.
"""
count = 1 if row.info is None else row.info.get('count', 1)
addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
penalty = base_penalty
penalty = 0.0
if row.type == 'w':
penalty += 0.3
elif row.type == 'W':
@@ -174,11 +173,14 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
self.split_query(query)
log().var_dump('Transliterated query', lambda: query.get_transliterated_query())
words = query.extract_words(base_penalty=PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD])
words = query.extract_words()
for row in await self.lookup_in_db(list(words.keys())):
for trange in words[row.word_token]:
token = ICUToken.from_db_row(row, trange.penalty or 0.0)
# Create a new token for each position because the token
# penalty can vary depending on the position in the query.
# (See rerank_tokens() below.)
token = ICUToken.from_db_row(row)
if row.type == 'S':
if row.info['op'] in ('in', 'near'):
if trange.start == 0:
@@ -200,6 +202,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
lookup_word=pc, word_token=term,
info=None))
self.rerank_tokens(query)
self.compute_break_penalties(query)
log().table_dump('Word tokens', _dump_word_tokens(query))
@@ -232,10 +235,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
query.add_node(qmod.BREAK_TOKEN, phrase.ptype,
PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN],
term, word)
query.nodes[-1].adjust_break(breakchar,
PENALTY_IN_TOKEN_BREAK[breakchar])
query.nodes[-1].btype = breakchar
query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END])
query.nodes[-1].btype = qmod.BREAK_END
async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
""" Return the token information from the database for the
@@ -300,6 +302,12 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
for token in tokens:
cast(ICUToken, token).rematch(norm)
def compute_break_penalties(self, query: qmod.QueryStruct) -> None:
""" Set the break penalties for the nodes in the query.
"""
for node in query.nodes:
node.penalty = PENALTY_BREAK[node.btype]
def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
yield ['type', 'from', 'to', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']