simplify QueryNode penalty and initial assignment

This commit is contained in:
Sarah Hoffmann
2025-07-09 15:36:11 +02:00
parent 1aeb8a262c
commit 4a9253a0a9
3 changed files with 33 additions and 47 deletions

View File

@@ -37,17 +37,16 @@ DB_TO_TOKEN_TYPE = {
'C': qmod.TOKEN_COUNTRY 'C': qmod.TOKEN_COUNTRY
} }
PENALTY_IN_TOKEN_BREAK = { PENALTY_BREAK = {
qmod.BREAK_START: 0.5, qmod.BREAK_START: -0.5,
qmod.BREAK_END: 0.5, qmod.BREAK_END: -0.5,
qmod.BREAK_PHRASE: 0.5, qmod.BREAK_PHRASE: -0.5,
qmod.BREAK_SOFT_PHRASE: 0.5, qmod.BREAK_SOFT_PHRASE: -0.5,
qmod.BREAK_WORD: 0.1, qmod.BREAK_WORD: 0.0,
qmod.BREAK_PART: 0.0, qmod.BREAK_PART: 0.2,
qmod.BREAK_TOKEN: 0.0 qmod.BREAK_TOKEN: 0.4
} }
@dataclasses.dataclass @dataclasses.dataclass
class ICUToken(qmod.Token): class ICUToken(qmod.Token):
""" Specialised token for ICU tokenizer. """ Specialised token for ICU tokenizer.
@@ -78,13 +77,13 @@ class ICUToken(qmod.Token):
self.penalty += (distance/len(self.lookup_word)) self.penalty += (distance/len(self.lookup_word))
@staticmethod @staticmethod
def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken': def from_db_row(row: SaRow) -> 'ICUToken':
""" Create a ICUToken from the row of the word table. """ Create a ICUToken from the row of the word table.
""" """
count = 1 if row.info is None else row.info.get('count', 1) count = 1 if row.info is None else row.info.get('count', 1)
addr_count = 1 if row.info is None else row.info.get('addr_count', 1) addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
penalty = base_penalty penalty = 0.0
if row.type == 'w': if row.type == 'w':
penalty += 0.3 penalty += 0.3
elif row.type == 'W': elif row.type == 'W':
@@ -174,11 +173,14 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
self.split_query(query) self.split_query(query)
log().var_dump('Transliterated query', lambda: query.get_transliterated_query()) log().var_dump('Transliterated query', lambda: query.get_transliterated_query())
words = query.extract_words(base_penalty=PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]) words = query.extract_words()
for row in await self.lookup_in_db(list(words.keys())): for row in await self.lookup_in_db(list(words.keys())):
for trange in words[row.word_token]: for trange in words[row.word_token]:
token = ICUToken.from_db_row(row, trange.penalty or 0.0) # Create a new token for each position because the token
# penalty can vary depending on the position in the query.
# (See rerank_tokens() below.)
token = ICUToken.from_db_row(row)
if row.type == 'S': if row.type == 'S':
if row.info['op'] in ('in', 'near'): if row.info['op'] in ('in', 'near'):
if trange.start == 0: if trange.start == 0:
@@ -200,6 +202,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
lookup_word=pc, word_token=term, lookup_word=pc, word_token=term,
info=None)) info=None))
self.rerank_tokens(query) self.rerank_tokens(query)
self.compute_break_penalties(query)
log().table_dump('Word tokens', _dump_word_tokens(query)) log().table_dump('Word tokens', _dump_word_tokens(query))
@@ -232,10 +235,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
query.add_node(qmod.BREAK_TOKEN, phrase.ptype, query.add_node(qmod.BREAK_TOKEN, phrase.ptype,
PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN], PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN],
term, word) term, word)
query.nodes[-1].adjust_break(breakchar, query.nodes[-1].btype = breakchar
PENALTY_IN_TOKEN_BREAK[breakchar])
query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END]) query.nodes[-1].btype = qmod.BREAK_END
async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]': async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
""" Return the token information from the database for the """ Return the token information from the database for the
@@ -300,6 +302,12 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
for token in tokens: for token in tokens:
cast(ICUToken, token).rematch(norm) cast(ICUToken, token).rematch(norm)
def compute_break_penalties(self, query: qmod.QueryStruct) -> None:
""" Set the break penalties for the nodes in the query.
"""
for node in query.nodes:
node.penalty = PENALTY_BREAK[node.btype]
def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]: def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
yield ['type', 'from', 'to', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info'] yield ['type', 'from', 'to', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']

View File

