mirror of
https://github.com/osm-search/Nominatim.git
synced 2026-02-26 11:08:13 +00:00
simplify QueryNode penalty and initial assignment
This commit is contained in:
@@ -37,17 +37,16 @@ DB_TO_TOKEN_TYPE = {
|
|||||||
'C': qmod.TOKEN_COUNTRY
|
'C': qmod.TOKEN_COUNTRY
|
||||||
}
|
}
|
||||||
|
|
||||||
PENALTY_IN_TOKEN_BREAK = {
|
PENALTY_BREAK = {
|
||||||
qmod.BREAK_START: 0.5,
|
qmod.BREAK_START: -0.5,
|
||||||
qmod.BREAK_END: 0.5,
|
qmod.BREAK_END: -0.5,
|
||||||
qmod.BREAK_PHRASE: 0.5,
|
qmod.BREAK_PHRASE: -0.5,
|
||||||
qmod.BREAK_SOFT_PHRASE: 0.5,
|
qmod.BREAK_SOFT_PHRASE: -0.5,
|
||||||
qmod.BREAK_WORD: 0.1,
|
qmod.BREAK_WORD: 0.0,
|
||||||
qmod.BREAK_PART: 0.0,
|
qmod.BREAK_PART: 0.2,
|
||||||
qmod.BREAK_TOKEN: 0.0
|
qmod.BREAK_TOKEN: 0.4
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ICUToken(qmod.Token):
|
class ICUToken(qmod.Token):
|
||||||
""" Specialised token for ICU tokenizer.
|
""" Specialised token for ICU tokenizer.
|
||||||
@@ -78,13 +77,13 @@ class ICUToken(qmod.Token):
|
|||||||
self.penalty += (distance/len(self.lookup_word))
|
self.penalty += (distance/len(self.lookup_word))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
|
def from_db_row(row: SaRow) -> 'ICUToken':
|
||||||
""" Create a ICUToken from the row of the word table.
|
""" Create a ICUToken from the row of the word table.
|
||||||
"""
|
"""
|
||||||
count = 1 if row.info is None else row.info.get('count', 1)
|
count = 1 if row.info is None else row.info.get('count', 1)
|
||||||
addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
|
addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
|
||||||
|
|
||||||
penalty = base_penalty
|
penalty = 0.0
|
||||||
if row.type == 'w':
|
if row.type == 'w':
|
||||||
penalty += 0.3
|
penalty += 0.3
|
||||||
elif row.type == 'W':
|
elif row.type == 'W':
|
||||||
@@ -174,11 +173,14 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
|||||||
|
|
||||||
self.split_query(query)
|
self.split_query(query)
|
||||||
log().var_dump('Transliterated query', lambda: query.get_transliterated_query())
|
log().var_dump('Transliterated query', lambda: query.get_transliterated_query())
|
||||||
words = query.extract_words(base_penalty=PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD])
|
words = query.extract_words()
|
||||||
|
|
||||||
for row in await self.lookup_in_db(list(words.keys())):
|
for row in await self.lookup_in_db(list(words.keys())):
|
||||||
for trange in words[row.word_token]:
|
for trange in words[row.word_token]:
|
||||||
token = ICUToken.from_db_row(row, trange.penalty or 0.0)
|
# Create a new token for each position because the token
|
||||||
|
# penalty can vary depending on the position in the query.
|
||||||
|
# (See rerank_tokens() below.)
|
||||||
|
token = ICUToken.from_db_row(row)
|
||||||
if row.type == 'S':
|
if row.type == 'S':
|
||||||
if row.info['op'] in ('in', 'near'):
|
if row.info['op'] in ('in', 'near'):
|
||||||
if trange.start == 0:
|
if trange.start == 0:
|
||||||
@@ -200,6 +202,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
|||||||
lookup_word=pc, word_token=term,
|
lookup_word=pc, word_token=term,
|
||||||
info=None))
|
info=None))
|
||||||
self.rerank_tokens(query)
|
self.rerank_tokens(query)
|
||||||
|
self.compute_break_penalties(query)
|
||||||
|
|
||||||
log().table_dump('Word tokens', _dump_word_tokens(query))
|
log().table_dump('Word tokens', _dump_word_tokens(query))
|
||||||
|
|
||||||
@@ -232,10 +235,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
|||||||
query.add_node(qmod.BREAK_TOKEN, phrase.ptype,
|
query.add_node(qmod.BREAK_TOKEN, phrase.ptype,
|
||||||
PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN],
|
PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN],
|
||||||
term, word)
|
term, word)
|
||||||
query.nodes[-1].adjust_break(breakchar,
|
query.nodes[-1].btype = breakchar
|
||||||
PENALTY_IN_TOKEN_BREAK[breakchar])
|
|
||||||
|
|
||||||
query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END])
|
query.nodes[-1].btype = qmod.BREAK_END
|
||||||
|
|
||||||
async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
|
async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
|
||||||
""" Return the token information from the database for the
|
""" Return the token information from the database for the
|
||||||
@@ -300,6 +302,12 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
|||||||
for token in tokens:
|
for token in tokens:
|
||||||
cast(ICUToken, token).rematch(norm)
|
cast(ICUToken, token).rematch(norm)
|
||||||
|
|
||||||
|
def compute_break_penalties(self, query: qmod.QueryStruct) -> None:
|
||||||
|
""" Set the break penalties for the nodes in the query.
|
||||||
|
"""
|
||||||
|
for node in query.nodes:
|
||||||
|
node.penalty = PENALTY_BREAK[node.btype]
|
||||||
|
|
||||||
|
|
||||||
def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
|
def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
|
||||||
yield ['type', 'from', 'to', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
|
yield ['type', 'from', 'to', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
|
||||||
|
|||||||
@@ -191,7 +191,9 @@ class QueryNode:
|
|||||||
ptype: PhraseType
|
ptype: PhraseType
|
||||||
|
|
||||||
penalty: float
|
penalty: float
|
||||||
""" Penalty for the break at this node.
|
""" Penalty for having a word break at this position. The penalty
|
||||||
|
may be negative, when a word break is more likely than continuing
|
||||||
|
the word after the node.
|
||||||
"""
|
"""
|
||||||
term_lookup: str
|
term_lookup: str
|
||||||
""" Transliterated term ending at this node.
|
""" Transliterated term ending at this node.
|
||||||
@@ -221,12 +223,6 @@ class QueryNode:
|
|||||||
|
|
||||||
return self.partial.count / (self.partial.count + self.partial.addr_count)
|
return self.partial.count / (self.partial.count + self.partial.addr_count)
|
||||||
|
|
||||||
def adjust_break(self, btype: BreakType, penalty: float) -> None:
|
|
||||||
""" Change the break type and penalty for this node.
|
|
||||||
"""
|
|
||||||
self.btype = btype
|
|
||||||
self.penalty = penalty
|
|
||||||
|
|
||||||
def has_tokens(self, end: int, *ttypes: TokenType) -> bool:
|
def has_tokens(self, end: int, *ttypes: TokenType) -> bool:
|
||||||
""" Check if there are tokens of the given types ending at the
|
""" Check if there are tokens of the given types ending at the
|
||||||
given node.
|
given node.
|
||||||
@@ -277,8 +273,7 @@ class QueryStruct:
|
|||||||
self.source = source
|
self.source = source
|
||||||
self.dir_penalty = 0.0
|
self.dir_penalty = 0.0
|
||||||
self.nodes: List[QueryNode] = \
|
self.nodes: List[QueryNode] = \
|
||||||
[QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY,
|
[QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY)]
|
||||||
0.0, '', '')]
|
|
||||||
|
|
||||||
def num_token_slots(self) -> int:
|
def num_token_slots(self) -> int:
|
||||||
""" Return the length of the query in vertice steps.
|
""" Return the length of the query in vertice steps.
|
||||||
@@ -286,13 +281,12 @@ class QueryStruct:
|
|||||||
return len(self.nodes) - 1
|
return len(self.nodes) - 1
|
||||||
|
|
||||||
def add_node(self, btype: BreakType, ptype: PhraseType,
|
def add_node(self, btype: BreakType, ptype: PhraseType,
|
||||||
break_penalty: float = 0.0,
|
|
||||||
term_lookup: str = '', term_normalized: str = '') -> None:
|
term_lookup: str = '', term_normalized: str = '') -> None:
|
||||||
""" Append a new break node with the given break type.
|
""" Append a new break node with the given break type.
|
||||||
The phrase type denotes the type for any tokens starting
|
The phrase type denotes the type for any tokens starting
|
||||||
at the node.
|
at the node.
|
||||||
"""
|
"""
|
||||||
self.nodes.append(QueryNode(btype, ptype, break_penalty, term_lookup, term_normalized))
|
self.nodes.append(QueryNode(btype, ptype, 0.0, term_lookup, term_normalized))
|
||||||
|
|
||||||
def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None:
|
def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None:
|
||||||
""" Add a token to the query. 'start' and 'end' are the indexes of the
|
""" Add a token to the query. 'start' and 'end' are the indexes of the
|
||||||
@@ -386,17 +380,14 @@ class QueryStruct:
|
|||||||
"""
|
"""
|
||||||
return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes)
|
return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes)
|
||||||
|
|
||||||
def extract_words(self, base_penalty: float = 0.0,
|
def extract_words(self, start: int = 0,
|
||||||
start: int = 0,
|
|
||||||
endpos: Optional[int] = None) -> Dict[str, List[TokenRange]]:
|
endpos: Optional[int] = None) -> Dict[str, List[TokenRange]]:
|
||||||
""" Add all combinations of words that can be formed from the terms
|
""" Add all combinations of words that can be formed from the terms
|
||||||
between the given start and endnode. The terms are joined with
|
between the given start and endnode. The terms are joined with
|
||||||
spaces for each break. Words can never go across a BREAK_PHRASE.
|
spaces for each break. Words can never go across a BREAK_PHRASE.
|
||||||
|
|
||||||
The functions returns a dictionary of possible words with their
|
The functions returns a dictionary of possible words with their
|
||||||
position within the query and a penalty. The penalty is computed
|
position within the query.
|
||||||
from the base_penalty plus the penalty for each node the word
|
|
||||||
crosses.
|
|
||||||
"""
|
"""
|
||||||
if endpos is None:
|
if endpos is None:
|
||||||
endpos = len(self.nodes)
|
endpos = len(self.nodes)
|
||||||
@@ -405,16 +396,13 @@ class QueryStruct:
|
|||||||
|
|
||||||
for first, first_node in enumerate(self.nodes[start + 1:endpos], start):
|
for first, first_node in enumerate(self.nodes[start + 1:endpos], start):
|
||||||
word = first_node.term_lookup
|
word = first_node.term_lookup
|
||||||
penalty = base_penalty
|
words[word].append(TokenRange(first, first + 1))
|
||||||
words[word].append(TokenRange(first, first + 1, penalty=penalty))
|
|
||||||
if first_node.btype != BREAK_PHRASE:
|
if first_node.btype != BREAK_PHRASE:
|
||||||
penalty += first_node.penalty
|
|
||||||
max_last = min(first + 20, endpos)
|
max_last = min(first + 20, endpos)
|
||||||
for last, last_node in enumerate(self.nodes[first + 2:max_last], first + 2):
|
for last, last_node in enumerate(self.nodes[first + 2:max_last], first + 2):
|
||||||
word = ' '.join((word, last_node.term_lookup))
|
word = ' '.join((word, last_node.term_lookup))
|
||||||
words[word].append(TokenRange(first, last, penalty=penalty))
|
words[word].append(TokenRange(first, last))
|
||||||
if last_node.btype == BREAK_PHRASE:
|
if last_node.btype == BREAK_PHRASE:
|
||||||
break
|
break
|
||||||
penalty += last_node.penalty
|
|
||||||
|
|
||||||
return words
|
return words
|
||||||
|
|||||||
@@ -23,16 +23,6 @@ class TypedRange:
|
|||||||
trange: qmod.TokenRange
|
trange: qmod.TokenRange
|
||||||
|
|
||||||
|
|
||||||
PENALTY_TOKENCHANGE = {
|
|
||||||
qmod.BREAK_START: 0.0,
|
|
||||||
qmod.BREAK_END: 0.0,
|
|
||||||
qmod.BREAK_PHRASE: 0.0,
|
|
||||||
qmod.BREAK_SOFT_PHRASE: 0.0,
|
|
||||||
qmod.BREAK_WORD: 0.1,
|
|
||||||
qmod.BREAK_PART: 0.2,
|
|
||||||
qmod.BREAK_TOKEN: 0.4
|
|
||||||
}
|
|
||||||
|
|
||||||
TypedRangeSeq = List[TypedRange]
|
TypedRangeSeq = List[TypedRange]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user