move partial token into a separate field in the query struct

There is exactly one token to be expected and the token is usually
present.
This commit is contained in:
Sarah Hoffmann
2025-04-11 08:57:34 +02:00
parent 1db717b886
commit 497e27bb9a
6 changed files with 78 additions and 51 deletions

View File

@@ -2,7 +2,7 @@
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# Copyright (C) 2025 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Conversion from token assignment to an abstract DB search.
@@ -296,26 +296,27 @@ class SearchBuilder:
while todo:
neglen, pos, rank = heapq.heappop(todo)
# partial node
partial = self.query.nodes[pos].partial
if partial is not None:
if pos + 1 < trange.end:
penalty = rank.penalty + partial.penalty \
+ PENALTY_WORDCHANGE[self.query.nodes[pos + 1].btype]
heapq.heappush(todo, (neglen - 1, pos + 1,
dbf.RankedTokens(penalty, rank.tokens)))
else:
ranks.append(dbf.RankedTokens(rank.penalty + partial.penalty,
rank.tokens))
# full words
for tlist in self.query.nodes[pos].starting:
if tlist.ttype in (qmod.TOKEN_PARTIAL, qmod.TOKEN_WORD):
if tlist.ttype == qmod.TOKEN_WORD:
if tlist.end < trange.end:
chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
if tlist.ttype == qmod.TOKEN_PARTIAL:
penalty = rank.penalty + chgpenalty \
+ max(t.penalty for t in tlist.tokens)
for t in tlist.tokens:
heapq.heappush(todo, (neglen - 1, tlist.end,
dbf.RankedTokens(penalty, rank.tokens)))
else:
for t in tlist.tokens:
heapq.heappush(todo, (neglen - 1, tlist.end,
rank.with_token(t, chgpenalty)))
rank.with_token(t, chgpenalty)))
elif tlist.end == trange.end:
if tlist.ttype == qmod.TOKEN_PARTIAL:
ranks.append(dbf.RankedTokens(rank.penalty
+ max(t.penalty for t in tlist.tokens),
rank.tokens))
else:
ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
if len(ranks) >= 10:
# Too many variants, bail out and only add
# Worst-case Fallback: sum of penalty of partials

View File

