From 3980791cfdb07d7aa0360dd7795c4943b0251882 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Fri, 11 Apr 2025 09:38:24 +0200 Subject: [PATCH] use iterator instead of list to go over partials --- src/nominatim_api/search/db_search_builder.py | 14 ++++++-------- src/nominatim_api/search/query.py | 7 ++++--- test/python/api/search/test_api_search_query.py | 2 +- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/nominatim_api/search/db_search_builder.py b/src/nominatim_api/search/db_search_builder.py index 03563f17..c363442d 100644 --- a/src/nominatim_api/search/db_search_builder.py +++ b/src/nominatim_api/search/db_search_builder.py @@ -146,7 +146,7 @@ class SearchBuilder: if address: sdata.lookups = [dbf.FieldLookup('nameaddress_vector', [t.token for r in address - for t in self.query.get_partials_list(r)], + for t in self.query.iter_partials(r)], lookups.Restrict)] yield dbs.PostcodeSearch(penalty, sdata) @@ -159,7 +159,7 @@ class SearchBuilder: expected_count = sum(t.count for t in hnrs) partials = {t.token: t.addr_count for trange in address - for t in self.query.get_partials_list(trange)} + for t in self.query.iter_partials(trange)} if not partials: # can happen when none of the partials is indexed @@ -203,9 +203,9 @@ class SearchBuilder: are and tries to find a lookup that optimizes index use. """ penalty = 0.0 # extra penalty - name_partials = {t.token: t for t in self.query.get_partials_list(name)} + name_partials = {t.token: t for t in self.query.iter_partials(name)} - addr_partials = [t for r in address for t in self.query.get_partials_list(r)] + addr_partials = [t for r in address for t in self.query.iter_partials(r)] addr_tokens = list({t.token for t in addr_partials}) exp_count = min(t.count for t in name_partials.values()) / (3**(len(name_partials) - 1)) @@ -282,8 +282,7 @@ class SearchBuilder: ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls] ranks.sort(key=lambda r: r.penalty) # Fallback, sum of penalty for partials - name_partials = self.query.get_partials_list(trange) - default = sum(t.penalty for t in name_partials) + 0.2 + default = sum(t.penalty for t in self.query.iter_partials(trange)) + 0.2 return dbf.FieldRanking(db_field, default, ranks) def get_addr_ranking(self, trange: qmod.TokenRange) -> dbf.FieldRanking: @@ -320,8 +319,7 @@ class SearchBuilder: if len(ranks) >= 10: # Too many variants, bail out and only add # Worst-case Fallback: sum of penalty of partials - name_partials = self.query.get_partials_list(trange) - default = sum(t.penalty for t in name_partials) + 0.2 + default = sum(t.penalty for t in self.query.iter_partials(trange)) + 0.2 ranks.append(dbf.RankedTokens(rank.penalty + default, [])) # Bail out of outer loop todo.clear() diff --git a/src/nominatim_api/search/query.py b/src/nominatim_api/search/query.py index 8c5983c4..b8541b78 100644 --- a/src/nominatim_api/search/query.py +++ b/src/nominatim_api/search/query.py @@ -301,10 +301,11 @@ class QueryStruct: assert ttype != TOKEN_PARTIAL return self.nodes[trange.start].get_tokens(trange.end, ttype) or [] - def get_partials_list(self, trange: TokenRange) -> List[Token]: - """ Create a list of partial tokens between the given nodes. + def iter_partials(self, trange: TokenRange) -> Iterator[Token]: + """ Iterate over the partial tokens between the given nodes. + Missing partials are ignored. """ - return list(filter(None, (self.nodes[i].partial for i in range(trange.start, trange.end)))) + return (n.partial for n in self.nodes[trange.start:trange.end] if n.partial is not None) def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]: """ Iterator over all token lists except partial tokens in the query. diff --git a/test/python/api/search/test_api_search_query.py b/test/python/api/search/test_api_search_query.py index e54da1e9..ea3b9772 100644 --- a/test/python/api/search/test_api_search_query.py +++ b/test/python/api/search/test_api_search_query.py @@ -82,7 +82,7 @@ def test_query_struct_with_tokens(): assert q.get_tokens(query.TokenRange(0, 2), query.TOKEN_WORD) == [] assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_WORD)) == 2 - partials = q.get_partials_list(query.TokenRange(0, 2)) + partials = list(q.iter_partials(query.TokenRange(0, 2))) assert len(partials) == 2 assert [t.token for t in partials] == [1, 2]