avoid yielding when extracting words from query

This commit is contained in:
Sarah Hoffmann
2025-02-20 23:32:39 +01:00
parent abc911079e
commit b56edf3d0a

View File

@@ -67,19 +67,20 @@ QueryParts = List[QueryPart]
WordDict = Dict[str, List[qmod.TokenRange]] WordDict = Dict[str, List[qmod.TokenRange]]
def yield_words(terms: List[QueryPart], start: int) -> Iterator[Tuple[str, qmod.TokenRange]]: def extract_words(terms: List[QueryPart], start: int, words: WordDict) -> None:
""" Return all combinations of words in the terms list after the """ Add all combinations of words in the terms list after the
given position. given position to the word list.
""" """
total = len(terms) total = len(terms)
base_penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType.WORD]
for first in range(start, total): for first in range(start, total):
word = terms[first].token word = terms[first].token
penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType.WORD] penalty = base_penalty
yield word, qmod.TokenRange(first, first + 1, penalty=penalty) words[word].append(qmod.TokenRange(first, first + 1, penalty=penalty))
for last in range(first + 1, min(first + 20, total)): for last in range(first + 1, min(first + 20, total)):
word = ' '.join((word, terms[last].token)) word = ' '.join((word, terms[last].token))
penalty += terms[last - 1].penalty penalty += terms[last - 1].penalty
yield word, qmod.TokenRange(first, last + 1, penalty=penalty) words[word].append(qmod.TokenRange(first, last + 1, penalty=penalty))
@dataclasses.dataclass @dataclasses.dataclass
@@ -256,7 +257,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
""" """
parts: QueryParts = [] parts: QueryParts = []
phrase_start = 0 phrase_start = 0
words = defaultdict(list) words: WordDict = defaultdict(list)
for phrase in query.source: for phrase in query.source:
query.nodes[-1].ptype = phrase.ptype query.nodes[-1].ptype = phrase.ptype
phrase_split = re.split('([ :-])', phrase.text) phrase_split = re.split('([ :-])', phrase.text)
@@ -277,8 +278,7 @@ class ICUQueryAnalyzer(AbstractQueryAnalyzer):
query.nodes[-1].btype = qmod.BreakType(breakchar) query.nodes[-1].btype = qmod.BreakType(breakchar)
parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType(breakchar)] parts[-1].penalty = PENALTY_IN_TOKEN_BREAK[qmod.BreakType(breakchar)]
for word, wrange in yield_words(parts, phrase_start): extract_words(parts, phrase_start, words)
words[word].append(wrange)
phrase_start = len(parts) phrase_start = len(parts)
query.nodes[-1].btype = qmod.BreakType.END query.nodes[-1].btype = qmod.BreakType.END