diff --git a/src/nominatim_api/search/icu_tokenizer.py b/src/nominatim_api/search/icu_tokenizer.py index 1a449276..60e712d5 100644 --- a/src/nominatim_api/search/icu_tokenizer.py +++ b/src/nominatim_api/search/icu_tokenizer.py @@ -47,40 +47,27 @@ PENALTY_IN_TOKEN_BREAK = { } -@dataclasses.dataclass -class QueryPart: - """ Normalized and transliterated form of a single term in the query. - - When the term came out of a split during the transliteration, - the normalized string is the full word before transliteration. - Check the subsequent break type to figure out if the word is - continued. - - Penalty is the break penalty for the break following the token. - """ - token: str - normalized: str - penalty: float - - -QueryParts = List[QueryPart] WordDict = Dict[str, List[qmod.TokenRange]] -def extract_words(terms: List[QueryPart], start: int, words: WordDict) -> None: - """ Add all combinations of words in the terms list after the - given position to the word list. +def extract_words(query: qmod.QueryStruct, start: int, words: WordDict) -> None: + """ Add all combinations of words in the terms list starting with + the term leading into node 'start'. + + The words found will be added into the 'words' dictionary with + their start and end position. """ - total = len(terms) + nodes = query.nodes + total = len(nodes) base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD] for first in range(start, total): - word = terms[first].token + word = nodes[first].term_lookup penalty = base_penalty - words[word].append(qmod.TokenRange(first, first + 1, penalty=penalty)) + words[word].append(qmod.TokenRange(first - 1, first, penalty=penalty)) for last in range(first + 1, min(first + 20, total)): - word = ' '.join((word, terms[last].token)) - penalty += terms[last - 1].penalty - words[word].append(qmod.TokenRange(first, last + 1, penalty=penalty)) + word = ' '.join((word, nodes[last].term_lookup)) + penalty += nodes[last - 1].penalty + words[word].append(qmod.TokenRange(first - 1, last, penalty=penalty)) @dataclasses.dataclass @@ -216,8 +203,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): if not query.source: return query - parts, words = self.split_query(query) - log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts)) + words = self.split_query(query) + log().var_dump('Transliterated query', lambda: query.get_transliterated_query()) for row in await self.lookup_in_db(list(words.keys())): for trange in words[row.word_token]: @@ -234,8 +221,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): else: query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token) - self.add_extra_tokens(query, parts) - self.rerank_tokens(query, parts) + self.add_extra_tokens(query) + self.rerank_tokens(query) log().table_dump('Word tokens', _dump_word_tokens(query)) @@ -248,15 +235,13 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): """ return cast(str, self.normalizer.transliterate(text)).strip('-: ') - def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]: + def split_query(self, query: qmod.QueryStruct) -> WordDict: """ Transliterate the phrases and split them into tokens. - Returns the list of transliterated tokens together with their - normalized form and a dictionary of words for lookup together + Returns a dictionary of words for lookup together with their position. """ - parts: QueryParts = [] - phrase_start = 0 + phrase_start = 1 words: WordDict = defaultdict(list) for phrase in query.source: query.nodes[-1].ptype = phrase.ptype @@ -272,18 +257,18 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): if trans: for term in trans.split(' '): if term: - parts.append(QueryPart(term, word, - PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN])) - query.add_node(qmod.BREAK_TOKEN, phrase.ptype) - query.nodes[-1].btype = breakchar - parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[breakchar] + query.add_node(qmod.BREAK_TOKEN, phrase.ptype, + PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN], + term, word) + query.nodes[-1].adjust_break(breakchar, + PENALTY_IN_TOKEN_BREAK[breakchar]) - extract_words(parts, phrase_start, words) + extract_words(query, phrase_start, words) - phrase_start = len(parts) - query.nodes[-1].btype = qmod.BREAK_END + phrase_start = len(query.nodes) + query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END]) - return parts, words + return words async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]': """ Return the token information from the database for the @@ -292,18 +277,23 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): t = self.conn.t.meta.tables['word'] return await self.conn.execute(t.select().where(t.c.word_token.in_(words))) - def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None: + def add_extra_tokens(self, query: qmod.QueryStruct) -> None: """ Add tokens to query that are not saved in the database. """ - for part, node, i in zip(parts, query.nodes, range(1000)): - if len(part.token) <= 4 and part.token.isdigit()\ - and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER): - query.add_token(qmod.TokenRange(i, i+1), qmod.TOKEN_HOUSENUMBER, + need_hnr = False + for i, node in enumerate(query.nodes): + is_full_token = node.btype not in (qmod.BREAK_TOKEN, qmod.BREAK_PART) + if need_hnr and is_full_token \ + and len(node.term_normalized) <= 4 and node.term_normalized.isdigit(): + query.add_token(qmod.TokenRange(i-1, i), qmod.TOKEN_HOUSENUMBER, ICUToken(penalty=0.5, token=0, - count=1, addr_count=1, lookup_word=part.token, - word_token=part.token, info=None)) + count=1, addr_count=1, + lookup_word=node.term_lookup, + word_token=node.term_lookup, info=None)) - def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None: + need_hnr = is_full_token and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER) + + def rerank_tokens(self, query: qmod.QueryStruct) -> None: """ Add penalties to tokens that depend on presence of other token. """ for i, node, tlist in query.iter_token_lists(): @@ -320,21 +310,15 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer): if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER: repl.add_penalty(0.5 - tlist.tokens[0].penalty) elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL): - norm = parts[i].normalized - for j in range(i + 1, tlist.end): - if node.btype != qmod.BREAK_TOKEN: - norm += ' ' + parts[j].normalized + norm = ' '.join(n.term_normalized for n in query.nodes[i + 1:tlist.end + 1] + if n.btype != qmod.BREAK_TOKEN) + if not norm: + # Can happen when the token only covers a partial term + norm = query.nodes[i + 1].term_normalized for token in tlist.tokens: cast(ICUToken, token).rematch(norm) -def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str: - out = query.nodes[0].btype - for node, part in zip(query.nodes[1:], parts): - out += part.token + node.btype - return out - - def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]: yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info'] for node in query.nodes: diff --git a/src/nominatim_api/search/query.py b/src/nominatim_api/search/query.py index 8530c4f2..fcd6763b 100644 --- a/src/nominatim_api/search/query.py +++ b/src/nominatim_api/search/query.py @@ -171,11 +171,33 @@ class TokenList: @dataclasses.dataclass class QueryNode: """ A node of the query representing a break between terms. + + The node also contains information on the source term + ending at the node. The tokens are created from this information. """ btype: BreakType ptype: PhraseType + + penalty: float + """ Penalty for the break at this node. + """ + term_lookup: str + """ Transliterated term following this node. + """ + term_normalized: str + """ Normalised form of term following this node. + When the token resulted from a split during transliteration, + then this string contains the complete source term. + """ + starting: List[TokenList] = dataclasses.field(default_factory=list) + def adjust_break(self, btype: BreakType, penalty: float) -> None: + """ Change the break type and penalty for this node. + """ + self.btype = btype + self.penalty = penalty + def has_tokens(self, end: int, *ttypes: TokenType) -> bool: """ Check if there are tokens of the given types ending at the given node. @@ -218,19 +240,22 @@ class QueryStruct: def __init__(self, source: List[Phrase]) -> None: self.source = source self.nodes: List[QueryNode] = \ - [QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY)] + [QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY, + 0.0, '', '')] def num_token_slots(self) -> int: """ Return the length of the query in vertice steps. """ return len(self.nodes) - 1 - def add_node(self, btype: BreakType, ptype: PhraseType) -> None: + def add_node(self, btype: BreakType, ptype: PhraseType, + break_penalty: float = 0.0, + term_lookup: str = '', term_normalized: str = '') -> None: """ Append a new break node with the given break type. The phrase type denotes the type for any tokens starting at the node. """ - self.nodes.append(QueryNode(btype, ptype)) + self.nodes.append(QueryNode(btype, ptype, break_penalty, term_lookup, term_normalized)) def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None: """ Add a token to the query. 'start' and 'end' are the indexes of the @@ -287,3 +312,11 @@ class QueryStruct: if t.token == token: return f"[{tlist.ttype}]{t.lookup_word}" return 'None' + + def get_transliterated_query(self) -> str: + """ Return a string representation of the transliterated query + with the character representation of the different break types. + + For debugging purposes only. + """ + return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes) diff --git a/test/python/api/search/test_api_search_query.py b/test/python/api/search/test_api_search_query.py index 412a5bf2..08a1f7aa 100644 --- a/test/python/api/search/test_api_search_query.py +++ b/test/python/api/search/test_api_search_query.py @@ -21,6 +21,9 @@ def mktoken(tid: int): return MyToken(penalty=3.0, token=tid, count=1, addr_count=1, lookup_word='foo') +@pytest.fixture +def qnode(): + return query.QueryNode(query.BREAK_PHRASE, query.PHRASE_ANY, 0.0 ,'', '') @pytest.mark.parametrize('ptype,ttype', [(query.PHRASE_ANY, 'W'), (query.PHRASE_AMENITY, 'Q'), @@ -37,27 +40,24 @@ def test_phrase_incompatible(ptype): assert not query._phrase_compatible_with(ptype, query.TOKEN_PARTIAL, True) -def test_query_node_empty(): - qn = query.QueryNode(query.BREAK_PHRASE, query.PHRASE_ANY) - - assert not qn.has_tokens(3, query.TOKEN_PARTIAL) - assert qn.get_tokens(3, query.TOKEN_WORD) is None +def test_query_node_empty(qnode): + assert not qnode.has_tokens(3, query.TOKEN_PARTIAL) + assert qnode.get_tokens(3, query.TOKEN_WORD) is None -def test_query_node_with_content(): - qn = query.QueryNode(query.BREAK_PHRASE, query.PHRASE_ANY) - qn.starting.append(query.TokenList(2, query.TOKEN_PARTIAL, [mktoken(100), mktoken(101)])) - qn.starting.append(query.TokenList(2, query.TOKEN_WORD, [mktoken(1000)])) +def test_query_node_with_content(qnode): + qnode.starting.append(query.TokenList(2, query.TOKEN_PARTIAL, [mktoken(100), mktoken(101)])) + qnode.starting.append(query.TokenList(2, query.TOKEN_WORD, [mktoken(1000)])) - assert not qn.has_tokens(3, query.TOKEN_PARTIAL) - assert not qn.has_tokens(2, query.TOKEN_COUNTRY) - assert qn.has_tokens(2, query.TOKEN_PARTIAL) - assert qn.has_tokens(2, query.TOKEN_WORD) + assert not qnode.has_tokens(3, query.TOKEN_PARTIAL) + assert not qnode.has_tokens(2, query.TOKEN_COUNTRY) + assert qnode.has_tokens(2, query.TOKEN_PARTIAL) + assert qnode.has_tokens(2, query.TOKEN_WORD) - assert qn.get_tokens(3, query.TOKEN_PARTIAL) is None - assert qn.get_tokens(2, query.TOKEN_COUNTRY) is None - assert len(qn.get_tokens(2, query.TOKEN_PARTIAL)) == 2 - assert len(qn.get_tokens(2, query.TOKEN_WORD)) == 1 + assert qnode.get_tokens(3, query.TOKEN_PARTIAL) is None + assert qnode.get_tokens(2, query.TOKEN_COUNTRY) is None + assert len(qnode.get_tokens(2, query.TOKEN_PARTIAL)) == 2 + assert len(qnode.get_tokens(2, query.TOKEN_WORD)) == 1 def test_query_struct_empty():