search: merge QueryPart array with QueryNodes

The basic information on terms is pretty much always used together
with the node inforamtion. Merging them together saves some
allocation while making lookup easier at the same time.
This commit is contained in:
Sarah Hoffmann
2025-02-26 14:37:08 +01:00
parent eff60ba6be
commit e362a965e1
3 changed files with 100 additions and 83 deletions

View File

@@ -47,40 +47,27 @@ PENALTY_IN_TOKEN_BREAK = {
} }
@dataclasses.dataclass
class QueryPart:
""" Normalized and transliterated form of a single term in the query.
When the term came out of a split during the transliteration,
the normalized string is the full word before transliteration.
Check the subsequent break type to figure out if the word is
continued.
Penalty is the break penalty for the break following the token.
"""
token: str
normalized: str
penalty: float
QueryParts = List[QueryPart]
WordDict = Dict[str, List[qmod.TokenRange]] WordDict = Dict[str, List[qmod.TokenRange]]
def extract_words(terms: List[QueryPart], start: int, words: WordDict) -> None: def extract_words(query: qmod.QueryStruct, start: int, words: WordDict) -> None:
""" Add all combinations of words in the terms list after the """ Add all combinations of words in the terms list starting with
given position to the word list. the term leading into node 'start'.
The words found will be added into the 'words' dictionary with
their start and end position.
""" """
total = len(terms) nodes = query.nodes
total = len(nodes)
base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD] base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]
for first in range(start, total): for first in range(start, total):
word = terms[first].token word = nodes[first].term_lookup
penalty = base_penalty penalty = base_penalty
words[word].append(qmod.TokenRange(first, first + 1, penalty=penalty)) words[word].append(qmod.TokenRange(first - 1, first, penalty=penalty))
for last in range(first + 1, min(first + 20, total)): for last in range(first + 1, min(first + 20, total)):
word = ' '.join((word, terms[last].token)) word = ' '.join((word, nodes[last].term_lookup))
penalty += terms[last - 1].penalty penalty += nodes[last - 1].penalty
words[word].append(qmod.TokenRange(first, last + 1, penalty=penalty)) words[word].append(qmod.TokenRange(first - 1, last, penalty=penalty))
@dataclasses.dataclass @dataclasses.dataclass
@@ -216,8 +203,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
if not query.source: if not query.source:
return query return query
parts, words = self.split_query(query) words = self.split_query(query)
log().var_dump('Transliterated query', lambda: _dump_transliterated(query, parts)) log().var_dump('Transliterated query', lambda: query.get_transliterated_query())
for row in await self.lookup_in_db(list(words.keys())): for row in await self.lookup_in_db(list(words.keys())):
for trange in words[row.word_token]: for trange in words[row.word_token]:
@@ -234,8 +221,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
else: else:
query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token) query.add_token(trange, DB_TO_TOKEN_TYPE[row.type], token)
self.add_extra_tokens(query, parts) self.add_extra_tokens(query)
self.rerank_tokens(query, parts) self.rerank_tokens(query)
log().table_dump('Word tokens', _dump_word_tokens(query)) log().table_dump('Word tokens', _dump_word_tokens(query))
@@ -248,15 +235,13 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
""" """
return cast(str, self.normalizer.transliterate(text)).strip('-: ') return cast(str, self.normalizer.transliterate(text)).strip('-: ')
def split_query(self, query: qmod.QueryStruct) -> Tuple[QueryParts, WordDict]: def split_query(self, query: qmod.QueryStruct) -> WordDict:
""" Transliterate the phrases and split them into tokens. """ Transliterate the phrases and split them into tokens.
Returns the list of transliterated tokens together with their Returns a dictionary of words for lookup together
normalized form and a dictionary of words for lookup together
with their position. with their position.
""" """
parts: QueryParts = [] phrase_start = 1
phrase_start = 0
words: WordDict = defaultdict(list) words: WordDict = defaultdict(list)
for phrase in query.source: for phrase in query.source:
query.nodes[-1].ptype = phrase.ptype query.nodes[-1].ptype = phrase.ptype
@@ -272,18 +257,18 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
if trans: if trans:
for term in trans.split(' '): for term in trans.split(' '):
if term: if term:
parts.append(QueryPart(term, word, query.add_node(qmod.BREAK_TOKEN, phrase.ptype,
PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN])) PENALTY_IN_TOKEN_BREAK[qmod.BREAK_TOKEN],
query.add_node(qmod.BREAK_TOKEN, phrase.ptype) term, word)
query.nodes[-1].btype = breakchar query.nodes[-1].adjust_break(breakchar,
parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[breakchar] PENALTY_IN_TOKEN_BREAK[breakchar])
extract_words(parts, phrase_start, words) extract_words(query, phrase_start, words)
phrase_start = len(parts) phrase_start = len(query.nodes)
query.nodes[-1].btype = qmod.BREAK_END query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END])
return parts, words return words
async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]': async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
""" Return the token information from the database for the """ Return the token information from the database for the
@@ -292,18 +277,23 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
t = self.conn.t.meta.tables['word'] t = self.conn.t.meta.tables['word']
return await self.conn.execute(t.select().where(t.c.word_token.in_(words))) return await self.conn.execute(t.select().where(t.c.word_token.in_(words)))
def add_extra_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None: def add_extra_tokens(self, query: qmod.QueryStruct) -> None:
""" Add tokens to query that are not saved in the database. """ Add tokens to query that are not saved in the database.
""" """
for part, node, i in zip(parts, query.nodes, range(1000)): need_hnr = False
if len(part.token) <= 4 and part.token.isdigit()\ for i, node in enumerate(query.nodes):
and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER): is_full_token = node.btype not in (qmod.BREAK_TOKEN, qmod.BREAK_PART)
query.add_token(qmod.TokenRange(i, i+1), qmod.TOKEN_HOUSENUMBER, if need_hnr and is_full_token \
and len(node.term_normalized) <= 4 and node.term_normalized.isdigit():
query.add_token(qmod.TokenRange(i-1, i), qmod.TOKEN_HOUSENUMBER,
ICUToken(penalty=0.5, token=0, ICUToken(penalty=0.5, token=0,
count=1, addr_count=1, lookup_word=part.token, count=1, addr_count=1,
word_token=part.token, info=None)) lookup_word=node.term_lookup,
word_token=node.term_lookup, info=None))
def rerank_tokens(self, query: qmod.QueryStruct, parts: QueryParts) -> None: need_hnr = is_full_token and not node.has_tokens(i+1, qmod.TOKEN_HOUSENUMBER)
def rerank_tokens(self, query: qmod.QueryStruct) -> None:
""" Add penalties to tokens that depend on presence of other token. """ Add penalties to tokens that depend on presence of other token.
""" """
for i, node, tlist in query.iter_token_lists(): for i, node, tlist in query.iter_token_lists():
@@ -320,21 +310,15 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER: if repl.end == tlist.end and repl.ttype != qmod.TOKEN_HOUSENUMBER:
repl.add_penalty(0.5 - tlist.tokens[0].penalty) repl.add_penalty(0.5 - tlist.tokens[0].penalty)
elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL): elif tlist.ttype not in (qmod.TOKEN_COUNTRY, qmod.TOKEN_PARTIAL):
norm = parts[i].normalized norm = ' '.join(n.term_normalized for n in query.nodes[i + 1:tlist.end + 1]
for j in range(i + 1, tlist.end): if n.btype != qmod.BREAK_TOKEN)
if node.btype != qmod.BREAK_TOKEN: if not norm:
norm += ' ' + parts[j].normalized # Can happen when the token only covers a partial term
norm = query.nodes[i + 1].term_normalized
for token in tlist.tokens: for token in tlist.tokens:
cast(ICUToken, token).rematch(norm) cast(ICUToken, token).rematch(norm)
def _dump_transliterated(query: qmod.QueryStruct, parts: QueryParts) -> str:
out = query.nodes[0].btype
for node, part in zip(query.nodes[1:], parts):
out += part.token + node.btype
return out
def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]: def _dump_word_tokens(query: qmod.QueryStruct) -> Iterator[List[Any]]:
yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info'] yield ['type', 'token', 'word_token', 'lookup_word', 'penalty', 'count', 'info']
for node in query.nodes: for node in query.nodes:

View File

@@ -171,11 +171,33 @@ class TokenList:
@dataclasses.dataclass @dataclasses.dataclass
class QueryNode: class QueryNode:
""" A node of the query representing a break between terms. """ A node of the query representing a break between terms.
The node also contains information on the source term
ending at the node. The tokens are created from this information.
""" """
btype: BreakType btype: BreakType
ptype: PhraseType ptype: PhraseType
penalty: float
""" Penalty for the break at this node.
"""
term_lookup: str
""" Transliterated term following this node.
"""
term_normalized: str
""" Normalised form of term following this node.
When the token resulted from a split during transliteration,
then this string contains the complete source term.
"""
starting: List[TokenList] = dataclasses.field(default_factory=list) starting: List[TokenList] = dataclasses.field(default_factory=list)
def adjust_break(self, btype: BreakType, penalty: float) -> None:
""" Change the break type and penalty for this node.
"""
self.btype = btype
self.penalty = penalty
def has_tokens(self, end: int, *ttypes: TokenType) -> bool: def has_tokens(self, end: int, *ttypes: TokenType) -> bool:
""" Check if there are tokens of the given types ending at the """ Check if there are tokens of the given types ending at the
given node. given node.
@@ -218,19 +240,22 @@ class QueryStruct:
def __init__(self, source: List[Phrase]) -> None: def __init__(self, source: List[Phrase]) -> None:
self.source = source self.source = source
self.nodes: List[QueryNode] = \ self.nodes: List[QueryNode] = \
[QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY)] [QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY,
0.0, '', '')]
def num_token_slots(self) -> int: def num_token_slots(self) -> int:
""" Return the length of the query in vertice steps. """ Return the length of the query in vertice steps.
""" """
return len(self.nodes) - 1 return len(self.nodes) - 1
def add_node(self, btype: BreakType, ptype: PhraseType) -> None: def add_node(self, btype: BreakType, ptype: PhraseType,
break_penalty: float = 0.0,
term_lookup: str = '', term_normalized: str = '') -> None:
""" Append a new break node with the given break type. """ Append a new break node with the given break type.
The phrase type denotes the type for any tokens starting The phrase type denotes the type for any tokens starting
at the node. at the node.
""" """
self.nodes.append(QueryNode(btype, ptype)) self.nodes.append(QueryNode(btype, ptype, break_penalty, term_lookup, term_normalized))
def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None: def add_token(self, trange: TokenRange, ttype: TokenType, token: Token) -> None:
""" Add a token to the query. 'start' and 'end' are the indexes of the """ Add a token to the query. 'start' and 'end' are the indexes of the
@@ -287,3 +312,11 @@ class QueryStruct:
if t.token == token: if t.token == token:
return f"[{tlist.ttype}]{t.lookup_word}" return f"[{tlist.ttype}]{t.lookup_word}"
return 'None' return 'None'
def get_transliterated_query(self) -> str:
""" Return a string representation of the transliterated query
with the character representation of the different break types.
For debugging purposes only.
"""
return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes)