@@ -191,7 +191,9 @@ class QueryNode:
ptype: PhraseType ptype: PhraseType
penalty: float penalty: float
""" Penalty for the break at this node. """ Penalty for having a word break at this position. The penalty
may be negative, when a word break is more likely than continuing
the word after the node.
""" """
term_lookup: str term_lookup: str
""" Transliterated term ending at this node. """ Transliterated term ending at this node.
@@ -221,12 +223,6 @@ class QueryNode:
return self.partial.count / (self.partial.count + self.partial.addr_count) return self.partial.count / (self.partial.count + self.partial.addr_count)
def adjust_break(self, btype: BreakType, penalty: float) -> None:
""" Change the break type and penalty for this node.
"""
self.btype = btype
self.penalty = penalty
def has_tokens(self, end: int, *ttypes: TokenType) -> bool: def has_tokens(self, end: int, *ttypes: TokenType) -> bool:
""" Check if there are tokens of the given types ending at the """ Check if there are tokens of the given types ending at the
given node. given node.
@@ -277,8 +273,7 @@ class QueryStruct:
self.source = source self.source = source
self.dir_penalty = 0.0 self.dir_penalty = 0.0
self.nodes: List[QueryNode] = \ self.nodes: List[QueryNode] = \
[QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY, [QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY)]
0.0, '', '')]
def num_token_slots(self) -> int: def num_token_slots(self) -> int:
""" Return the length of the query in vertice steps. """ Return the length of the query in vertice steps.
@@ -286,13 +281,12 @@ class QueryStruct:
return len(self.nodes) - 1 return len(self.nodes) - 1
def add_node(self, btype: BreakType, ptype: PhraseType, def add_node(self, btype: BreakType, ptype: PhraseType,
break_penalty: float = 0.0,
term_lookup: str = '', term_normalized: str = '') -> None: term_lookup: str = '', term_normalized: str = '') -> None:
""" Append a new break node with the given break type. """ Append a new break node with the given break type.
The phrase type denotes the type for any tokens starting The phrase type denotes the type for any tokens starting
at the node. at the node.
""" """
self.nodes.append(QueryNode(btype, ptype, break_penalty, term_lookup, term_normalized)) self.nodes.append(QueryNode(btype, ptype, 0.0, term_lookup, term_normalized))
def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None: def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None:
""" Add a token to the query. 'start' and 'end' are the indexes of the """ Add a token to the query. 'start' and 'end' are the indexes of the
@@ -386,17 +380,14 @@ class QueryStruct:
""" """
return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes) return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes)
def extract_words(self, base_penalty: float = 0.0, def extract_words(self, start: int = 0,
start: int = 0,
endpos: Optional[int] = None) -> Dict[str, List[TokenRange]]: endpos: Optional[int] = None) -> Dict[str, List[TokenRange]]:
""" Add all combinations of words that can be formed from the terms """ Add all combinations of words that can be formed from the terms
between the given start and endnode. The terms are joined with between the given start and endnode. The terms are joined with
spaces for each break. Words can never go across a BREAK_PHRASE. spaces for each break. Words can never go across a BREAK_PHRASE.
The functions returns a dictionary of possible words with their The functions returns a dictionary of possible words with their
position within the query and a penalty. The penalty is computed position within the query.
from the base_penalty plus the penalty for each node the word
crosses.
""" """
if endpos is None: if endpos is None:
endpos = len(self.nodes) endpos = len(self.nodes)
@@ -405,16 +396,13 @@ class QueryStruct:
for first, first_node in enumerate(self.nodes[start + 1:endpos], start): for first, first_node in enumerate(self.nodes[start + 1:endpos], start):
word = first_node.term_lookup word = first_node.term_lookup
penalty = base_penalty words[word].append(TokenRange(first, first + 1))
words[word].append(TokenRange(first, first + 1, penalty=penalty))
if first_node.btype != BREAK_PHRASE: if first_node.btype != BREAK_PHRASE:
penalty += first_node.penalty
max_last = min(first + 20, endpos) max_last = min(first + 20, endpos)
for last, last_node in enumerate(self.nodes[first + 2:max_last], first + 2): for last, last_node in enumerate(self.nodes[first + 2:max_last], first + 2):
word = ' '.join((word, last_node.term_lookup)) word = ' '.join((word, last_node.term_lookup))
words[word].append(TokenRange(first, last, penalty=penalty)) words[word].append(TokenRange(first, last))
if last_node.btype == BREAK_PHRASE: if last_node.btype == BREAK_PHRASE:
break break
penalty += last_node.penalty
return words return words

View File

@@ -23,16 +23,6 @@ class TypedRange:
trange: qmod.TokenRange trange: qmod.TokenRange
PENALTY_TOKENCHANGE = {
qmod.BREAK_START: 0.0,
qmod.BREAK_END: 0.0,
qmod.BREAK_PHRASE: 0.0,
qmod.BREAK_SOFT_PHRASE: 0.0,
qmod.BREAK_WORD: 0.1,
qmod.BREAK_PART: 0.2,
qmod.BREAK_TOKEN: 0.4
}
TypedRangeSeq = List[TypedRange] TypedRangeSeq = List[TypedRange]