rebalance word transition penalties

This commit is contained in:
Sarah Hoffmann
2025-07-09 20:35:15 +02:00
parent 4a9253a0a9
commit 4634ad0720
4 changed files with 52 additions and 36 deletions

View File

@@ -214,6 +214,19 @@ class QueryNode:
types of tokens spanning over the gap.
"""
@property
def word_break_penalty(self) -> float:
""" Penalty to apply when a words ends at this node.
"""
return max(0, self.penalty)
@property
def word_continuation_penalty(self) -> float:
""" Penalty to apply when a word continues over this node
(i.e. is a multi-term word).
"""
return max(0, -self.penalty)
def name_address_ratio(self) -> float:
""" Return the propability that the partial token belonging to
this node forms part of a name (as opposed of part of the address).
@@ -273,7 +286,8 @@ class QueryStruct:
self.source = source
self.dir_penalty = 0.0
self.nodes: List[QueryNode] = \
[QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY)]
[QueryNode(BREAK_START, source[0].ptype if source else PHRASE_ANY,
0.0, '', '')]
def num_token_slots(self) -> int:
""" Return the length of the query in vertice steps.
@@ -338,6 +352,13 @@ class QueryStruct:
assert ttype != TOKEN_PARTIAL
return self.nodes[trange.start].get_tokens(trange.end, ttype) or []
def get_in_word_penalty(self, trange: TokenRange) -> float:
""" Gets the sum of penalties for all token transitions
within the given range.
"""
return sum(n.word_continuation_penalty
for n in self.nodes[trange.start + 1:trange.end])
def iter_partials(self, trange: TokenRange) -> Iterator[Token]:
""" Iterate over the partial tokens between the given nodes.
Missing partials are ignored.