View File

@@ -21,6 +21,9 @@ def mktoken(tid: int):
return MyToken(penalty=3.0, token=tid, count=1, addr_count=1, return MyToken(penalty=3.0, token=tid, count=1, addr_count=1,
lookup_word='foo') lookup_word='foo')
@pytest.fixture
def qnode():
return query.QueryNode(query.BREAK_PHRASE, query.PHRASE_ANY, 0.0 ,'', '')
@pytest.mark.parametrize('ptype,ttype', [(query.PHRASE_ANY, 'W'), @pytest.mark.parametrize('ptype,ttype', [(query.PHRASE_ANY, 'W'),
(query.PHRASE_AMENITY, 'Q'), (query.PHRASE_AMENITY, 'Q'),
@@ -37,27 +40,24 @@ def test_phrase_incompatible(ptype):
assert not query._phrase_compatible_with(ptype, query.TOKEN_PARTIAL, True) assert not query._phrase_compatible_with(ptype, query.TOKEN_PARTIAL, True)
def test_query_node_empty(): def test_query_node_empty(qnode):
qn = query.QueryNode(query.BREAK_PHRASE, query.PHRASE_ANY) assert not qnode.has_tokens(3, query.TOKEN_PARTIAL)
assert qnode.get_tokens(3, query.TOKEN_WORD) is None
assert not qn.has_tokens(3, query.TOKEN_PARTIAL)
assert qn.get_tokens(3, query.TOKEN_WORD) is None
def test_query_node_with_content(): def test_query_node_with_content(qnode):
qn = query.QueryNode(query.BREAK_PHRASE, query.PHRASE_ANY) qnode.starting.append(query.TokenList(2, query.TOKEN_PARTIAL, [mktoken(100), mktoken(101)]))
qn.starting.append(query.TokenList(2, query.TOKEN_PARTIAL, [mktoken(100), mktoken(101)])) qnode.starting.append(query.TokenList(2, query.TOKEN_WORD, [mktoken(1000)]))
qn.starting.append(query.TokenList(2, query.TOKEN_WORD, [mktoken(1000)]))
assert not qn.has_tokens(3, query.TOKEN_PARTIAL) assert not qnode.has_tokens(3, query.TOKEN_PARTIAL)
assert not qn.has_tokens(2, query.TOKEN_COUNTRY) assert not qnode.has_tokens(2, query.TOKEN_COUNTRY)
assert qn.has_tokens(2, query.TOKEN_PARTIAL) assert qnode.has_tokens(2, query.TOKEN_PARTIAL)
assert qn.has_tokens(2, query.TOKEN_WORD) assert qnode.has_tokens(2, query.TOKEN_WORD)
assert qn.get_tokens(3, query.TOKEN_PARTIAL) is None assert qnode.get_tokens(3, query.TOKEN_PARTIAL) is None
assert qn.get_tokens(2, query.TOKEN_COUNTRY) is None assert qnode.get_tokens(2, query.TOKEN_COUNTRY) is None
assert len(qn.get_tokens(2, query.TOKEN_PARTIAL)) == 2 assert len(qnode.get_tokens(2, query.TOKEN_PARTIAL)) == 2
assert len(qn.get_tokens(2, query.TOKEN_WORD)) == 1 assert len(qnode.get_tokens(2, query.TOKEN_WORD)) == 1
def test_query_struct_empty(): def test_query_struct_empty():