forked from hans/Nominatim
add inner word break penalty
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
"""
|
||||
Implementation of query analysis for the ICU tokenizer.
|
||||
"""
|
||||
from typing import Tuple, Dict, List, Optional, NamedTuple, Iterator, Any, cast
|
||||
from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
|
||||
from collections import defaultdict
|
||||
import dataclasses
|
||||
import difflib
|
||||
@@ -36,17 +36,30 @@ DB_TO_TOKEN_TYPE = {
|
||||
'C': qmod.TokenType.COUNTRY
|
||||
}
|
||||
|
||||
PENALTY_IN_TOKEN_BREAK = {
|
||||
qmod.BreakType.START: 0.5,
|
||||
qmod.BreakType.END: 0.5,
|
||||
qmod.BreakType.PHRASE: 0.5,
|
||||
qmod.BreakType.SOFT_PHRASE: 0.5,
|
||||
qmod.BreakType.WORD: 0.1,
|
||||
qmod.BreakType.PART: 0.0,
|
||||
qmod.BreakType.TOKEN: 0.0
|
||||
}
|
||||
|
||||
class QueryPart(NamedTuple):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class QueryPart:
|
||||
""" Normalized and transliterated form of a single term in the query.
|
||||
When the term came out of a split during the transliteration,
|
||||
the normalized string is the full word before transliteration.
|
||||
The word number keeps track of the word before transliteration
|
||||
and can be used to identify partial transliterated terms.
|
||||
Penalty is the break penalty for the break following the token.
|
||||
"""
|
||||
token: str
|
||||
normalized: str
|
||||
word_number: int
|
||||
penalty: float
|
||||
|
||||
|
||||
QueryParts = List[QueryPart]
|
||||
@@ -60,10 +73,12 @@ def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.
|
||||
total = len(terms)
|
||||
for first in range(start, total):
|
||||
word = terms[first].token
|
||||
yield word, qmod.TokenRange(first, first + 1)
|
||||
penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType.WORD]
|
||||
yield word, qmod.TokenRange(first, first + 1, penalty=penalty)
|
||||
for last in range(first + 1, min(first + 20, total)):
|
||||
word = ' '.join((word, terms[last].token))
|
||||
yield word, qmod.TokenRange(first, last + 1)
|
||||
penalty += terms[last - 1].penalty
|
||||
yield word, qmod.TokenRange(first, last + 1, penalty=penalty)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -96,25 +111,25 @@ class ICUToken(qmod.Token):
|
||||
self.penalty += (distance/len(self.lookup_word))
|
||||
|
||||
@staticmethod
|
||||
def from_db_row(row: SaRow) -> 'ICUToken':
|
||||
def from_db_row(row: SaRow, base_penalty: float = 0.0) -> 'ICUToken':
|
||||
""" Create a ICUToken from the row of the word table.
|
||||
"""
|
||||
count = 1 if row.info is None else row.info.get('count', 1)
|
||||
addr_count = 1 if row.info is None else row.info.get('addr_count', 1)
|
||||
|
||||
penalty = 0.0
|
||||
penalty = base_penalty
|
||||
if row.type == 'w':
|
||||
penalty = 0.3
|
||||
penalty += 0.3
|
||||
elif row.type == 'W':
|
||||
if len(row.word_token) == 1 and row.word_token == row.word:
|
||||
penalty = 0.2 if row.word.isdigit() else 0.3
|
||||
penalty += 0.2 if row.word.isdigit() else 0.3
|
||||
elif row.type == 'H':
|
||||
penalty = sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
|
||||
penalty += sum(0.1 for c in row.word_token if c != ' ' and not c.isdigit())
|
||||
if all(not c.isdigit() for c in row.word_token):
|
||||
penalty += 0.2 * (len(row.word_token) - 1)
|
||||
elif row.type == 'C':
|
||||
if len(row.word_token) == 1:
|
||||
penalty = 0.3
|
||||
penalty += 0.3
|
||||
|
||||
if row.info is None:
|
||||
lookup_word = row.word
|
||||
@@ -204,7 +219,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
||||
|
||||
for row in await self.lookup_in_db(list(words.keys())):
|
||||
for trange in words[row.word_token]:
|
||||
token = ICUToken.from_db_row(row)
|
||||
token = ICUToken.from_db_row(row, trange.penalty or 0.0)
|
||||
if row.type == 'S':
|
||||
if row.info['op'] in ('in', 'near'):
|
||||
if trange.start == 0:
|
||||
@@ -256,9 +271,11 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
||||
if trans:
|
||||
for term in trans.split(' '):
|
||||
if term:
|
||||
parts.append(QueryPart(term, word, wordnr))
|
||||
parts.append(QueryPart(term, word, wordnr,
|
||||
PENALTY_IN_TOKEN_BREAK[qmod.BreakType.TOKEN]))
|
||||
query.add_node(qmod.BreakType.TOKEN, phrase.ptype)
|
||||
query.nodes[-1].btype = qmod.BreakType(breakchar)
|
||||
parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType(breakchar)]
|
||||
wordnr += 1
|
||||
|
||||
for word, wrange in yield_words(parts, phrase_start):
|
||||
@@ -280,7 +297,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
|
||||
""" Add tokens to query that are not saved in the database.
|
||||
"""
|
||||
for part, node, i in zip(parts, query.nodes, range(1000)):
|
||||
if len(part.token) <= 4 and part[0].isdigit()\
|
||||
if len(part.token) <= 4 and part.token.isdigit()\
|
||||
and not node.has_tokens(i+1, qmod.TokenType.HOUSENUMBER):
|
||||
query.add_token(qmod.TokenRange(i, i+1), qmod.TokenType.HOUSENUMBER,
|
||||
ICUToken(penalty=0.5, token=0,
|
||||
|
||||
@@ -122,6 +122,7 @@ class TokenRange:
|
||||
"""
|
||||
start: int
|
||||
end: int
|
||||
penalty: Optional[float] = None
|
||||
|
||||
def __lt__(self, other: 'TokenRange') -> bool:
|
||||
return self.end <= other.start
|
||||
|
||||
Reference in New Issue
Block a user