make word generation from query a class method

This commit is contained in:
Sarah Hoffmann
2025-02-26 17:22:14 +01:00
parent e362a965e1
commit 6759edfb5d
3 changed files with 53 additions and 37 deletions

View File

@@ -8,7 +8,6 @@
Implementation of query analysis for the ICU tokenizer. Implementation of query analysis for the ICU tokenizer.
""" """
from typing import Tuple, Dict, List, Optional, Iterator, Any, cast from typing import Tuple, Dict, List, Optional, Iterator, Any, cast
from collections import defaultdict
import dataclasses import dataclasses
import difflib import difflib
import re import re
@@ -47,29 +46,6 @@ PENALTY_IN_TOKEN_BREAK = {
} }
WordDict = Dict[str, List[qmod.TokenRange]]
def extract_words(query: qmod.QueryStruct, start: int, words: WordDict) -> None:
""" Add all combinations of words in the terms list starting with
the term leading into node 'start'.
The words found will be added into the 'words' dictionary with
their start and end position.
"""
nodes = query.nodes
total = len(nodes)
base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD]
for first in range(start, total):
word = nodes[first].term_lookup
penalty = base_penalty
words[word].append(qmod.TokenRange(first - 1, first, penalty=penalty))
for last in range(first + 1, min(first + 20, total)):
word = ' '.join((word, nodes[last].term_lookup))
penalty += nodes[last - 1].penalty
words[word].append(qmod.TokenRange(first - 1, last, penalty=penalty))
@dataclasses.dataclass @dataclasses.dataclass
class ICUToken(qmod.Token): class ICUToken(qmod.Token):
""" Specialised token for ICU tokenizer. """ Specialised token for ICU tokenizer.
@@ -203,8 +179,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
if not query.source: if not query.source:
return query return query
words = self.split_query(query) self.split_query(query)
log().var_dump('Transliterated query', lambda: query.get_transliterated_query()) log().var_dump('Transliterated query', lambda: query.get_transliterated_query())
words = query.extract_words(base_penalty=PENALTY_IN_TOKEN_BREAK[qmod.BREAK_WORD])
for row in await self.lookup_in_db(list(words.keys())): for row in await self.lookup_in_db(list(words.keys())):
for trange in words[row.word_token]: for trange in words[row.word_token]:
@@ -235,14 +212,9 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
""" """
return cast(str, self.normalizer.transliterate(text)).strip('-: ') return cast(str, self.normalizer.transliterate(text)).strip('-: ')
def split_query(self, query: qmod.QueryStruct) -> WordDict: def split_query(self, query: qmod.QueryStruct) -> None:
""" Transliterate the phrases and split them into tokens. """ Transliterate the phrases and split them into tokens.
Returns a dictionary of words for lookup together
with their position.
""" """
phrase_start = 1
words: WordDict = defaultdict(list)
for phrase in query.source: for phrase in query.source:
query.nodes[-1].ptype = phrase.ptype query.nodes[-1].ptype = phrase.ptype
phrase_split = re.split('([ :-])', phrase.text) phrase_split = re.split('([ :-])', phrase.text)
@@ -263,13 +235,8 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
query.nodes[-1].adjust_break(breakchar, query.nodes[-1].adjust_break(breakchar,
PENALTY_IN_TOKEN_BREAK[breakchar]) PENALTY_IN_TOKEN_BREAK[breakchar])
extract_words(query, phrase_start, words)
phrase_start = len(query.nodes)
query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END]) query.nodes[-1].adjust_break(qmod.BREAK_END, PENALTY_IN_TOKEN_BREAK[qmod.BREAK_END])
return words
async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]': async def lookup_in_db(self, words: List[str]) -> 'sa.Result[Any]':
""" Return the token information from the database for the """ Return the token information from the database for the
given word tokens. given word tokens.

View File

@@ -7,8 +7,9 @@
""" """
Datastructures for a tokenized query. Datastructures for a tokenized query.
""" """
from typing import List, Tuple, Optional, Iterator from typing import Dict, List, Tuple, Optional, Iterator
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict
import dataclasses import dataclasses
@@ -320,3 +321,34 @@ class QueryStruct:
For debugging purposes only. For debugging purposes only.
""" """
return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes) return ''.join(''.join((n.term_lookup, n.btype)) for n in self.nodes)
def extract_words(self, base_penalty: float = 0.0,
start: int = 0,
endpos: Optional[int] = None) -> Dict[str, List[TokenRange]]:
""" Add all combinations of words that can be formed from the terms
between the given start and endnode. The terms are joined with
spaces for each break. Words can never go across a BREAK_PHRASE.
The functions returns a dictionary of possible words with their
position within the query and a penalty. The penalty is computed
from the base_penalty plus the penalty for each node the word
crosses.
"""
if endpos is None:
endpos = len(self.nodes)
words: Dict[str, List[TokenRange]] = defaultdict(list)
for first in range(start, endpos - 1):
word = self.nodes[first + 1].term_lookup
penalty = base_penalty
words[word].append(TokenRange(first, first + 1, penalty=penalty))
if self.nodes[first + 1].btype != BREAK_PHRASE:
for last in range(first + 2, min(first + 20, endpos)):
word = ' '.join((word, self.nodes[last].term_lookup))
penalty += self.nodes[last - 1].penalty
words[word].append(TokenRange(first, last, penalty=penalty))
if self.nodes[last].btype == BREAK_PHRASE:
break
return words

View File

@@ -46,3 +46,20 @@ def test_token_range_unimplemented_ops():
nq.TokenRange(1, 3) <= nq.TokenRange(10, 12) nq.TokenRange(1, 3) <= nq.TokenRange(10, 12)
with pytest.raises(TypeError): with pytest.raises(TypeError):
nq.TokenRange(1, 3) >= nq.TokenRange(10, 12) nq.TokenRange(1, 3) >= nq.TokenRange(10, 12)
def test_query_extract_words():
q = nq.QueryStruct([])
q.add_node(nq.BREAK_WORD, nq.PHRASE_ANY, 0.1, '12', '')
q.add_node(nq.BREAK_TOKEN, nq.PHRASE_ANY, 0.0, 'ab', '')
q.add_node(nq.BREAK_PHRASE, nq.PHRASE_ANY, 0.0, '12', '')
q.add_node(nq.BREAK_END, nq.PHRASE_ANY, 0.5, 'hallo', '')
words = q.extract_words(base_penalty=1.0)
assert set(words.keys()) \
== {'12', 'ab', 'hallo', '12 ab', 'ab 12', '12 ab 12'}
assert sorted(words['12']) == [nq.TokenRange(0, 1, 1.0), nq.TokenRange(2, 3, 1.0)]
assert words['12 ab'] == [nq.TokenRange(0, 2, 1.1)]
assert words['hallo'] == [nq.TokenRange(3, 4, 1.0)]