use iterator instead of list to go over partials

This commit is contained in:
Sarah Hoffmann
2025-04-11 09:38:24 +02:00
parent 497e27bb9a
commit 3980791cfd
3 changed files with 11 additions and 12 deletions

View File

@@ -146,7 +146,7 @@ class SearchBuilder:
if address: if address:
sdata.lookups = [dbf.FieldLookup('nameaddress_vector', sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
[t.token for r in address [t.token for r in address
for t in self.query.get_partials_list(r)], for t in self.query.iter_partials(r)],
lookups.Restrict)] lookups.Restrict)]
yield dbs.PostcodeSearch(penalty, sdata) yield dbs.PostcodeSearch(penalty, sdata)
@@ -159,7 +159,7 @@ class SearchBuilder:
expected_count = sum(t.count for t in hnrs) expected_count = sum(t.count for t in hnrs)
partials = {t.token: t.addr_count for trange in address partials = {t.token: t.addr_count for trange in address
for t in self.query.get_partials_list(trange)} for t in self.query.iter_partials(trange)}
if not partials: if not partials:
# can happen when none of the partials is indexed # can happen when none of the partials is indexed
@@ -203,9 +203,9 @@ class SearchBuilder:
are and tries to find a lookup that optimizes index use. are and tries to find a lookup that optimizes index use.
""" """
penalty = 0.0 # extra penalty penalty = 0.0 # extra penalty
name_partials = {t.token: t for t in self.query.get_partials_list(name)} name_partials = {t.token: t for t in self.query.iter_partials(name)}
addr_partials = [t for r in address for t in self.query.get_partials_list(r)] addr_partials = [t for r in address for t in self.query.iter_partials(r)]
addr_tokens = list({t.token for t in addr_partials}) addr_tokens = list({t.token for t in addr_partials})
exp_count = min(t.count for t in name_partials.values()) / (3**(len(name_partials) - 1)) exp_count = min(t.count for t in name_partials.values()) / (3**(len(name_partials) - 1))
@@ -282,8 +282,7 @@ class SearchBuilder:
ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls] ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
ranks.sort(key=lambda r: r.penalty) ranks.sort(key=lambda r: r.penalty)
# Fallback, sum of penalty for partials # Fallback, sum of penalty for partials
name_partials = self.query.get_partials_list(trange) default = sum(t.penalty for t in self.query.iter_partials(trange)) + 0.2
default = sum(t.penalty for t in name_partials) + 0.2
return dbf.FieldRanking(db_field, default, ranks) return dbf.FieldRanking(db_field, default, ranks)
def get_addr_ranking(self, trange: qmod.TokenRange) -> dbf.FieldRanking: def get_addr_ranking(self, trange: qmod.TokenRange) -> dbf.FieldRanking:
@@ -320,8 +319,7 @@ class SearchBuilder:
if len(ranks) >= 10: if len(ranks) >= 10:
# Too many variants, bail out and only add # Too many variants, bail out and only add
# Worst-case Fallback: sum of penalty of partials # Worst-case Fallback: sum of penalty of partials
name_partials = self.query.get_partials_list(trange) default = sum(t.penalty for t in self.query.iter_partials(trange)) + 0.2
default = sum(t.penalty for t in name_partials) + 0.2
ranks.append(dbf.RankedTokens(rank.penalty + default, [])) ranks.append(dbf.RankedTokens(rank.penalty + default, []))
# Bail out of outer loop # Bail out of outer loop
todo.clear() todo.clear()

View File

@@ -301,10 +301,11 @@ class QueryStruct:
assert ttype != TOKEN_PARTIAL assert ttype != TOKEN_PARTIAL
return self.nodes[trange.start].get_tokens(trange.end, ttype) or [] return self.nodes[trange.start].get_tokens(trange.end, ttype) or []
def get_partials_list(self, trange: TokenRange) -> List[Token]: def iter_partials(self, trange: TokenRange) -> Iterator[Token]:
""" Create a list of partial tokens between the given nodes. """ Iterate over the partial tokens between the given nodes.
Missing partials are ignored.
""" """
return list(filter(None, (self.nodes[i].partial for i in range(trange.start, trange.end)))) return (n.partial for n in self.nodes[trange.start:trange.end] if n.partial is not None)
def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]: def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]:
""" Iterator over all token lists except partial tokens in the query. """ Iterator over all token lists except partial tokens in the query.

View File

@@ -82,7 +82,7 @@ def test_query_struct_with_tokens():
assert q.get_tokens(query.TokenRange(0, 2), query.TOKEN_WORD) == [] assert q.get_tokens(query.TokenRange(0, 2), query.TOKEN_WORD) == []
assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_WORD)) == 2 assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_WORD)) == 2
partials = q.get_partials_list(query.TokenRange(0, 2)) partials = list(q.iter_partials(query.TokenRange(0, 2)))
assert len(partials) == 2 assert len(partials) == 2
assert [t.token for t in partials] == [1, 2] assert [t.token for t in partials] == [1, 2]