@@ -2,7 +2,7 @@
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# Copyright (C) 2025 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Implementation of query analysis for the ICU tokenizer.
@@ -280,7 +280,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
for repl in node.starting:
if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER:
repl.add_penalty(0.5 - tlist.tokens[0].penalty)
elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL):
elif tlist.ttype != qmod.TOKEN_COUNTRY:
norm = ' '.join(n.term_normalized for n in query.nodes[i + 1:tlist.end + 1]
if n.btype != qmod.BREAK_TOKEN)
if not norm:
@@ -293,6 +293,10 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
yield ['type', 'from', 'to', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
for i, node in enumerate(query.nodes):
if node.partial is not None:
t = cast(ICUToken, node.partial)
yield [qmod.TOKEN_PARTIAL, str(i), str(i + 1), t.token,
t.word_token, t.lookup_word, t.penalty, t.count, t.info]
for tlist in node.starting:
for token in tlist.tokens:
t = cast(ICUToken, token)

View File

@@ -2,7 +2,7 @@
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# Copyright (C) 2025 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Datastructures for a tokenized query.
@@ -192,6 +192,14 @@ class QueryNode:
"""
starting: List[TokenList] = dataclasses.field(default_factory=list)
""" List of all full tokens starting at this node.
"""
partial: Optional[Token] = None
""" Base token going to the next node.
May be None when the query has parts for which no words are known.
Note that the query may still be parsable when there are other
types of tokens spanning over the gap.
"""
def adjust_break(self, btype: BreakType, penalty: float) -> None:
""" Change the break type and penalty for this node.
@@ -269,33 +277,37 @@ class QueryStruct:
be added to, then the token is silently dropped.
"""
snode = self.nodes[trange.start]
full_phrase = snode.btype in (BREAK_START, BREAK_PHRASE)\
and self.nodes[trange.end].btype in (BREAK_PHRASE, BREAK_END)
if _phrase_compatible_with(snode.ptype, ttype, full_phrase):
tlist = snode.get_tokens(trange.end, ttype)
if tlist is None:
snode.starting.append(TokenList(trange.end, ttype, [token]))
else:
tlist.append(token)
if ttype == TOKEN_PARTIAL:
assert snode.partial is None
if _phrase_compatible_with(snode.ptype, TOKEN_PARTIAL, False):
snode.partial = token
else:
full_phrase = snode.btype in (BREAK_START, BREAK_PHRASE)\
and self.nodes[trange.end].btype in (BREAK_PHRASE, BREAK_END)
if _phrase_compatible_with(snode.ptype, ttype, full_phrase):
tlist = snode.get_tokens(trange.end, ttype)
if tlist is None:
snode.starting.append(TokenList(trange.end, ttype, [token]))
else:
tlist.append(token)
def get_tokens(self, trange: TokenRange, ttype: TokenType) -> List[Token]:
""" Get the list of tokens of a given type, spanning the given
nodes. The nodes must exist. If no tokens exist, an
empty list is returned.
Cannot be used to get the partial token.
"""
assert ttype != TOKEN_PARTIAL
return self.nodes[trange.start].get_tokens(trange.end, ttype) or []
def get_partials_list(self, trange: TokenRange) -> List[Token]:
""" Create a list of partial tokens between the given nodes.
The list is composed of the first token of type PARTIAL
going to the subsequent node. Such PARTIAL tokens are
assumed to exist.
"""
return [next(iter(self.get_tokens(TokenRange(i, i+1), TOKEN_PARTIAL)))
for i in range(trange.start, trange.end)]
return list(filter(None, (self.nodes[i].partial for i in range(trange.start, trange.end))))
def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]:
""" Iterator over all token lists in the query.
""" Iterator over all token lists except partial tokens in the query.
"""
for i, node in enumerate(self.nodes):
for tlist in node.starting:
@@ -308,6 +320,8 @@ class QueryStruct:
debugging.
"""
for node in self.nodes:
if node.partial is not None and node.partial.token == token:
return f"[P]{node.partial.lookup_word}"
for tlist in node.starting:
for t in tlist.tokens:
if t.token == token:

View File

@@ -409,11 +409,22 @@ def yield_token_assignments(query: qmod.QueryStruct) -> Iterator[TokenAssignment
node = query.nodes[state.end_pos]
for tlist in node.starting:
newstate = state.advance(tlist.ttype, tlist.end, node.btype)
if newstate is not None:
if newstate.end_pos == query.num_token_slots():
if newstate.recheck_sequence():
log().var_dump('Assignment', newstate)
yield from newstate.get_assignments(query)
elif not newstate.is_final():
todo.append(newstate)
yield from _append_state_to_todo(
query, todo,
state.advance(tlist.ttype, tlist.end, node.btype))
if node.partial is not None:
yield from _append_state_to_todo(
query, todo,
state.advance(qmod.TOKEN_PARTIAL, state.end_pos + 1, node.btype))
def _append_state_to_todo(query: qmod.QueryStruct, todo: List[_TokenSequence],
newstate: Optional[_TokenSequence]) -> Iterator[TokenAssignment]:
if newstate is not None:
if newstate.end_pos == query.num_token_slots():
if newstate.recheck_sequence():
log().var_dump('Assignment', newstate)
yield from newstate.get_assignments(query)
elif not newstate.is_final():
todo.append(newstate)

View File

@@ -44,7 +44,6 @@ def test_phrase_incompatible(ptype):
def test_query_node_empty(qnode):
assert not qnode.has_tokens(3, query.TOKEN_PARTIAL)
assert qnode.get_tokens(3, query.TOKEN_WORD) is None
@@ -57,7 +56,6 @@ def test_query_node_with_content(qnode):
assert qnode.has_tokens(2, query.TOKEN_PARTIAL)
assert qnode.has_tokens(2, query.TOKEN_WORD)
assert qnode.get_tokens(3, query.TOKEN_PARTIAL) is None
assert qnode.get_tokens(2, query.TOKEN_COUNTRY) is None
assert len(qnode.get_tokens(2, query.TOKEN_PARTIAL)) == 2
assert len(qnode.get_tokens(2, query.TOKEN_WORD)) == 1
@@ -101,7 +99,6 @@ def test_query_struct_incompatible_token():
q.add_token(query.TokenRange(0, 1), query.TOKEN_PARTIAL, mktoken(1))
q.add_token(query.TokenRange(1, 2), query.TOKEN_COUNTRY, mktoken(100))
assert q.get_tokens(query.TokenRange(0, 1), query.TOKEN_PARTIAL) == []
assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_COUNTRY)) == 1
@@ -113,7 +110,7 @@ def test_query_struct_amenity_single_word():
q.add_token(query.TokenRange(0, 1), query.TOKEN_NEAR_ITEM, mktoken(2))
q.add_token(query.TokenRange(0, 1), query.TOKEN_QUALIFIER, mktoken(3))
assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_PARTIAL)) == 1
assert q.nodes[0].partial.token == 1
assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_NEAR_ITEM)) == 1
assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_QUALIFIER)) == 0
@@ -128,10 +125,10 @@ def test_query_struct_amenity_two_words():
q.add_token(query.TokenRange(*trange), query.TOKEN_NEAR_ITEM, mktoken(2))
q.add_token(query.TokenRange(*trange), query.TOKEN_QUALIFIER, mktoken(3))
assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_PARTIAL)) == 1
assert q.nodes[0].partial.token == 1
assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_NEAR_ITEM)) == 0
assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_QUALIFIER)) == 1
assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_PARTIAL)) == 1
assert q.nodes[1].partial.token == 1
assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_NEAR_ITEM)) == 0
assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_QUALIFIER)) == 1

View File

@@ -69,8 +69,8 @@ async def test_single_phrase_with_unknown_terms(conn):
assert query.source[0].text == 'foo bar'
assert query.num_token_slots() == 2
assert len(query.nodes[0].starting) == 1
assert not query.nodes[1].starting
assert query.nodes[0].partial.token == 1
assert query.nodes[1].partial is None
@pytest.mark.asyncio
@@ -103,8 +103,8 @@ async def test_splitting_in_transliteration(conn):
@pytest.mark.asyncio
@pytest.mark.parametrize('term,order', [('23456', ['P', 'H', 'W', 'w']),
('3', ['H', 'W', 'w'])])
@pytest.mark.parametrize('term,order', [('23456', ['P', 'H', 'W']),
('3', ['H', 'W'])])
async def test_penalty_postcodes_and_housenumbers(conn, term, order):
ana = await tok.create_query_analyzer(conn)