reorganise token reranking

As the reranking is about changing penalties in presence of other
tokens, change the datastructure to have the other tokens readily
avilable.
This commit is contained in:
Sarah Hoffmann
2025-04-11 13:38:34 +02:00
parent b680d81f0a
commit 2ef0e20a3f
2 changed files with 43 additions and 26 deletions

View File

@@ -267,27 +267,38 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
def rerank_tokens(self, query: qmod.QueryStruct) -> None:
""" Add penalties to tokens that depend on presence of other token.
"""
for i, node, tlist in query.iter_token_lists():
if tlist.ttype == qmod.TOKEN_POSTCODE:
tlen = len(cast(ICUToken, tlist.tokens[0]).word_token)
for repl in node.starting:
if repl.end == tlist.end and repl.ttype != qmod.TOKEN_POSTCODE \
and (repl.ttype != qmod.TOKEN_HOUSENUMBER or tlen > 4):
repl.add_penalty(0.39)
elif (tlist.ttype == qmod.TOKEN_HOUSENUMBER
and len(tlist.tokens[0].lookup_word) <= 3):
if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
for repl in node.starting:
if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER:
repl.add_penalty(0.5 - tlist.tokens[0].penalty)
elif tlist.ttype != qmod.TOKEN_COUNTRY:
norm = ' '.join(n.term_normalized for n in query.nodes[i + 1:tlist.end + 1]
if n.btype != qmod.BREAK_TOKEN)
if not norm:
# Can happen when the token only covers a partial term
norm = query.nodes[i + 1].term_normalized
for token in tlist.tokens:
cast(ICUToken, token).rematch(norm)
for start, end, tlist in query.iter_tokens_by_edge():
if len(tlist) > 1:
# If it looks like a Postcode, give preference.
if qmod.TOKEN_POSTCODE in tlist:
for ttype, tokens in tlist.items():
if ttype != qmod.TOKEN_POSTCODE and \
(ttype != qmod.TOKEN_HOUSENUMBER or
start + 1 > end or
len(query.nodes[end].term_lookup) > 4):
for token in tokens:
token.penalty += 0.39
# If it looks like a simple housenumber, prefer that.
if qmod.TOKEN_HOUSENUMBER in tlist:
hnr_lookup = tlist[qmod.TOKEN_HOUSENUMBER][0].lookup_word
if len(hnr_lookup) <= 3 and any(c.isdigit() for c in hnr_lookup):
penalty = 0.5 - tlist[qmod.TOKEN_HOUSENUMBER][0].penalty
for ttype, tokens in tlist.items():
if ttype != qmod.TOKEN_HOUSENUMBER:
for token in tokens:
token.penalty += penalty
# rerank tokens against the normalized form
norm = ' '.join(n.term_normalized for n in query.nodes[start + 1:end + 1]
if n.btype != qmod.BREAK_TOKEN)
if not norm:
# Can happen when the token only covers a partial term
norm = query.nodes[start + 1].term_normalized
for ttype, tokens in tlist.items():
if ttype != qmod.TOKEN_COUNTRY:
for token in tokens:
cast(ICUToken, token).rematch(norm)
def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]: