reorganise token reranking

As the reranking is about changing penalties in presence of other
tokens, change the datastructure to have the other tokens readily
avilable.
This commit is contained in:
Sarah Hoffmann
2025-04-11 13:38:34 +02:00
parent b680d81f0a
commit 2ef0e20a3f
2 changed files with 43 additions and 26 deletions

View File

@@ -267,27 +267,38 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
def rerank_tokens(self, query: qmod.QueryStruct) -> None: def rerank_tokens(self, query: qmod.QueryStruct) -> None:
""" Add penalties to tokens that depend on presence of other token. """ Add penalties to tokens that depend on presence of other token.
""" """
for i, node, tlist in query.iter_token_lists(): for start, end, tlist in query.iter_tokens_by_edge():
if tlist.ttype == qmod.TOKEN_POSTCODE: if len(tlist) > 1:
tlen = len(cast(ICUToken, tlist.tokens[0]).word_token) # If it looks like a Postcode, give preference.
for repl in node.starting: if qmod.TOKEN_POSTCODE in tlist:
if repl.end == tlist.end and repl.ttype != qmod.TOKEN_POSTCODE \ for ttype, tokens in tlist.items():
and (repl.ttype != qmod.TOKEN_HOUSENUMBER or tlen > 4): if ttype != qmod.TOKEN_POSTCODE and \
repl.add_penalty(0.39) (ttype != qmod.TOKEN_HOUSENUMBER or
elif (tlist.ttype == qmod.TOKEN_HOUSENUMBER start + 1 > end or
and len(tlist.tokens[0].lookup_word) <= 3): len(query.nodes[end].term_lookup) > 4):
if any(c.isdigit() for c in tlist.tokens[0].lookup_word): for token in tokens:
for repl in node.starting: token.penalty += 0.39
if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER:
repl.add_penalty(0.5 - tlist.tokens[0].penalty) # If it looks like a simple housenumber, prefer that.
elif tlist.ttype != qmod.TOKEN_COUNTRY: if qmod.TOKEN_HOUSENUMBER in tlist:
norm = ' '.join(n.term_normalized for n in query.nodes[i + 1:tlist.end + 1] hnr_lookup = tlist[qmod.TOKEN_HOUSENUMBER][0].lookup_word
if n.btype != qmod.BREAK_TOKEN) if len(hnr_lookup) <= 3 and any(c.isdigit() for c in hnr_lookup):
if not norm: penalty = 0.5 - tlist[qmod.TOKEN_HOUSENUMBER][0].penalty
# Can happen when the token only covers a partial term for ttype, tokens in tlist.items():
norm = query.nodes[i + 1].term_normalized if ttype != qmod.TOKEN_HOUSENUMBER:
for token in tlist.tokens: for token in tokens:
cast(ICUToken, token).rematch(norm) token.penalty += penalty
# rerank tokens against the normalized form
norm = ' '.join(n.term_normalized for n in query.nodes[start + 1:end + 1]
if n.btype != qmod.BREAK_TOKEN)
if not norm:
# Can happen when the token only covers a partial term
norm = query.nodes[start + 1].term_normalized
for ttype, tokens in tlist.items():
if ttype != qmod.TOKEN_COUNTRY:
for token in tokens:
cast(ICUToken, token).rematch(norm)
def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]: def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:

View File

@@ -183,10 +183,10 @@ class QueryNode:
""" Penalty for the break at this node. """ Penalty for the break at this node.
""" """
term_lookup: str term_lookup: str
""" Transliterated term following this node. """ Transliterated term ending at this node.
""" """
term_normalized: str term_normalized: str
""" Normalised form of term following this node. """ Normalised form of term ending at this node.
When the token resulted from a split during transliteration, When the token resulted from a split during transliteration,
then this string contains the complete source term. then this string contains the complete source term.
""" """
@@ -307,12 +307,18 @@ class QueryStruct:
""" """
return (n.partial for n in self.nodes[trange.start:trange.end] if n.partial is not None) return (n.partial for n in self.nodes[trange.start:trange.end] if n.partial is not None)
def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]: def iter_tokens_by_edge(self) -> Iterator[Tuple[int, int, Dict[TokenType, List[Token]]]]:
""" Iterator over all token lists except partial tokens in the query. """ Iterator over all tokens except partial ones grouped by edge.
Returns the start and end node indexes and a dictionary
of list of tokens by token type.
""" """
for i, node in enumerate(self.nodes): for i, node in enumerate(self.nodes):
by_end: Dict[int, Dict[TokenType, List[Token]]] = defaultdict(dict)
for tlist in node.starting: for tlist in node.starting:
yield i, node, tlist by_end[tlist.end][tlist.ttype] = tlist.tokens
for end, endlist in by_end.items():
yield i, end, endlist
def find_lookup_word_by_id(self, token: int) -> str: def find_lookup_word_by_id(self, token: int) -> str:
""" Find the first token with the given token ID and return """ Find the first token with the given token ID and return