forked from hans/Nominatim
simplify QueryNode penalty and initial assignment
This commit is contained in:
@@ -37,17 +37,16 @@ DB_TO_TOKEN_TYPE = {
|
||||
'C': qmod.TOKEN_COUNTRY
|
||||
}
|
||||
|
||||
PENALTY_IN_TOKEN_BREAK = {
|
||||
qmod.BREAK_START: 0.5,
|
||||
qmod.BREAK_END: 0.5,
|
||||
qmod.BREAK_PHRASE: 0.5,
|
||||
qmod.BREAK_SOFT_PHRASE: 0.5,
|
||||
qmod.BREAK_WORD: 0.1,
|
||||
qmod.BREAK_PART: 0.0,
|
||||
qmod.BREAK_TOKEN: 0.0
|
||||
PENALTY_BREAK = {
|
||||
qmod.BREAK_START: -0.5,
|
||||
qmod.BREAK_END: -0.5,
|
||||
qmod.BREAK_PHRASE: -0.5,
|
||||
qmod.BREAK_SOFT_PHRASE: -0.5,
|
||||
qmod.BREAK_WORD: 0.0,
|
||||
qmod.BREAK_PART: 0.2,
|
||||
qmod.BREAK_TOKEN: 0.4
|
||||
}
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ICUToken(qmod.Token):
|
||||
""" Specialised token for ICU tokenizer.
|
||||
@@ -78,13 +77,13 @@ class ICUToken(qmod.Token):
|
||||
self.penalty += (distance/len(self.lookup_word))
|
||||
|
||||
@staticmethod
|
||||
def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
|
||||
def from_db_row(row: SaRow) -> 'ICUToken':
|
||||
""" Create a ICUToken from the row of the word table.
|
||||
"""
|
||||
count = 1 if row.info is None else row.info.get('count', 1)
|
||||
addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
|
||||
|
||||
penalty = base_penalty
|
||||
penalty = 0.0
|
||||
if row.type == 'w':
|
||||
penalty += 0.3
|
||||
elif row.type == 'W':
|
||||
@@ -174,11 +173,14 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
||||
|
||||
self.split_query(query)
|
||||
log().var_dump('Transliterated query', lambda: query.get_transliterated_query())
|
||||
words = query.extract_words(base_penalty=PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD])
|
||||
words = query.extract_words()
|
||||
|
||||
for row in await self.lookup_in_db(list(words.keys())):
|
||||
for trange in words[row.word_token]:
|
||||
token = ICUToken.from_db_row(row, trange.penalty or 0.0)
|
||||
# Create a new token for each position because the token
|
||||
# penalty can vary depending on the position in the query.
|
||||
# (See rerank_tokens() below.)
|
||||
token = ICUToken.from_db_row(row)
|
||||
if row.type == 'S':
|
||||
if row.info['op'] in ('in', 'near'):
|
||||
if trange.start == 0:
|
||||
@@ -200,6 +202,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
||||
lookup_word=pc, word_token=term,
|
||||
info=None))
|
||||
self.rerank_tokens(query)
|
||||
self.compute_break_penalties(query)
|
||||
|
||||
log().table_dump('Word tokens', _dump_word_tokens(query))
|
||||
|
||||
@@ -232,10 +235,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
||||
query.add_node(qmod.BREAK_TOKEN, phrase.ptype,
|
||||
PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN],
|
||||
term, word)
|
||||
query.nodes[-1].adjust_break(breakchar,
|
||||
PENALTY_IN_TOKEN_BREAK[breakchar])
|
||||
query.nodes[-1].btype = breakchar
|
||||
|
||||
query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END])
|
||||
query.nodes[-1].btype = qmod.BREAK_END
|
||||
|
||||
async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
|
||||
""" Return the token information from the database for the
|
||||
@@ -300,6 +302,12 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
||||
for token in tokens:
|
||||
cast(ICUToken, token).rematch(norm)
|
||||
|
||||
def compute_break_penalties(self, query: qmod.QueryStruct) -> None:
|
||||
""" Set the break penalties for the nodes in the query.
|
||||
"""
|
||||
for node in query.nodes:
|
||||
node.penalty = PENALTY_BREAK[node.btype]
|
||||
|
||||
|
||||
def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
|
||||
yield ['type', 'from', 'to', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
|
||||
|
||||
@@ -191,7 +191,9 @@ class QueryNode:
|
||||
ptype: PhraseType
|
||||
|
||||
penalty: float
|
||||
""" Penalty for the break at this node.
|
||||
""" Penalty for having a word break at this position. The penalty
|
||||
may be negative, when a word break is more likely than continuing
|
||||
the word after the node.
|
||||
"""
|
||||
term_lookup: str
|
||||
""" Transliterated term ending at this node.
|
||||
@@ -221,12 +223,6 @@ class QueryNode:
|
||||
|
||||
return self.partial.count / (self.partial.count + self.partial.addr_count)
|
||||
|
||||
def adjust_break(self, btype: BreakType, penalty: float) -> None:
|
||||
""" Change the break type and penalty for this node.
|
||||
"""
|
||||
self.btype = btype
|
||||
self.penalty = penalty
|
||||
|
||||
def has_tokens(self, end: int, *ttypes: TokenType) -> bool:
|
||||
""" Check if there are tokens of the given types ending at the
|
||||
given node.
|
||||
@@ -277,8 +273,7 @@ class QueryStruct:
|
||||
self.source = source
|
||||
self.dir_penalty = 0.0
|
||||
self.nodes: List[QueryNode] = \
|
||||
[QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY,
|
||||
0.0, '', '')]
|
||||
[QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY)]
|
||||
|
||||
def num_token_slots(self) -> int:
|
||||
""" Return the length of the query in vertice steps.
|
||||
@@ -286,13 +281,12 @@ class QueryStruct:
|
||||
return len(self.nodes) - 1
|
||||
|
||||
def add_node(self, btype: BreakType, ptype: PhraseType,
|
||||
break_penalty: float = 0.0,
|
||||
term_lookup: str = '', term_normalized: str = '') -> None:
|
||||
""" Append a new break node with the given break type.
|
||||
The phrase type denotes the type for any tokens starting
|
||||
at the node.
|
||||
"""
|
||||
self.nodes.append(QueryNode(btype, ptype, break_penalty, term_lookup, term_normalized))
|
||||
self.nodes.append(QueryNode(btype, ptype, 0.0, term_lookup, term_normalized))
|
||||
|
||||
def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None:
|
||||
""" Add a token to the query. 'start' and 'end' are the indexes of the
|
||||
@@ -386,17 +380,14 @@ class QueryStruct:
|
||||
"""
|
||||
return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes)
|
||||
|
||||
def extract_words(self, base_penalty: float = 0.0,
|
||||
start: int = 0,
|
||||
def extract_words(self, start: int = 0,
|
||||
endpos: Optional[int] = None) -> Dict[str, List[TokenRange]]:
|
||||
""" Add all combinations of words that can be formed from the terms
|
||||
between the given start and endnode. The terms are joined with
|
||||
spaces for each break. Words can never go across a BREAK_PHRASE.
|
||||
|
||||
The functions returns a dictionary of possible words with their
|
||||
position within the query and a penalty. The penalty is computed
|
||||
from the base_penalty plus the penalty for each node the word
|
||||
crosses.
|
||||
position within the query.
|
||||
"""
|
||||
if endpos is None:
|
||||
endpos = len(self.nodes)
|
||||
@@ -405,16 +396,13 @@ class QueryStruct:
|
||||
|
||||
for first, first_node in enumerate(self.nodes[start + 1:endpos], start):
|
||||
word = first_node.term_lookup
|
||||
penalty = base_penalty
|
||||
words[word].append(TokenRange(first, first + 1, penalty=penalty))
|
||||
words[word].append(TokenRange(first, first + 1))
|
||||
if first_node.btype != BREAK_PHRASE:
|
||||
penalty += first_node.penalty
|
||||
max_last = min(first + 20, endpos)
|
||||
for last, last_node in enumerate(self.nodes[first + 2:max_last], first + 2):
|
||||
word = ' '.join((word, last_node.term_lookup))
|
||||
words[word].append(TokenRange(first, last, penalty=penalty))
|
||||
words[word].append(TokenRange(first, last))
|
||||
if last_node.btype == BREAK_PHRASE:
|
||||
break
|
||||
penalty += last_node.penalty
|
||||
|
||||
return words
|
||||
|
||||
@@ -23,16 +23,6 @@ class TypedRange:
|
||||
trange: qmod.TokenRange
|
||||
|
||||
|
||||
PENALTY_TOKENCHANGE = {
|
||||
qmod.BREAK_START: 0.0,
|
||||
qmod.BREAK_END: 0.0,
|
||||
qmod.BREAK_PHRASE: 0.0,
|
||||
qmod.BREAK_SOFT_PHRASE: 0.0,
|
||||
qmod.BREAK_WORD: 0.1,
|
||||
qmod.BREAK_PART: 0.2,
|
||||
qmod.BREAK_TOKEN: 0.4
|
||||
}
|
||||
|
||||
TypedRangeSeq = List[TypedRange]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user