forked from hans/Nominatim
Merge pull request #3719 from lonvia/query-direction
Estimate query direction
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
#
|
#
|
||||||
# This file is part of Nominatim. (https://nominatim.org)
|
# This file is part of Nominatim. (https://nominatim.org)
|
||||||
#
|
#
|
||||||
# Copyright (C) 2024 by the Nominatim developer community.
|
# Copyright (C) 2025 by the Nominatim developer community.
|
||||||
# For a full list of authors see the git log.
|
# For a full list of authors see the git log.
|
||||||
"""
|
"""
|
||||||
Conversion from token assignment to an abstract DB search.
|
Conversion from token assignment to an abstract DB search.
|
||||||
@@ -146,7 +146,7 @@ class SearchBuilder:
|
|||||||
if address:
|
if address:
|
||||||
sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
|
sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
|
||||||
[t.token for r in address
|
[t.token for r in address
|
||||||
for t in self.query.get_partials_list(r)],
|
for t in self.query.iter_partials(r)],
|
||||||
lookups.Restrict)]
|
lookups.Restrict)]
|
||||||
yield dbs.PostcodeSearch(penalty, sdata)
|
yield dbs.PostcodeSearch(penalty, sdata)
|
||||||
|
|
||||||
@@ -159,7 +159,7 @@ class SearchBuilder:
|
|||||||
expected_count = sum(t.count for t in hnrs)
|
expected_count = sum(t.count for t in hnrs)
|
||||||
|
|
||||||
partials = {t.token: t.addr_count for trange in address
|
partials = {t.token: t.addr_count for trange in address
|
||||||
for t in self.query.get_partials_list(trange)}
|
for t in self.query.iter_partials(trange)}
|
||||||
|
|
||||||
if not partials:
|
if not partials:
|
||||||
# can happen when none of the partials is indexed
|
# can happen when none of the partials is indexed
|
||||||
@@ -203,9 +203,9 @@ class SearchBuilder:
|
|||||||
are and tries to find a lookup that optimizes index use.
|
are and tries to find a lookup that optimizes index use.
|
||||||
"""
|
"""
|
||||||
penalty = 0.0 # extra penalty
|
penalty = 0.0 # extra penalty
|
||||||
name_partials = {t.token: t for t in self.query.get_partials_list(name)}
|
name_partials = {t.token: t for t in self.query.iter_partials(name)}
|
||||||
|
|
||||||
addr_partials = [t for r in address for t in self.query.get_partials_list(r)]
|
addr_partials = [t for r in address for t in self.query.iter_partials(r)]
|
||||||
addr_tokens = list({t.token for t in addr_partials})
|
addr_tokens = list({t.token for t in addr_partials})
|
||||||
|
|
||||||
exp_count = min(t.count for t in name_partials.values()) / (3**(len(name_partials) - 1))
|
exp_count = min(t.count for t in name_partials.values()) / (3**(len(name_partials) - 1))
|
||||||
@@ -282,8 +282,7 @@ class SearchBuilder:
|
|||||||
ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
|
ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
|
||||||
ranks.sort(key=lambda r: r.penalty)
|
ranks.sort(key=lambda r: r.penalty)
|
||||||
# Fallback, sum of penalty for partials
|
# Fallback, sum of penalty for partials
|
||||||
name_partials = self.query.get_partials_list(trange)
|
default = sum(t.penalty for t in self.query.iter_partials(trange)) + 0.2
|
||||||
default = sum(t.penalty for t in name_partials) + 0.2
|
|
||||||
return dbf.FieldRanking(db_field, default, ranks)
|
return dbf.FieldRanking(db_field, default, ranks)
|
||||||
|
|
||||||
def get_addr_ranking(self, trange: qmod.TokenRange) -> dbf.FieldRanking:
|
def get_addr_ranking(self, trange: qmod.TokenRange) -> dbf.FieldRanking:
|
||||||
@@ -296,35 +295,35 @@ class SearchBuilder:
|
|||||||
|
|
||||||
while todo:
|
while todo:
|
||||||
neglen, pos, rank = heapq.heappop(todo)
|
neglen, pos, rank = heapq.heappop(todo)
|
||||||
|
# partial node
|
||||||
|
partial = self.query.nodes[pos].partial
|
||||||
|
if partial is not None:
|
||||||
|
if pos + 1 < trange.end:
|
||||||
|
penalty = rank.penalty + partial.penalty \
|
||||||
|
+ PENALTY_WORDCHANGE[self.query.nodes[pos + 1].btype]
|
||||||
|
heapq.heappush(todo, (neglen - 1, pos + 1,
|
||||||
|
dbf.RankedTokens(penalty, rank.tokens)))
|
||||||
|
else:
|
||||||
|
ranks.append(dbf.RankedTokens(rank.penalty + partial.penalty,
|
||||||
|
rank.tokens))
|
||||||
|
# full words
|
||||||
for tlist in self.query.nodes[pos].starting:
|
for tlist in self.query.nodes[pos].starting:
|
||||||
if tlist.ttype in (qmod.TOKEN_PARTIAL, qmod.TOKEN_WORD):
|
if tlist.ttype == qmod.TOKEN_WORD:
|
||||||
if tlist.end < trange.end:
|
if tlist.end < trange.end:
|
||||||
chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
|
chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
|
||||||
if tlist.ttype == qmod.TOKEN_PARTIAL:
|
for t in tlist.tokens:
|
||||||
penalty = rank.penalty + chgpenalty \
|
|
||||||
+ max(t.penalty for t in tlist.tokens)
|
|
||||||
heapq.heappush(todo, (neglen - 1, tlist.end,
|
heapq.heappush(todo, (neglen - 1, tlist.end,
|
||||||
dbf.RankedTokens(penalty, rank.tokens)))
|
rank.with_token(t, chgpenalty)))
|
||||||
else:
|
|
||||||
for t in tlist.tokens:
|
|
||||||
heapq.heappush(todo, (neglen - 1, tlist.end,
|
|
||||||
rank.with_token(t, chgpenalty)))
|
|
||||||
elif tlist.end == trange.end:
|
elif tlist.end == trange.end:
|
||||||
if tlist.ttype == qmod.TOKEN_PARTIAL:
|
ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
|
||||||
ranks.append(dbf.RankedTokens(rank.penalty
|
|
||||||
+ max(t.penalty for t in tlist.tokens),
|
if len(ranks) >= 10:
|
||||||
rank.tokens))
|
# Too many variants, bail out and only add
|
||||||
else:
|
# Worst-case Fallback: sum of penalty of partials
|
||||||
ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
|
default = sum(t.penalty for t in self.query.iter_partials(trange)) + 0.2
|
||||||
if len(ranks) >= 10:
|
ranks.append(dbf.RankedTokens(rank.penalty + default, []))
|
||||||
# Too many variants, bail out and only add
|
# Bail out of outer loop
|
||||||
# Worst-case Fallback: sum of penalty of partials
|
break
|
||||||
name_partials = self.query.get_partials_list(trange)
|
|
||||||
default = sum(t.penalty for t in name_partials) + 0.2
|
|
||||||
ranks.append(dbf.RankedTokens(rank.penalty + default, []))
|
|
||||||
# Bail out of outer loop
|
|
||||||
todo.clear()
|
|
||||||
break
|
|
||||||
|
|
||||||
ranks.sort(key=lambda r: len(r.tokens))
|
ranks.sort(key=lambda r: len(r.tokens))
|
||||||
default = ranks[0].penalty + 0.3
|
default = ranks[0].penalty + 0.3
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
#
|
#
|
||||||
# This file is part of Nominatim. (https://nominatim.org)
|
# This file is part of Nominatim. (https://nominatim.org)
|
||||||
#
|
#
|
||||||
# Copyright (C) 2024 by the Nominatim developer community.
|
# Copyright (C) 2025 by the Nominatim developer community.
|
||||||
# For a full list of authors see the git log.
|
# For a full list of authors see the git log.
|
||||||
"""
|
"""
|
||||||
Public interface to the search code.
|
Public interface to the search code.
|
||||||
@@ -50,6 +50,9 @@ class ForwardGeocoder:
|
|||||||
self.query_analyzer = await make_query_analyzer(self.conn)
|
self.query_analyzer = await make_query_analyzer(self.conn)
|
||||||
|
|
||||||
query = await self.query_analyzer.analyze_query(phrases)
|
query = await self.query_analyzer.analyze_query(phrases)
|
||||||
|
query.compute_direction_penalty()
|
||||||
|
log().var_dump('Query direction penalty',
|
||||||
|
lambda: f"[{'LR' if query.dir_penalty < 0 else 'RL'}] {query.dir_penalty}")
|
||||||
|
|
||||||
searches: List[AbstractSearch] = []
|
searches: List[AbstractSearch] = []
|
||||||
if query.num_token_slots() > 0:
|
if query.num_token_slots() > 0:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
#
|
#
|
||||||
# This file is part of Nominatim. (https://nominatim.org)
|
# This file is part of Nominatim. (https://nominatim.org)
|
||||||
#
|
#
|
||||||
# Copyright (C) 2024 by the Nominatim developer community.
|
# Copyright (C) 2025 by the Nominatim developer community.
|
||||||
# For a full list of authors see the git log.
|
# For a full list of authors see the git log.
|
||||||
"""
|
"""
|
||||||
Implementation of query analysis for the ICU tokenizer.
|
Implementation of query analysis for the ICU tokenizer.
|
||||||
@@ -267,32 +267,47 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
|||||||
def rerank_tokens(self, query: qmod.QueryStruct) -> None:
|
def rerank_tokens(self, query: qmod.QueryStruct) -> None:
|
||||||
""" Add penalties to tokens that depend on presence of other token.
|
""" Add penalties to tokens that depend on presence of other token.
|
||||||
"""
|
"""
|
||||||
for i, node, tlist in query.iter_token_lists():
|
for start, end, tlist in query.iter_tokens_by_edge():
|
||||||
if tlist.ttype == qmod.TOKEN_POSTCODE:
|
if len(tlist) > 1:
|
||||||
tlen = len(cast(ICUToken, tlist.tokens[0]).word_token)
|
# If it looks like a Postcode, give preference.
|
||||||
for repl in node.starting:
|
if qmod.TOKEN_POSTCODE in tlist:
|
||||||
if repl.end == tlist.end and repl.ttype != qmod.TOKEN_POSTCODE \
|
for ttype, tokens in tlist.items():
|
||||||
and (repl.ttype != qmod.TOKEN_HOUSENUMBER or tlen > 4):
|
if ttype != qmod.TOKEN_POSTCODE and \
|
||||||
repl.add_penalty(0.39)
|
(ttype != qmod.TOKEN_HOUSENUMBER or
|
||||||
elif (tlist.ttype == qmod.TOKEN_HOUSENUMBER
|
start + 1 > end or
|
||||||
and len(tlist.tokens[0].lookup_word) <= 3):
|
len(query.nodes[end].term_lookup) > 4):
|
||||||
if any(c.isdigit() for c in tlist.tokens[0].lookup_word):
|
for token in tokens:
|
||||||
for repl in node.starting:
|
token.penalty += 0.39
|
||||||
if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER:
|
|
||||||
repl.add_penalty(0.5 - tlist.tokens[0].penalty)
|
# If it looks like a simple housenumber, prefer that.
|
||||||
elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL):
|
if qmod.TOKEN_HOUSENUMBER in tlist:
|
||||||
norm = ' '.join(n.term_normalized for n in query.nodes[i + 1:tlist.end + 1]
|
hnr_lookup = tlist[qmod.TOKEN_HOUSENUMBER][0].lookup_word
|
||||||
if n.btype != qmod.BREAK_TOKEN)
|
if len(hnr_lookup) <= 3 and any(c.isdigit() for c in hnr_lookup):
|
||||||
if not norm:
|
penalty = 0.5 - tlist[qmod.TOKEN_HOUSENUMBER][0].penalty
|
||||||
# Can happen when the token only covers a partial term
|
for ttype, tokens in tlist.items():
|
||||||
norm = query.nodes[i + 1].term_normalized
|
if ttype != qmod.TOKEN_HOUSENUMBER:
|
||||||
for token in tlist.tokens:
|
for token in tokens:
|
||||||
cast(ICUToken, token).rematch(norm)
|
token.penalty += penalty
|
||||||
|
|
||||||
|
# rerank tokens against the normalized form
|
||||||
|
norm = ' '.join(n.term_normalized for n in query.nodes[start + 1:end + 1]
|
||||||
|
if n.btype != qmod.BREAK_TOKEN)
|
||||||
|
if not norm:
|
||||||
|
# Can happen when the token only covers a partial term
|
||||||
|
norm = query.nodes[start + 1].term_normalized
|
||||||
|
for ttype, tokens in tlist.items():
|
||||||
|
if ttype != qmod.TOKEN_COUNTRY:
|
||||||
|
for token in tokens:
|
||||||
|
cast(ICUToken, token).rematch(norm)
|
||||||
|
|
||||||
|
|
||||||
def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
|
def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
|
||||||
yield ['type', 'from', 'to', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
|
yield ['type', 'from', 'to', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
|
||||||
for i, node in enumerate(query.nodes):
|
for i, node in enumerate(query.nodes):
|
||||||
|
if node.partial is not None:
|
||||||
|
t = cast(ICUToken, node.partial)
|
||||||
|
yield [qmod.TOKEN_PARTIAL, str(i), str(i + 1), t.token,
|
||||||
|
t.word_token, t.lookup_word, t.penalty, t.count, t.info]
|
||||||
for tlist in node.starting:
|
for tlist in node.starting:
|
||||||
for token in tlist.tokens:
|
for token in tlist.tokens:
|
||||||
t = cast(ICUToken, token)
|
t = cast(ICUToken, token)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
#
|
#
|
||||||
# This file is part of Nominatim. (https://nominatim.org)
|
# This file is part of Nominatim. (https://nominatim.org)
|
||||||
#
|
#
|
||||||
# Copyright (C) 2024 by the Nominatim developer community.
|
# Copyright (C) 2025 by the Nominatim developer community.
|
||||||
# For a full list of authors see the git log.
|
# For a full list of authors see the git log.
|
||||||
"""
|
"""
|
||||||
Datastructures for a tokenized query.
|
Datastructures for a tokenized query.
|
||||||
@@ -12,6 +12,17 @@ from abc import ABC, abstractmethod
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
|
||||||
|
# Precomputed denominator for the computation of the linear regression slope
|
||||||
|
# used to determine the query direction.
|
||||||
|
# The x value for the regression computation will be the position of the
|
||||||
|
# token in the query. Thus we know the x values will be [0, query length).
|
||||||
|
# As the denominator only depends on the x values, we can pre-compute here
|
||||||
|
# the denominatior to use for a given query length.
|
||||||
|
# Note that query length of two or less is special cased and will not use
|
||||||
|
# the values from this array. Thus it is not a problem that they are 0.
|
||||||
|
LINFAC = [i * (sum(si * si for si in range(i)) - (i - 1) * i * (i - 1) / 4)
|
||||||
|
for i in range(50)]
|
||||||
|
|
||||||
|
|
||||||
BreakType = str
|
BreakType = str
|
||||||
""" Type of break between tokens.
|
""" Type of break between tokens.
|
||||||
@@ -183,15 +194,32 @@ class QueryNode:
|
|||||||
""" Penalty for the break at this node.
|
""" Penalty for the break at this node.
|
||||||
"""
|
"""
|
||||||
term_lookup: str
|
term_lookup: str
|
||||||
""" Transliterated term following this node.
|
""" Transliterated term ending at this node.
|
||||||
"""
|
"""
|
||||||
term_normalized: str
|
term_normalized: str
|
||||||
""" Normalised form of term following this node.
|
""" Normalised form of term ending at this node.
|
||||||
When the token resulted from a split during transliteration,
|
When the token resulted from a split during transliteration,
|
||||||
then this string contains the complete source term.
|
then this string contains the complete source term.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
starting: List[TokenList] = dataclasses.field(default_factory=list)
|
starting: List[TokenList] = dataclasses.field(default_factory=list)
|
||||||
|
""" List of all full tokens starting at this node.
|
||||||
|
"""
|
||||||
|
partial: Optional[Token] = None
|
||||||
|
""" Base token going to the next node.
|
||||||
|
May be None when the query has parts for which no words are known.
|
||||||
|
Note that the query may still be parsable when there are other
|
||||||
|
types of tokens spanning over the gap.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def name_address_ratio(self) -> float:
|
||||||
|
""" Return the propability that the partial token belonging to
|
||||||
|
this node forms part of a name (as opposed of part of the address).
|
||||||
|
"""
|
||||||
|
if self.partial is None:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
return self.partial.count / (self.partial.count + self.partial.addr_count)
|
||||||
|
|
||||||
def adjust_break(self, btype: BreakType, penalty: float) -> None:
|
def adjust_break(self, btype: BreakType, penalty: float) -> None:
|
||||||
""" Change the break type and penalty for this node.
|
""" Change the break type and penalty for this node.
|
||||||
@@ -234,12 +262,20 @@ class QueryStruct:
|
|||||||
need to be direct neighbours. Thus the query is represented as a
|
need to be direct neighbours. Thus the query is represented as a
|
||||||
directed acyclic graph.
|
directed acyclic graph.
|
||||||
|
|
||||||
|
A query also has a direction penalty 'dir_penalty'. This describes
|
||||||
|
the likelyhood if the query should be read from left-to-right or
|
||||||
|
vice versa. A negative 'dir_penalty' should be read as a penalty on
|
||||||
|
right-to-left reading, while a positive value represents a penalty
|
||||||
|
for left-to-right reading. The default value is 0, which is equivalent
|
||||||
|
to having no information about the reading.
|
||||||
|
|
||||||
When created, a query contains a single node: the start of the
|
When created, a query contains a single node: the start of the
|
||||||
query. Further nodes can be added by appending to 'nodes'.
|
query. Further nodes can be added by appending to 'nodes'.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, source: List[Phrase]) -> None:
|
def __init__(self, source: List[Phrase]) -> None:
|
||||||
self.source = source
|
self.source = source
|
||||||
|
self.dir_penalty = 0.0
|
||||||
self.nodes: List[QueryNode] = \
|
self.nodes: List[QueryNode] = \
|
||||||
[QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY,
|
[QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY,
|
||||||
0.0, '', '')]
|
0.0, '', '')]
|
||||||
@@ -269,37 +305,63 @@ class QueryStruct:
|
|||||||
be added to, then the token is silently dropped.
|
be added to, then the token is silently dropped.
|
||||||
"""
|
"""
|
||||||
snode = self.nodes[trange.start]
|
snode = self.nodes[trange.start]
|
||||||
full_phrase = snode.btype in (BREAK_START, BREAK_PHRASE)\
|
if ttype == TOKEN_PARTIAL:
|
||||||
and self.nodes[trange.end].btype in (BREAK_PHRASE, BREAK_END)
|
assert snode.partial is None
|
||||||
if _phrase_compatible_with(snode.ptype, ttype, full_phrase):
|
if _phrase_compatible_with(snode.ptype, TOKEN_PARTIAL, False):
|
||||||
tlist = snode.get_tokens(trange.end, ttype)
|
snode.partial = token
|
||||||
if tlist is None:
|
else:
|
||||||
snode.starting.append(TokenList(trange.end, ttype, [token]))
|
full_phrase = snode.btype in (BREAK_START, BREAK_PHRASE)\
|
||||||
else:
|
and self.nodes[trange.end].btype in (BREAK_PHRASE, BREAK_END)
|
||||||
tlist.append(token)
|
if _phrase_compatible_with(snode.ptype, ttype, full_phrase):
|
||||||
|
tlist = snode.get_tokens(trange.end, ttype)
|
||||||
|
if tlist is None:
|
||||||
|
snode.starting.append(TokenList(trange.end, ttype, [token]))
|
||||||
|
else:
|
||||||
|
tlist.append(token)
|
||||||
|
|
||||||
|
def compute_direction_penalty(self) -> None:
|
||||||
|
""" Recompute the direction probability from the partial tokens
|
||||||
|
of each node.
|
||||||
|
"""
|
||||||
|
n = len(self.nodes) - 1
|
||||||
|
if n == 1 or n >= 50:
|
||||||
|
self.dir_penalty = 0
|
||||||
|
elif n == 2:
|
||||||
|
self.dir_penalty = (self.nodes[1].name_address_ratio()
|
||||||
|
- self.nodes[0].name_address_ratio()) / 3
|
||||||
|
else:
|
||||||
|
ratios = [n.name_address_ratio() for n in self.nodes[:-1]]
|
||||||
|
self.dir_penalty = (n * sum(i * r for i, r in enumerate(ratios))
|
||||||
|
- sum(ratios) * n * (n - 1) / 2) / LINFAC[n]
|
||||||
|
|
||||||
def get_tokens(self, trange: TokenRange, ttype: TokenType) -> List[Token]:
|
def get_tokens(self, trange: TokenRange, ttype: TokenType) -> List[Token]:
|
||||||
""" Get the list of tokens of a given type, spanning the given
|
""" Get the list of tokens of a given type, spanning the given
|
||||||
nodes. The nodes must exist. If no tokens exist, an
|
nodes. The nodes must exist. If no tokens exist, an
|
||||||
empty list is returned.
|
empty list is returned.
|
||||||
|
|
||||||
|
Cannot be used to get the partial token.
|
||||||
"""
|
"""
|
||||||
|
assert ttype != TOKEN_PARTIAL
|
||||||
return self.nodes[trange.start].get_tokens(trange.end, ttype) or []
|
return self.nodes[trange.start].get_tokens(trange.end, ttype) or []
|
||||||
|
|
||||||
def get_partials_list(self, trange: TokenRange) -> List[Token]:
|
def iter_partials(self, trange: TokenRange) -> Iterator[Token]:
|
||||||
""" Create a list of partial tokens between the given nodes.
|
""" Iterate over the partial tokens between the given nodes.
|
||||||
The list is composed of the first token of type PARTIAL
|
Missing partials are ignored.
|
||||||
going to the subsequent node. Such PARTIAL tokens are
|
|
||||||
assumed to exist.
|
|
||||||
"""
|
"""
|
||||||
return [next(iter(self.get_tokens(TokenRange(i, i+1), TOKEN_PARTIAL)))
|
return (n.partial for n in self.nodes[trange.start:trange.end] if n.partial is not None)
|
||||||
for i in range(trange.start, trange.end)]
|
|
||||||
|
|
||||||
def iter_token_lists(self) -> Iterator[Tuple[int, QueryNode, TokenList]]:
|
def iter_tokens_by_edge(self) -> Iterator[Tuple[int, int, Dict[TokenType, List[Token]]]]:
|
||||||
""" Iterator over all token lists in the query.
|
""" Iterator over all tokens except partial ones grouped by edge.
|
||||||
|
|
||||||
|
Returns the start and end node indexes and a dictionary
|
||||||
|
of list of tokens by token type.
|
||||||
"""
|
"""
|
||||||
for i, node in enumerate(self.nodes):
|
for i, node in enumerate(self.nodes):
|
||||||
|
by_end: Dict[int, Dict[TokenType, List[Token]]] = defaultdict(dict)
|
||||||
for tlist in node.starting:
|
for tlist in node.starting:
|
||||||
yield i, node, tlist
|
by_end[tlist.end][tlist.ttype] = tlist.tokens
|
||||||
|
for end, endlist in by_end.items():
|
||||||
|
yield i, end, endlist
|
||||||
|
|
||||||
def find_lookup_word_by_id(self, token: int) -> str:
|
def find_lookup_word_by_id(self, token: int) -> str:
|
||||||
""" Find the first token with the given token ID and return
|
""" Find the first token with the given token ID and return
|
||||||
@@ -308,6 +370,8 @@ class QueryStruct:
|
|||||||
debugging.
|
debugging.
|
||||||
"""
|
"""
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
|
if node.partial is not None and node.partial.token == token:
|
||||||
|
return f"[P]{node.partial.lookup_word}"
|
||||||
for tlist in node.starting:
|
for tlist in node.starting:
|
||||||
for t in tlist.tokens:
|
for t in tlist.tokens:
|
||||||
if t.token == token:
|
if t.token == token:
|
||||||
@@ -339,16 +403,18 @@ class QueryStruct:
|
|||||||
|
|
||||||
words: Dict[str, List[TokenRange]] = defaultdict(list)
|
words: Dict[str, List[TokenRange]] = defaultdict(list)
|
||||||
|
|
||||||
for first in range(start, endpos - 1):
|
for first, first_node in enumerate(self.nodes[start + 1:endpos], start):
|
||||||
word = self.nodes[first + 1].term_lookup
|
word = first_node.term_lookup
|
||||||
penalty = base_penalty
|
penalty = base_penalty
|
||||||
words[word].append(TokenRange(first, first + 1, penalty=penalty))
|
words[word].append(TokenRange(first, first + 1, penalty=penalty))
|
||||||
if self.nodes[first + 1].btype != BREAK_PHRASE:
|
if first_node.btype != BREAK_PHRASE:
|
||||||
for last in range(first + 2, min(first + 20, endpos)):
|
penalty += first_node.penalty
|
||||||
word = ' '.join((word, self.nodes[last].term_lookup))
|
max_last = min(first + 20, endpos)
|
||||||
penalty += self.nodes[last - 1].penalty
|
for last, last_node in enumerate(self.nodes[first + 2:max_last], first + 2):
|
||||||
|
word = ' '.join((word, last_node.term_lookup))
|
||||||
words[word].append(TokenRange(first, last, penalty=penalty))
|
words[word].append(TokenRange(first, last, penalty=penalty))
|
||||||
if self.nodes[last].btype == BREAK_PHRASE:
|
if last_node.btype == BREAK_PHRASE:
|
||||||
break
|
break
|
||||||
|
penalty += last_node.penalty
|
||||||
|
|
||||||
return words
|
return words
|
||||||
|
|||||||
@@ -286,8 +286,12 @@ class _TokenSequence:
|
|||||||
log().var_dump('skip forward', (base.postcode, first))
|
log().var_dump('skip forward', (base.postcode, first))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
penalty = self.penalty
|
||||||
|
if self.direction == 1 and query.dir_penalty > 0:
|
||||||
|
penalty += query.dir_penalty
|
||||||
|
|
||||||
log().comment('first word = name')
|
log().comment('first word = name')
|
||||||
yield dataclasses.replace(base, penalty=self.penalty,
|
yield dataclasses.replace(base, penalty=penalty,
|
||||||
name=first, address=base.address[1:])
|
name=first, address=base.address[1:])
|
||||||
|
|
||||||
# To paraphrase:
|
# To paraphrase:
|
||||||
@@ -300,14 +304,15 @@ class _TokenSequence:
|
|||||||
or (query.nodes[first.start].ptype != qmod.PHRASE_ANY):
|
or (query.nodes[first.start].ptype != qmod.PHRASE_ANY):
|
||||||
return
|
return
|
||||||
|
|
||||||
penalty = self.penalty
|
|
||||||
|
|
||||||
# Penalty for:
|
# Penalty for:
|
||||||
# * <name>, <street>, <housenumber> , ...
|
# * <name>, <street>, <housenumber> , ...
|
||||||
# * queries that are comma-separated
|
# * queries that are comma-separated
|
||||||
if (base.housenumber and base.housenumber > first) or len(query.source) > 1:
|
if (base.housenumber and base.housenumber > first) or len(query.source) > 1:
|
||||||
penalty += 0.25
|
penalty += 0.25
|
||||||
|
|
||||||
|
if self.direction == 0 and query.dir_penalty > 0:
|
||||||
|
penalty += query.dir_penalty
|
||||||
|
|
||||||
for i in range(first.start + 1, first.end):
|
for i in range(first.start + 1, first.end):
|
||||||
name, addr = first.split(i)
|
name, addr = first.split(i)
|
||||||
log().comment(f'split first word = name ({i - first.start})')
|
log().comment(f'split first word = name ({i - first.start})')
|
||||||
@@ -326,9 +331,13 @@ class _TokenSequence:
|
|||||||
log().var_dump('skip backward', (base.postcode, last))
|
log().var_dump('skip backward', (base.postcode, last))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
penalty = self.penalty
|
||||||
|
if self.direction == -1 and query.dir_penalty < 0:
|
||||||
|
penalty -= query.dir_penalty
|
||||||
|
|
||||||
if self.direction == -1 or len(base.address) > 1 or base.postcode:
|
if self.direction == -1 or len(base.address) > 1 or base.postcode:
|
||||||
log().comment('last word = name')
|
log().comment('last word = name')
|
||||||
yield dataclasses.replace(base, penalty=self.penalty,
|
yield dataclasses.replace(base, penalty=penalty,
|
||||||
name=last, address=base.address[:-1])
|
name=last, address=base.address[:-1])
|
||||||
|
|
||||||
# To paraphrase:
|
# To paraphrase:
|
||||||
@@ -341,12 +350,14 @@ class _TokenSequence:
|
|||||||
or (query.nodes[last.start].ptype != qmod.PHRASE_ANY):
|
or (query.nodes[last.start].ptype != qmod.PHRASE_ANY):
|
||||||
return
|
return
|
||||||
|
|
||||||
penalty = self.penalty
|
|
||||||
if base.housenumber and base.housenumber < last:
|
if base.housenumber and base.housenumber < last:
|
||||||
penalty += 0.4
|
penalty += 0.4
|
||||||
if len(query.source) > 1:
|
if len(query.source) > 1:
|
||||||
penalty += 0.25
|
penalty += 0.25
|
||||||
|
|
||||||
|
if self.direction == 0 and query.dir_penalty < 0:
|
||||||
|
penalty -= query.dir_penalty
|
||||||
|
|
||||||
for i in range(last.start + 1, last.end):
|
for i in range(last.start + 1, last.end):
|
||||||
addr, name = last.split(i)
|
addr, name = last.split(i)
|
||||||
log().comment(f'split last word = name ({i - last.start})')
|
log().comment(f'split last word = name ({i - last.start})')
|
||||||
@@ -379,11 +390,11 @@ class _TokenSequence:
|
|||||||
if base.postcode and base.postcode.start == 0:
|
if base.postcode and base.postcode.start == 0:
|
||||||
self.penalty += 0.1
|
self.penalty += 0.1
|
||||||
|
|
||||||
# Right-to-left reading of the address
|
# Left-to-right reading of the address
|
||||||
if self.direction != -1:
|
if self.direction != -1:
|
||||||
yield from self._get_assignments_address_forward(base, query)
|
yield from self._get_assignments_address_forward(base, query)
|
||||||
|
|
||||||
# Left-to-right reading of the address
|
# Right-to-left reading of the address
|
||||||
if self.direction != 1:
|
if self.direction != 1:
|
||||||
yield from self._get_assignments_address_backward(base, query)
|
yield from self._get_assignments_address_backward(base, query)
|
||||||
|
|
||||||
@@ -409,11 +420,22 @@ def yield_token_assignments(query: qmod.QueryStruct) -> Iterator[TokenAssignment
|
|||||||
node = query.nodes[state.end_pos]
|
node = query.nodes[state.end_pos]
|
||||||
|
|
||||||
for tlist in node.starting:
|
for tlist in node.starting:
|
||||||
newstate = state.advance(tlist.ttype, tlist.end, node.btype)
|
yield from _append_state_to_todo(
|
||||||
if newstate is not None:
|
query, todo,
|
||||||
if newstate.end_pos == query.num_token_slots():
|
state.advance(tlist.ttype, tlist.end, node.btype))
|
||||||
if newstate.recheck_sequence():
|
|
||||||
log().var_dump('Assignment', newstate)
|
if node.partial is not None:
|
||||||
yield from newstate.get_assignments(query)
|
yield from _append_state_to_todo(
|
||||||
elif not newstate.is_final():
|
query, todo,
|
||||||
todo.append(newstate)
|
state.advance(qmod.TOKEN_PARTIAL, state.end_pos + 1, node.btype))
|
||||||
|
|
||||||
|
|
||||||
|
def _append_state_to_todo(query: qmod.QueryStruct, todo: List[_TokenSequence],
|
||||||
|
newstate: Optional[_TokenSequence]) -> Iterator[TokenAssignment]:
|
||||||
|
if newstate is not None:
|
||||||
|
if newstate.end_pos == query.num_token_slots():
|
||||||
|
if newstate.recheck_sequence():
|
||||||
|
log().var_dump('Assignment', newstate)
|
||||||
|
yield from newstate.get_assignments(query)
|
||||||
|
elif not newstate.is_final():
|
||||||
|
todo.append(newstate)
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ def test_phrase_incompatible(ptype):
|
|||||||
|
|
||||||
|
|
||||||
def test_query_node_empty(qnode):
|
def test_query_node_empty(qnode):
|
||||||
assert not qnode.has_tokens(3, query.TOKEN_PARTIAL)
|
|
||||||
assert qnode.get_tokens(3, query.TOKEN_WORD) is None
|
assert qnode.get_tokens(3, query.TOKEN_WORD) is None
|
||||||
|
|
||||||
|
|
||||||
@@ -57,7 +56,6 @@ def test_query_node_with_content(qnode):
|
|||||||
assert qnode.has_tokens(2, query.TOKEN_PARTIAL)
|
assert qnode.has_tokens(2, query.TOKEN_PARTIAL)
|
||||||
assert qnode.has_tokens(2, query.TOKEN_WORD)
|
assert qnode.has_tokens(2, query.TOKEN_WORD)
|
||||||
|
|
||||||
assert qnode.get_tokens(3, query.TOKEN_PARTIAL) is None
|
|
||||||
assert qnode.get_tokens(2, query.TOKEN_COUNTRY) is None
|
assert qnode.get_tokens(2, query.TOKEN_COUNTRY) is None
|
||||||
assert len(qnode.get_tokens(2, query.TOKEN_PARTIAL)) == 2
|
assert len(qnode.get_tokens(2, query.TOKEN_PARTIAL)) == 2
|
||||||
assert len(qnode.get_tokens(2, query.TOKEN_WORD)) == 1
|
assert len(qnode.get_tokens(2, query.TOKEN_WORD)) == 1
|
||||||
@@ -84,7 +82,7 @@ def test_query_struct_with_tokens():
|
|||||||
assert q.get_tokens(query.TokenRange(0, 2), query.TOKEN_WORD) == []
|
assert q.get_tokens(query.TokenRange(0, 2), query.TOKEN_WORD) == []
|
||||||
assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_WORD)) == 2
|
assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_WORD)) == 2
|
||||||
|
|
||||||
partials = q.get_partials_list(query.TokenRange(0, 2))
|
partials = list(q.iter_partials(query.TokenRange(0, 2)))
|
||||||
|
|
||||||
assert len(partials) == 2
|
assert len(partials) == 2
|
||||||
assert [t.token for t in partials] == [1, 2]
|
assert [t.token for t in partials] == [1, 2]
|
||||||
@@ -101,7 +99,6 @@ def test_query_struct_incompatible_token():
|
|||||||
q.add_token(query.TokenRange(0, 1), query.TOKEN_PARTIAL, mktoken(1))
|
q.add_token(query.TokenRange(0, 1), query.TOKEN_PARTIAL, mktoken(1))
|
||||||
q.add_token(query.TokenRange(1, 2), query.TOKEN_COUNTRY, mktoken(100))
|
q.add_token(query.TokenRange(1, 2), query.TOKEN_COUNTRY, mktoken(100))
|
||||||
|
|
||||||
assert q.get_tokens(query.TokenRange(0, 1), query.TOKEN_PARTIAL) == []
|
|
||||||
assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_COUNTRY)) == 1
|
assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_COUNTRY)) == 1
|
||||||
|
|
||||||
|
|
||||||
@@ -113,7 +110,7 @@ def test_query_struct_amenity_single_word():
|
|||||||
q.add_token(query.TokenRange(0, 1), query.TOKEN_NEAR_ITEM, mktoken(2))
|
q.add_token(query.TokenRange(0, 1), query.TOKEN_NEAR_ITEM, mktoken(2))
|
||||||
q.add_token(query.TokenRange(0, 1), query.TOKEN_QUALIFIER, mktoken(3))
|
q.add_token(query.TokenRange(0, 1), query.TOKEN_QUALIFIER, mktoken(3))
|
||||||
|
|
||||||
assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_PARTIAL)) == 1
|
assert q.nodes[0].partial.token == 1
|
||||||
assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_NEAR_ITEM)) == 1
|
assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_NEAR_ITEM)) == 1
|
||||||
assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_QUALIFIER)) == 0
|
assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_QUALIFIER)) == 0
|
||||||
|
|
||||||
@@ -128,10 +125,10 @@ def test_query_struct_amenity_two_words():
|
|||||||
q.add_token(query.TokenRange(*trange), query.TOKEN_NEAR_ITEM, mktoken(2))
|
q.add_token(query.TokenRange(*trange), query.TOKEN_NEAR_ITEM, mktoken(2))
|
||||||
q.add_token(query.TokenRange(*trange), query.TOKEN_QUALIFIER, mktoken(3))
|
q.add_token(query.TokenRange(*trange), query.TOKEN_QUALIFIER, mktoken(3))
|
||||||
|
|
||||||
assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_PARTIAL)) == 1
|
assert q.nodes[0].partial.token == 1
|
||||||
assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_NEAR_ITEM)) == 0
|
assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_NEAR_ITEM)) == 0
|
||||||
assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_QUALIFIER)) == 1
|
assert len(q.get_tokens(query.TokenRange(0, 1), query.TOKEN_QUALIFIER)) == 1
|
||||||
|
|
||||||
assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_PARTIAL)) == 1
|
assert q.nodes[1].partial.token == 1
|
||||||
assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_NEAR_ITEM)) == 0
|
assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_NEAR_ITEM)) == 0
|
||||||
assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_QUALIFIER)) == 1
|
assert len(q.get_tokens(query.TokenRange(1, 2), query.TOKEN_QUALIFIER)) == 1
|
||||||
|
|||||||
@@ -69,8 +69,8 @@ async def test_single_phrase_with_unknown_terms(conn):
|
|||||||
assert query.source[0].text == 'foo bar'
|
assert query.source[0].text == 'foo bar'
|
||||||
|
|
||||||
assert query.num_token_slots() == 2
|
assert query.num_token_slots() == 2
|
||||||
assert len(query.nodes[0].starting) == 1
|
assert query.nodes[0].partial.token == 1
|
||||||
assert not query.nodes[1].starting
|
assert query.nodes[1].partial is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -103,8 +103,8 @@ async def test_splitting_in_transliteration(conn):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize('term,order', [('23456', ['P', 'H', 'W', 'w']),
|
@pytest.mark.parametrize('term,order', [('23456', ['P', 'H', 'W']),
|
||||||
('3', ['H', 'W', 'w'])])
|
('3', ['H', 'W'])])
|
||||||
async def test_penalty_postcodes_and_housenumbers(conn, term, order):
|
async def test_penalty_postcodes_and_housenumbers(conn, term, order):
|
||||||
ana = await tok.create_query_analyzer(conn)
|
ana = await tok.create_query_analyzer(conn)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user