Merge pull request #3898 from lonvia/fix-country-restriction

Fix comparision between country tokens and country restriction
This commit is contained in:
Sarah Hoffmann
2025-12-04 20:03:14 +01:00
committed by GitHub
2 changed files with 34 additions and 14 deletions

View File

@@ -413,7 +413,7 @@ class SearchBuilder:
""" """
tokens = self.query.get_tokens(trange, qmod.TOKEN_COUNTRY) tokens = self.query.get_tokens(trange, qmod.TOKEN_COUNTRY)
if self.details.countries: if self.details.countries:
tokens = [t for t in tokens if t.lookup_word in self.details.countries] tokens = [t for t in tokens if t.get_country() in self.details.countries]
return tokens return tokens

View File

@@ -2,12 +2,14 @@
# #
# This file is part of Nominatim. (https://nominatim.org) # This file is part of Nominatim. (https://nominatim.org)
# #
# Copyright (C) 2023 by the Nominatim developer community. # Copyright (C) 2025 by the Nominatim developer community.
# For a full list of authors see the git log. # For a full list of authors see the git log.
""" """
Tests for creating abstract searches from token assignments. Tests for creating abstract searches from token assignments.
""" """
from typing import Optional
import pytest import pytest
import dataclasses
from nominatim_api.search.query import Token, TokenRange, QueryStruct, Phrase from nominatim_api.search.query import Token, TokenRange, QueryStruct, Phrase
import nominatim_api.search.query as qmod import nominatim_api.search.query as qmod
@@ -17,12 +19,15 @@ from nominatim_api.types import SearchDetails
import nominatim_api.search.db_searches as dbs import nominatim_api.search.db_searches as dbs
@dataclasses.dataclass
class MyToken(Token): class MyToken(Token):
cc: Optional[str] = None
def get_category(self): def get_category(self):
return 'this', 'that' return 'this', 'that'
def get_country(self): def get_country(self):
return self.lookup_word return self.cc
def make_query(*args): def make_query(*args):
@@ -33,18 +38,24 @@ def make_query(*args):
q.add_node(qmod.BREAK_END, qmod.PHRASE_ANY) q.add_node(qmod.BREAK_END, qmod.PHRASE_ANY)
for start, tlist in enumerate(args): for start, tlist in enumerate(args):
for end, ttype, tinfo in tlist: for end, ttype, tinfos in tlist:
for tid, word in tinfo: for tinfo in tinfos:
if isinstance(tinfo, tuple):
q.add_token(TokenRange(start, end), ttype, q.add_token(TokenRange(start, end), ttype,
MyToken(penalty=0.5 if ttype == qmod.TOKEN_PARTIAL else 0.0, MyToken(penalty=0.5 if ttype == qmod.TOKEN_PARTIAL else 0.0,
token=tid, count=1, addr_count=1, token=tinfo[0], count=1, addr_count=1,
lookup_word=word)) lookup_word=tinfo[1]))
else:
q.add_token(TokenRange(start, end), ttype, tinfo)
return q return q
def test_country_search(): def test_country_search():
q = make_query([(1, qmod.TOKEN_COUNTRY, [(2, 'de'), (3, 'en')])]) q = make_query([(1, qmod.TOKEN_COUNTRY, [
MyToken(penalty=0.0, token=2, count=1, addr_count=1, lookup_word='Germany', cc='de'),
MyToken(penalty=0.0, token=3, count=1, addr_count=1, lookup_word='UK', cc='en'),
])])
builder = SearchBuilder(q, SearchDetails()) builder = SearchBuilder(q, SearchDetails())
searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1)))) searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1))))
@@ -58,7 +69,10 @@ def test_country_search():
def test_country_search_with_country_restriction(): def test_country_search_with_country_restriction():
q = make_query([(1, qmod.TOKEN_COUNTRY, [(2, 'de'), (3, 'en')])]) q = make_query([(1, qmod.TOKEN_COUNTRY, [
MyToken(penalty=0.0, token=2, count=1, addr_count=1, lookup_word='Germany', cc='de'),
MyToken(penalty=0.0, token=3, count=1, addr_count=1, lookup_word='UK', cc='en'),
])])
builder = SearchBuilder(q, SearchDetails.from_kwargs({'countries': 'en,fr'})) builder = SearchBuilder(q, SearchDetails.from_kwargs({'countries': 'en,fr'}))
searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1)))) searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1))))
@@ -72,7 +86,10 @@ def test_country_search_with_country_restriction():
def test_country_search_with_conflicting_country_restriction(): def test_country_search_with_conflicting_country_restriction():
q = make_query([(1, qmod.TOKEN_COUNTRY, [(2, 'de'), (3, 'en')])]) q = make_query([(1, qmod.TOKEN_COUNTRY, [
MyToken(penalty=0.0, token=2, count=1, addr_count=1, lookup_word='Germany', cc='de'),
MyToken(penalty=0.0, token=3, count=1, addr_count=1, lookup_word='UK', cc='en'),
])])
builder = SearchBuilder(q, SearchDetails.from_kwargs({'countries': 'fr'})) builder = SearchBuilder(q, SearchDetails.from_kwargs({'countries': 'fr'}))
searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1)))) searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1))))
@@ -97,8 +114,11 @@ def test_postcode_search_simple():
def test_postcode_with_country(): def test_postcode_with_country():
q = make_query([(1, qmod.TOKEN_POSTCODE, [(34, '2367')])], q = make_query(
[(2, qmod.TOKEN_COUNTRY, [(1, 'xx')])]) [(1, qmod.TOKEN_POSTCODE, [(34, '2367')])],
[(2, qmod.TOKEN_COUNTRY, [
MyToken(penalty=0.0, token=1, count=1, addr_count=1, lookup_word='none', cc='xx'),
])])
builder = SearchBuilder(q, SearchDetails()) builder = SearchBuilder(q, SearchDetails())
searches = list(builder.build(TokenAssignment(postcode=TokenRange(0, 1), searches = list(builder.build(TokenAssignment(postcode=TokenRange(0, 1),