search: merge QueryPart array with QueryNodes

The basic information on terms is pretty much always used together
with the node inforamtion. Merging them together saves some
allocation while making lookup easier at the same time.
This commit is contained in:
Sarah Hoffmann
2025-02-26 14:37:08 +01:00
parent eff60ba6be
commit e362a965e1
3 changed files with 100 additions and 83 deletions

View File

@@ -47,40 +47,27 @@ PENALTY_IN_TOKEN_BREAK = {
}
@dataclasses.dataclass
class QueryPart:
""" Normalized and transliterated form of a single term in the query.
When the term came out of a split during the transliteration,
the normalized string is the full word before transliteration.
Check the subsequent break type to figure out if the word is
continued.
Penalty is the break penalty for the break following the token.
"""
token: str
normalized: str
penalty: float
QueryParts = List[QueryPart]
WordDict = Dict[str, List[qmod.TokenRange]]
def extract_words(terms: List[QueryPart], start: int, words: WordDict) -> None:
""" Add all combinations of words in the terms list after the
given position to the word list.
def extract_words(query: qmod.QueryStruct, start: int, words: WordDict) -> None:
""" Add all combinations of words in the terms list starting with
the term leading into node 'start'.
The words found will be added into the 'words' dictionary with
their start and end position.
"""
total = len(terms)
nodes = query.nodes
total = len(nodes)
base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]
for first in range(start, total):
word = terms[first].token
word = nodes[first].term_lookup
penalty = base_penalty
words[word].append(qmod.TokenRange(first, first + 1, penalty=penalty))
words[word].append(qmod.TokenRange(first - 1, first, penalty=penalty))
for last in range(first + 1, min(first + 20, total)):
word = ' '.join((word, terms[last].token))
penalty += terms[last - 1].penalty
words[word].append(qmod.TokenRange(first, last + 1, penalty=penalty))
word = ' '.join((word, nodes[last].term_lookup))
penalty += nodes[last - 1].penalty
words[word].append(qmod.TokenRange(first - 1, last, penalty=penalty))
@dataclasses.dataclass
@@ -216,8 +203,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
if not query.source:
return query
parts, words = self.split_query(query)
log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts))
words = self.split_query(query)
log().var_dump('Transliterated query', lambda: query.get_transliterated_query())
for row in await self.lookup_in_db(list(words.keys())):
for trange in words[row.word_token]:
@@ -234,8 +221,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
else:
query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
self.add_extra_tokens(query, parts)
self.rerank_tokens(query, parts)
self.add_extra_tokens(query)
self.rerank_tokens(query)
log().table_dump('Word tokens', _dump_word_tokens(query))
@@ -248,15 +235,13 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
"""
return cast(str, self.normalizer.transliterate(text)).strip('-: ')
def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]:
def split_query(self, query: qmod.QueryStruct) -> WordDict:
""" Transliterate the phrases and split them into tokens.
Returns the list of transliterated tokens together with their
normalized form and a dictionary of words for lookup together
Returns a dictionary of words for lookup together
with their position.
"""
parts: QueryParts = []
phrase_start = 0
phrase_start = 1
words: WordDict = defaultdict(list)
for phrase in query.source:
query.nodes[-1].ptype = phrase.ptype
@@ -272,18 +257,18 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
if trans:
for term in trans.split(' '):
if term:
parts.append(QueryPart(term, word,
PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN]))
query.add_node(qmod.BREAK_TOKEN, phrase.ptype)
query.nodes[-1].btype = breakchar
parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[breakchar]
query.add_node(qmod.BREAK_TOKEN, phrase.ptype,
PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN],
term, word)
query.nodes[-1].adjust_break(breakchar,
PENALTY_IN_TOKEN_BREAK[breakchar])
extract_words(parts, phrase_start, words)
extract_words(query, phrase_start, words)
phrase_start = len(parts)
query.nodes[-1].btype = qmod.BREAK_END
phrase_start = len(query.nodes)
query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END])
return parts, words
return words
async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
""" Return the token information from the database for the
@@ -292,18 +277,23 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
t = self.conn.t.meta.tables['word']
return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
def add_extra_tokens(self, query: qmod.QueryStruct) -> None:
""" Add tokens to query that are not saved in the database.
"""
for part, node, i in zip(parts, query.nodes, range(1000)):
if len(part.token) <= 4 and part.token.isdigit()\
and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER):
query.add_token(qmod.TokenRange(i, i+1), qmod.TOKEN_HOUSENUMBER,
need_hnr = False
for i, node in enumerate(query.nodes):
is_full_token = node.btype not in (qmod.BREAK_TOKEN, qmod.BREAK_PART)
if need_hnr and is_full_token \
and len(node.term_normalized) <= 4 and node.term_normalized.isdigit():
query.add_token(qmod.TokenRange(i-1, i), qmod.TOKEN_HOUSENUMBER,
ICUToken(penalty=0.5, token=0,
count=1, addr_count=1, lookup_word=part.token,
word_token=part.token, info=None))
count=1, addr_count=1,
lookup_word=node.term_lookup,
word_token=node.term_lookup, info=None))
def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None:
need_hnr = is_full_token and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER)
def rerank_tokens(self, query: qmod.QueryStruct) -> None:
""" Add penalties to tokens that depend on presence of other token.
"""
for i, node, tlist in query.iter_token_lists():
@@ -320,21 +310,15 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER:
repl.add_penalty(0.5 - tlist.tokens[0].penalty)
elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL):
norm = parts[i].normalized
for j in range(i + 1, tlist.end):
if node.btype != qmod.BREAK_TOKEN:
norm += ' ' + parts[j].normalized
norm = ' '.join(n.term_normalized for n in query.nodes[i + 1:tlist.end + 1]
if n.btype != qmod.BREAK_TOKEN)
if not norm:
# Can happen when the token only covers a partial term
norm = query.nodes[i + 1].term_normalized
for token in tlist.tokens:
cast(ICUToken, token).rematch(norm)
def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
out = query.nodes[0].btype
for node, part in zip(query.nodes[1:], parts):
out += part.token + node.btype
return out
def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
for node in query.nodes: