do not split names from typed phrases

When phrases are typed, they should only contain exactly one term.
This commit is contained in:
Sarah Hoffmann
2023-07-17 16:25:39 +02:00
parent 7f9cb4e68d
commit 927d2cc824
3 changed files with 112 additions and 28 deletions

View File

@@ -7,7 +7,7 @@
""" """
Datastructures for a tokenized query. Datastructures for a tokenized query.
""" """
from typing import List, Tuple, Optional, NamedTuple, Iterator from typing import List, Tuple, Optional, Iterator
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import dataclasses import dataclasses
import enum import enum
@@ -107,13 +107,29 @@ class Token(ABC):
category objects. category objects.
""" """
@dataclasses.dataclass
class TokenRange(NamedTuple): class TokenRange:
""" Indexes of query nodes over which a token spans. """ Indexes of query nodes over which a token spans.
""" """
start: int start: int
end: int end: int
def __lt__(self, other: 'TokenRange') -> bool:
return self.end <= other.start
def __le__(self, other: 'TokenRange') -> bool:
return NotImplemented
def __gt__(self, other: 'TokenRange') -> bool:
return self.start >= other.end
def __ge__(self, other: 'TokenRange') -> bool:
return NotImplemented
def replace_start(self, new_start: int) -> 'TokenRange': def replace_start(self, new_start: int) -> 'TokenRange':
""" Return a new token range with the new start. """ Return a new token range with the new start.
""" """

View File

@@ -288,18 +288,29 @@ class _TokenSequence:
yield dataclasses.replace(base, penalty=self.penalty, yield dataclasses.replace(base, penalty=self.penalty,
name=first, address=base.address[1:]) name=first, address=base.address[1:])
if (not base.housenumber or first.end >= base.housenumber.start)\ # To paraphrase:
and (not base.qualifier or first.start >= base.qualifier.end): # * if another name term comes after the first one and before the
base_penalty = self.penalty # housenumber
if (base.housenumber and base.housenumber.start > first.start) \ # * a qualifier comes after the name
or len(query.source) > 1: # * the containing phrase is strictly typed
base_penalty += 0.25 if (base.housenumber and first.end < base.housenumber.start)\
for i in range(first.start + 1, first.end): or (base.qualifier and base.qualifier > first)\
name, addr = first.split(i) or (query.nodes[first.start].ptype != qmod.PhraseType.NONE):
penalty = base_penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype] return
log().comment(f'split first word = name ({i - first.start})')
yield dataclasses.replace(base, name=name, penalty=penalty, penalty = self.penalty
address=[addr] + base.address[1:])
# Penalty for:
# * <name>, <street>, <housenumber> , ...
# * queries that are comma-separated
if (base.housenumber and base.housenumber > first) or len(query.source) > 1:
penalty += 0.25
for i in range(first.start + 1, first.end):
name, addr = first.split(i)
log().comment(f'split first word = name ({i - first.start})')
yield dataclasses.replace(base, name=name, address=[addr] + base.address[1:],
penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
def _get_assignments_address_backward(self, base: TokenAssignment, def _get_assignments_address_backward(self, base: TokenAssignment,
@@ -314,19 +325,27 @@ class _TokenSequence:
yield dataclasses.replace(base, penalty=self.penalty, yield dataclasses.replace(base, penalty=self.penalty,
name=last, address=base.address[:-1]) name=last, address=base.address[:-1])
if (not base.housenumber or last.start <= base.housenumber.end)\ # To paraphrase:
and (not base.qualifier or last.end <= base.qualifier.start): # * if another name term comes before the last one and after the
base_penalty = self.penalty # housenumber
if base.housenumber and base.housenumber.start < last.start: # * a qualifier comes before the name
base_penalty += 0.4 # * the containing phrase is strictly typed
if len(query.source) > 1: if (base.housenumber and last.start > base.housenumber.end)\
base_penalty += 0.25 or (base.qualifier and base.qualifier < last)\
for i in range(last.start + 1, last.end): or (query.nodes[last.start].ptype != qmod.PhraseType.NONE):
addr, name = last.split(i) return
penalty = base_penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype]
log().comment(f'split last word = name ({i - last.start})') penalty = self.penalty
yield dataclasses.replace(base, name=name, penalty=penalty, if base.housenumber and base.housenumber < last:
address=base.address[:-1] + [addr]) penalty += 0.4
if len(query.source) > 1:
penalty += 0.25
for i in range(last.start + 1, last.end):
addr, name = last.split(i)
log().comment(f'split last word = name ({i - last.start})')
yield dataclasses.replace(base, name=name, address=base.address[:-1] + [addr],
penalty=penalty + PENALTY_TOKENCHANGE[query.nodes[i].btype])
def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]: def get_assignments(self, query: qmod.QueryStruct) -> Iterator[TokenAssignment]:

View File

@@ -0,0 +1,49 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2023 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Test data types for search queries.
"""
import pytest
import nominatim.api.search.query as nq
def test_token_range_equal():
assert nq.TokenRange(2, 3) == nq.TokenRange(2, 3)
assert not (nq.TokenRange(2, 3) != nq.TokenRange(2, 3))
@pytest.mark.parametrize('lop,rop', [((1, 2), (3, 4)),
((3, 4), (3, 5)),
((10, 12), (11, 12))])
def test_token_range_unequal(lop, rop):
assert not (nq.TokenRange(*lop) == nq.TokenRange(*rop))
assert nq.TokenRange(*lop) != nq.TokenRange(*rop)
def test_token_range_lt():
assert nq.TokenRange(1, 3) < nq.TokenRange(10, 12)
assert nq.TokenRange(5, 6) < nq.TokenRange(7, 8)
assert nq.TokenRange(1, 4) < nq.TokenRange(4, 5)
assert not(nq.TokenRange(5, 6) < nq.TokenRange(5, 6))
assert not(nq.TokenRange(10, 11) < nq.TokenRange(4, 5))
def test_token_rankge_gt():
assert nq.TokenRange(3, 4) > nq.TokenRange(1, 2)
assert nq.TokenRange(100, 200) > nq.TokenRange(10, 11)
assert nq.TokenRange(10, 11) > nq.TokenRange(4, 10)
assert not(nq.TokenRange(5, 6) > nq.TokenRange(5, 6))
assert not(nq.TokenRange(1, 2) > nq.TokenRange(3, 4))
assert not(nq.TokenRange(4, 10) > nq.TokenRange(3, 5))
def test_token_range_unimplemented_ops():
with pytest.raises(TypeError):
nq.TokenRange(1, 3) <= nq.TokenRange(10, 12)
with pytest.raises(TypeError):
nq.TokenRange(1, 3) >= nq.TokenRange(10, 12)