mirror of
https://github.com/osm-search/Nominatim.git
synced 2026-02-26 11:08:13 +00:00
rebalance word transition penalties
This commit is contained in:
@@ -282,10 +282,14 @@ class SearchBuilder:
|
|||||||
""" Create a ranking expression for a name term in the given range.
|
""" Create a ranking expression for a name term in the given range.
|
||||||
"""
|
"""
|
||||||
name_fulls = self.query.get_tokens(trange, qmod.TOKEN_WORD)
|
name_fulls = self.query.get_tokens(trange, qmod.TOKEN_WORD)
|
||||||
ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
|
full_word_penalty = self.query.get_in_word_penalty(trange)
|
||||||
|
ranks = [dbf.RankedTokens(t.penalty + full_word_penalty, [t.token])
|
||||||
|
for t in name_fulls]
|
||||||
ranks.sort(key=lambda r: r.penalty)
|
ranks.sort(key=lambda r: r.penalty)
|
||||||
# Fallback, sum of penalty for partials
|
# Fallback, sum of penalty for partials
|
||||||
default = sum(t.penalty for t in self.query.iter_partials(trange)) + 0.2
|
default = sum(t.penalty for t in self.query.iter_partials(trange))
|
||||||
|
default += sum(n.word_break_penalty
|
||||||
|
for n in self.query.nodes[trange.start + 1:trange.end])
|
||||||
return dbf.FieldRanking(db_field, default, ranks)
|
return dbf.FieldRanking(db_field, default, ranks)
|
||||||
|
|
||||||
def get_addr_ranking(self, trange: qmod.TokenRange) -> dbf.FieldRanking:
|
def get_addr_ranking(self, trange: qmod.TokenRange) -> dbf.FieldRanking:
|
||||||
@@ -303,7 +307,7 @@ class SearchBuilder:
|
|||||||
if partial is not None:
|
if partial is not None:
|
||||||
if pos + 1 < trange.end:
|
if pos + 1 < trange.end:
|
||||||
penalty = rank.penalty + partial.penalty \
|
penalty = rank.penalty + partial.penalty \
|
||||||
+ PENALTY_WORDCHANGE[self.query.nodes[pos + 1].btype]
|
+ self.query.nodes[pos + 1].word_break_penalty
|
||||||
heapq.heappush(todo, (neglen - 1, pos + 1,
|
heapq.heappush(todo, (neglen - 1, pos + 1,
|
||||||
dbf.RankedTokens(penalty, rank.tokens)))
|
dbf.RankedTokens(penalty, rank.tokens)))
|
||||||
else:
|
else:
|
||||||
@@ -313,7 +317,9 @@ class SearchBuilder:
|
|||||||
for tlist in self.query.nodes[pos].starting:
|
for tlist in self.query.nodes[pos].starting:
|
||||||
if tlist.ttype == qmod.TOKEN_WORD:
|
if tlist.ttype == qmod.TOKEN_WORD:
|
||||||
if tlist.end < trange.end:
|
if tlist.end < trange.end:
|
||||||
chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
|
chgpenalty = self.query.nodes[tlist.end].word_break_penalty \
|
||||||
|
+ self.query.get_in_word_penalty(
|
||||||
|
qmod.TokenRange(pos, tlist.end))
|
||||||
for t in tlist.tokens:
|
for t in tlist.tokens:
|
||||||
heapq.heappush(todo, (neglen - 1, tlist.end,
|
heapq.heappush(todo, (neglen - 1, tlist.end,
|
||||||
rank.with_token(t, chgpenalty)))
|
rank.with_token(t, chgpenalty)))
|
||||||
@@ -323,7 +329,9 @@ class SearchBuilder:
|
|||||||
if len(ranks) >= 10:
|
if len(ranks) >= 10:
|
||||||
# Too many variants, bail out and only add
|
# Too many variants, bail out and only add
|
||||||
# Worst-case Fallback: sum of penalty of partials
|
# Worst-case Fallback: sum of penalty of partials
|
||||||
default = sum(t.penalty for t in self.query.iter_partials(trange)) + 0.2
|
default = sum(t.penalty for t in self.query.iter_partials(trange))
|
||||||
|
default += sum(n.word_break_penalty
|
||||||
|
for n in self.query.nodes[trange.start + 1:trange.end])
|
||||||
ranks.append(dbf.RankedTokens(rank.penalty + default, []))
|
ranks.append(dbf.RankedTokens(rank.penalty + default, []))
|
||||||
# Bail out of outer loop
|
# Bail out of outer loop
|
||||||
break
|
break
|
||||||
@@ -346,6 +354,7 @@ class SearchBuilder:
|
|||||||
if not tokens:
|
if not tokens:
|
||||||
return None
|
return None
|
||||||
sdata.set_strings('countries', tokens)
|
sdata.set_strings('countries', tokens)
|
||||||
|
sdata.penalty += self.query.get_in_word_penalty(assignment.country)
|
||||||
elif self.details.countries:
|
elif self.details.countries:
|
||||||
sdata.countries = dbf.WeightedStrings(self.details.countries,
|
sdata.countries = dbf.WeightedStrings(self.details.countries,
|
||||||
[0.0] * len(self.details.countries))
|
[0.0] * len(self.details.countries))
|
||||||
@@ -353,29 +362,24 @@ class SearchBuilder:
|
|||||||
sdata.set_strings('housenumbers',
|
sdata.set_strings('housenumbers',
|
||||||
self.query.get_tokens(assignment.housenumber,
|
self.query.get_tokens(assignment.housenumber,
|
||||||
qmod.TOKEN_HOUSENUMBER))
|
qmod.TOKEN_HOUSENUMBER))
|
||||||
|
sdata.penalty += self.query.get_in_word_penalty(assignment.housenumber)
|
||||||
if assignment.postcode:
|
if assignment.postcode:
|
||||||
sdata.set_strings('postcodes',
|
sdata.set_strings('postcodes',
|
||||||
self.query.get_tokens(assignment.postcode,
|
self.query.get_tokens(assignment.postcode,
|
||||||
qmod.TOKEN_POSTCODE))
|
qmod.TOKEN_POSTCODE))
|
||||||
|
sdata.penalty += self.query.get_in_word_penalty(assignment.postcode)
|
||||||
if assignment.qualifier:
|
if assignment.qualifier:
|
||||||
tokens = self.get_qualifier_tokens(assignment.qualifier)
|
tokens = self.get_qualifier_tokens(assignment.qualifier)
|
||||||
if not tokens:
|
if not tokens:
|
||||||
return None
|
return None
|
||||||
sdata.set_qualifiers(tokens)
|
sdata.set_qualifiers(tokens)
|
||||||
|
sdata.penalty += self.query.get_in_word_penalty(assignment.qualifier)
|
||||||
elif self.details.categories:
|
elif self.details.categories:
|
||||||
sdata.qualifiers = dbf.WeightedCategories(self.details.categories,
|
sdata.qualifiers = dbf.WeightedCategories(self.details.categories,
|
||||||
[0.0] * len(self.details.categories))
|
[0.0] * len(self.details.categories))
|
||||||
|
|
||||||
if assignment.address:
|
if assignment.address:
|
||||||
if not assignment.name and assignment.housenumber:
|
sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
|
||||||
# housenumber search: the first item needs to be handled like
|
|
||||||
# a name in ranking or penalties are not comparable with
|
|
||||||
# normal searches.
|
|
||||||
sdata.set_ranking([self.get_name_ranking(assignment.address[0],
|
|
||||||
db_field='nameaddress_vector')]
|
|
||||||
+ [self.get_addr_ranking(r) for r in assignment.address[1:]])
|
|
||||||
else:
|
|
||||||
sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
|
|
||||||
else:
|
else:
|
||||||
sdata.rankings = []
|
sdata.rankings = []
|
||||||
|
|
||||||
@@ -421,14 +425,3 @@ class SearchBuilder:
|
|||||||
return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))
|
return dbf.WeightedCategories(list(tokens.keys()), list(tokens.values()))
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
PENALTY_WORDCHANGE = {
|
|
||||||
qmod.BREAK_START: 0.0,
|
|
||||||
qmod.BREAK_END: 0.0,
|
|
||||||
qmod.BREAK_PHRASE: 0.0,
|
|
||||||
qmod.BREAK_SOFT_PHRASE: 0.0,
|
|
||||||
qmod.BREAK_WORD: 0.1,
|
|
||||||
qmod.BREAK_PART: 0.2,
|
|
||||||
qmod.BREAK_TOKEN: 0.4
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ PENALTY_BREAK = {
|
|||||||
qmod.BREAK_TOKEN: 0.4
|
qmod.BREAK_TOKEN: 0.4
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ICUToken(qmod.Token):
|
class ICUToken(qmod.Token):
|
||||||
""" Specialised token for ICU tokenizer.
|
""" Specialised token for ICU tokenizer.
|
||||||
@@ -232,9 +233,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
|||||||
if trans:
|
if trans:
|
||||||
for term in trans.split(' '):
|
for term in trans.split(' '):
|
||||||
if term:
|
if term:
|
||||||
query.add_node(qmod.BREAK_TOKEN, phrase.ptype,
|
query.add_node(qmod.BREAK_TOKEN, phrase.ptype, term, word)
|
||||||
PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN],
|
|
||||||
term, word)
|
|
||||||
query.nodes[-1].btype = breakchar
|
query.nodes[-1].btype = breakchar
|
||||||
|
|
||||||
query.nodes[-1].btype = qmod.BREAK_END
|
query.nodes[-1].btype = qmod.BREAK_END
|
||||||
|
|||||||
@@ -214,6 +214,19 @@ class QueryNode:
|
|||||||
types of tokens spanning over the gap.
|
types of tokens spanning over the gap.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def word_break_penalty(self) -> float:
|
||||||
|
""" Penalty to apply when a words ends at this node.
|
||||||
|
"""
|
||||||
|
return max(0, self.penalty)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def word_continuation_penalty(self) -> float:
|
||||||
|
""" Penalty to apply when a word continues over this node
|
||||||
|
(i.e. is a multi-term word).
|
||||||
|
"""
|
||||||
|
return max(0, -self.penalty)
|
||||||
|
|
||||||
def name_address_ratio(self) -> float:
|
def name_address_ratio(self) -> float:
|
||||||
""" Return the propability that the partial token belonging to
|
""" Return the propability that the partial token belonging to
|
||||||
this node forms part of a name (as opposed of part of the address).
|
this node forms part of a name (as opposed of part of the address).
|
||||||
@@ -273,7 +286,8 @@ class QueryStruct:
|
|||||||
self.source = source
|
self.source = source
|
||||||
self.dir_penalty = 0.0
|
self.dir_penalty = 0.0
|
||||||
self.nodes: List[QueryNode] = \
|
self.nodes: List[QueryNode] = \
|
||||||
[QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY)]
|
[QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY,
|
||||||
|
0.0, '', '')]
|
||||||
|
|
||||||
def num_token_slots(self) -> int:
|
def num_token_slots(self) -> int:
|
||||||
""" Return the length of the query in vertice steps.
|
""" Return the length of the query in vertice steps.
|
||||||
@@ -338,6 +352,13 @@ class QueryStruct:
|
|||||||
assert ttype != TOKEN_PARTIAL
|
assert ttype != TOKEN_PARTIAL
|
||||||
return self.nodes[trange.start].get_tokens(trange.end, ttype) or []
|
return self.nodes[trange.start].get_tokens(trange.end, ttype) or []
|
||||||
|
|
||||||
|
def get_in_word_penalty(self, trange: TokenRange) -> float:
|
||||||
|
""" Gets the sum of penalties for all token transitions
|
||||||
|
within the given range.
|
||||||
|
"""
|
||||||
|
return sum(n.word_continuation_penalty
|
||||||
|
for n in self.nodes[trange.start + 1:trange.end])
|
||||||
|
|
||||||
def iter_partials(self, trange: TokenRange) -> Iterator[Token]:
|
def iter_partials(self, trange: TokenRange) -> Iterator[Token]:
|
||||||
""" Iterate over the partial tokens between the given nodes.
|
""" Iterate over the partial tokens between the given nodes.
|
||||||
Missing partials are ignored.
|
Missing partials are ignored.
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ class _TokenSequence:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def advance(self, ttype: qmod.TokenType, end_pos: int,
|
def advance(self, ttype: qmod.TokenType, end_pos: int,
|
||||||
btype: qmod.BreakType) -> Optional['_TokenSequence']:
|
force_break: bool, break_penalty: float) -> Optional['_TokenSequence']:
|
||||||
""" Return a new token sequence state with the given token type
|
""" Return a new token sequence state with the given token type
|
||||||
extended.
|
extended.
|
||||||
"""
|
"""
|
||||||
@@ -195,7 +195,7 @@ class _TokenSequence:
|
|||||||
new_penalty = 0.0
|
new_penalty = 0.0
|
||||||
else:
|
else:
|
||||||
last = self.seq[-1]
|
last = self.seq[-1]
|
||||||
if btype != qmod.BREAK_PHRASE and last.ttype == ttype:
|
if not force_break and last.ttype == ttype:
|
||||||
# extend the existing range
|
# extend the existing range
|
||||||
newseq = self.seq[:-1] + [TypedRange(ttype, last.trange.replace_end(end_pos))]
|
newseq = self.seq[:-1] + [TypedRange(ttype, last.trange.replace_end(end_pos))]
|
||||||
new_penalty = 0.0
|
new_penalty = 0.0
|
||||||
@@ -203,7 +203,7 @@ class _TokenSequence:
|
|||||||
# start a new range
|
# start a new range
|
||||||
newseq = list(self.seq) + [TypedRange(ttype,
|
newseq = list(self.seq) + [TypedRange(ttype,
|
||||||
qmod.TokenRange(last.trange.end, end_pos))]
|
qmod.TokenRange(last.trange.end, end_pos))]
|
||||||
new_penalty = PENALTY_TOKENCHANGE[btype]
|
new_penalty = break_penalty
|
||||||
|
|
||||||
return _TokenSequence(newseq, newdir, self.penalty + new_penalty)
|
return _TokenSequence(newseq, newdir, self.penalty + new_penalty)
|
||||||
|
|
||||||
@@ -307,7 +307,7 @@ class _TokenSequence:
|
|||||||
name, addr = first.split(i)
|
name, addr = first.split(i)
|
||||||
log().comment(f'split first word = name ({i - first.start})')
|
log().comment(f'split first word = name ({i - first.start})')
|
||||||
yield dataclasses.replace(base, name=name, address=[addr] + base.address[1:],
|
yield dataclasses.replace(base, name=name, address=[addr] + base.address[1:],
|
||||||
penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
|
penalty=penalty + query.nodes[i].word_break_penalty)
|
||||||
|
|
||||||
def _get_assignments_address_backward(self, base: TokenAssignment,
|
def _get_assignments_address_backward(self, base: TokenAssignment,
|
||||||
query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
|
query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
|
||||||
@@ -352,7 +352,7 @@ class _TokenSequence:
|
|||||||
addr, name = last.split(i)
|
addr, name = last.split(i)
|
||||||
log().comment(f'split last word = name ({i - last.start})')
|
log().comment(f'split last word = name ({i - last.start})')
|
||||||
yield dataclasses.replace(base, name=name, address=base.address[:-1] + [addr],
|
yield dataclasses.replace(base, name=name, address=base.address[:-1] + [addr],
|
||||||
penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
|
penalty=penalty + query.nodes[i].word_break_penalty)
|
||||||
|
|
||||||
def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
|
def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]:
|
||||||
""" Yield possible assignments for the current sequence.
|
""" Yield possible assignments for the current sequence.
|
||||||
@@ -412,12 +412,15 @@ def yield_token_assignments(query: qmod.QueryStruct) -> Iterator[TokenAssignment
|
|||||||
for tlist in node.starting:
|
for tlist in node.starting:
|
||||||
yield from _append_state_to_todo(
|
yield from _append_state_to_todo(
|
||||||
query, todo,
|
query, todo,
|
||||||
state.advance(tlist.ttype, tlist.end, node.btype))
|
state.advance(tlist.ttype, tlist.end,
|
||||||
|
True, node.word_break_penalty))
|
||||||
|
|
||||||
if node.partial is not None:
|
if node.partial is not None:
|
||||||
yield from _append_state_to_todo(
|
yield from _append_state_to_todo(
|
||||||
query, todo,
|
query, todo,
|
||||||
state.advance(qmod.TOKEN_PARTIAL, state.end_pos + 1, node.btype))
|
state.advance(qmod.TOKEN_PARTIAL, state.end_pos + 1,
|
||||||
|
node.btype == qmod.BREAK_PHRASE,
|
||||||
|
node.word_break_penalty))
|
||||||
|
|
||||||
|
|
||||||
def _append_state_to_todo(query: qmod.QueryStruct, todo: List[_TokenSequence],
|
def _append_state_to_todo(query: qmod.QueryStruct, todo: List[_TokenSequence],
|
||||||
|
|||||||
Reference in New Issue
Block a user