drop category tokens when they make up a full phrase

This commit is contained in:
Sarah Hoffmann
2023-11-26 20:58:50 +01:00
parent a8b023e57e
commit a7f5c6c8f5
4 changed files with 56 additions and 26 deletions

View File

@@ -70,14 +70,16 @@ class PhraseType(enum.Enum):
COUNTRY = enum.auto() COUNTRY = enum.auto()
""" Contains the country name or code. """ """ Contains the country name or code. """
def compatible_with(self, ttype: TokenType) -> bool: def compatible_with(self, ttype: TokenType,
is_full_phrase: bool) -> bool:
""" Check if the given token type can be used with the phrase type. """ Check if the given token type can be used with the phrase type.
""" """
if self == PhraseType.NONE: if self == PhraseType.NONE:
return True return not is_full_phrase or ttype != TokenType.QUALIFIER
if self == PhraseType.AMENITY: if self == PhraseType.AMENITY:
return ttype in (TokenType.WORD, TokenType.PARTIAL, return ttype in (TokenType.WORD, TokenType.PARTIAL)\
TokenType.QUALIFIER, TokenType.CATEGORY) or (is_full_phrase and ttype == TokenType.CATEGORY)\
or (not is_full_phrase and ttype == TokenType.QUALIFIER)
if self == PhraseType.STREET: if self == PhraseType.STREET:
return ttype in (TokenType.WORD, TokenType.PARTIAL, TokenType.HOUSENUMBER) return ttype in (TokenType.WORD, TokenType.PARTIAL, TokenType.HOUSENUMBER)
if self == PhraseType.POSTCODE: if self == PhraseType.POSTCODE:
@@ -244,7 +246,9 @@ class QueryStruct:
be added to, then the token is silently dropped. be added to, then the token is silently dropped.
""" """
snode = self.nodes[trange.start] snode = self.nodes[trange.start]
if snode.ptype.compatible_with(ttype): full_phrase = snode.btype in (BreakType.START, BreakType.PHRASE)\
and self.nodes[trange.end].btype in (BreakType.PHRASE, BreakType.END)
if snode.ptype.compatible_with(ttype, full_phrase):
tlist = snode.get_tokens(trange.end, ttype) tlist = snode.get_tokens(trange.end, ttype)
if tlist is None: if tlist is None:
snode.starting.append(TokenList(trange.end, ttype, [token])) snode.starting.append(TokenList(trange.end, ttype, [token]))

View File

@@ -28,12 +28,12 @@ def mktoken(tid: int):
('COUNTRY', 'COUNTRY'), ('COUNTRY', 'COUNTRY'),
('POSTCODE', 'POSTCODE')]) ('POSTCODE', 'POSTCODE')])
def test_phrase_compatible(ptype, ttype): def test_phrase_compatible(ptype, ttype):
assert query.PhraseType[ptype].compatible_with(query.TokenType[ttype]) assert query.PhraseType[ptype].compatible_with(query.TokenType[ttype], False)
@pytest.mark.parametrize('ptype', ['COUNTRY', 'POSTCODE']) @pytest.mark.parametrize('ptype', ['COUNTRY', 'POSTCODE'])
def test_phrase_incompatible(ptype): def test_phrase_incompatible(ptype):
assert not query.PhraseType[ptype].compatible_with(query.TokenType.PARTIAL) assert not query.PhraseType[ptype].compatible_with(query.TokenType.PARTIAL, True)
def test_query_node_empty(): def test_query_node_empty():
@@ -99,3 +99,36 @@ def test_query_struct_incompatible_token():
assert q.get_tokens(query.TokenRange(0, 1), query.TokenType.PARTIAL) == [] assert q.get_tokens(query.TokenRange(0, 1), query.TokenType.PARTIAL) == []
assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.COUNTRY)) == 1 assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.COUNTRY)) == 1
def test_query_struct_amenity_single_word():
q = query.QueryStruct([query.Phrase(query.PhraseType.AMENITY, 'bar')])
q.add_node(query.BreakType.END, query.PhraseType.NONE)
q.add_token(query.TokenRange(0, 1), query.TokenType.PARTIAL, mktoken(1))
q.add_token(query.TokenRange(0, 1), query.TokenType.CATEGORY, mktoken(2))
q.add_token(query.TokenRange(0, 1), query.TokenType.QUALIFIER, mktoken(3))
assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.PARTIAL)) == 1
assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.CATEGORY)) == 1
assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.QUALIFIER)) == 0
def test_query_struct_amenity_two_words():
q = query.QueryStruct([query.Phrase(query.PhraseType.AMENITY, 'foo bar')])
q.add_node(query.BreakType.WORD, query.PhraseType.AMENITY)
q.add_node(query.BreakType.END, query.PhraseType.NONE)
for trange in [(0, 1), (1, 2)]:
q.add_token(query.TokenRange(*trange), query.TokenType.PARTIAL, mktoken(1))
q.add_token(query.TokenRange(*trange), query.TokenType.CATEGORY, mktoken(2))
q.add_token(query.TokenRange(*trange), query.TokenType.QUALIFIER, mktoken(3))
assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.PARTIAL)) == 1
assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.CATEGORY)) == 0
assert len(q.get_tokens(query.TokenRange(0, 1), query.TokenType.QUALIFIER)) == 1
assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.PARTIAL)) == 1
assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.CATEGORY)) == 0
assert len(q.get_tokens(query.TokenRange(1, 2), query.TokenType.QUALIFIER)) == 1

View File

@@ -21,21 +21,18 @@ class MyToken(Token):
def make_query(*args): def make_query(*args):
q = None q = QueryStruct([Phrase(PhraseType.NONE, '')])
for tlist in args: for _ in range(max(inner[0] for tlist in args for inner in tlist)):
if q is None: q.add_node(BreakType.WORD, PhraseType.NONE)
q = QueryStruct([Phrase(PhraseType.NONE, '')]) q.add_node(BreakType.END, PhraseType.NONE)
else:
q.add_node(BreakType.WORD, PhraseType.NONE)
start = len(q.nodes) - 1 for start, tlist in enumerate(args):
for end, ttype, tinfo in tlist: for end, ttype, tinfo in tlist:
for tid, word in tinfo: for tid, word in tinfo:
q.add_token(TokenRange(start, end), ttype, q.add_token(TokenRange(start, end), ttype,
MyToken(0.5 if ttype == TokenType.PARTIAL else 0.0, tid, 1, word, True)) MyToken(0.5 if ttype == TokenType.PARTIAL else 0.0, tid, 1, word, True))
q.add_node(BreakType.END, PhraseType.NONE)
return q return q

View File

@@ -18,21 +18,17 @@ class MyToken(Token):
def make_query(*args): def make_query(*args):
q = None q = QueryStruct([Phrase(args[0][1], '')])
dummy = MyToken(3.0, 45, 1, 'foo', True) dummy = MyToken(3.0, 45, 1, 'foo', True)
for btype, ptype, tlist in args: for btype, ptype, _ in args[1:]:
if q is None: q.add_node(btype, ptype)
q = QueryStruct([Phrase(ptype, '')])
else:
q.add_node(btype, ptype)
start = len(q.nodes) - 1
for end, ttype in tlist:
q.add_token(TokenRange(start, end), ttype, dummy)
q.add_node(BreakType.END, PhraseType.NONE) q.add_node(BreakType.END, PhraseType.NONE)
for start, t in enumerate(args):
for end, ttype in t[2]:
q.add_token(TokenRange(start, end), ttype, dummy)
return q return q