fix comparision between countr tokens and country restriction

This commit is contained in:
Sarah Hoffmann
2025-12-04 18:28:04 +01:00
parent 6c8869439f
commit ffd5c32f17
2 changed files with 34 additions and 14 deletions

View File

@@ -413,7 +413,7 @@ class SearchBuilder:
"""
tokens = self.query.get_tokens(trange, qmod.TOKEN_COUNTRY)
if self.details.countries:
tokens = [t for t in tokens if t.lookup_word in self.details.countries]
tokens = [t for t in tokens if t.get_country() in self.details.countries]
return tokens

View File

@@ -2,12 +2,14 @@
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2023 by the Nominatim developer community.
# Copyright (C) 2025 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Tests for creating abstract searches from token assignments.
"""
from typing import Optional
import pytest
import dataclasses
from nominatim_api.search.query import Token, TokenRange, QueryStruct, Phrase
import nominatim_api.search.query as qmod
@@ -17,12 +19,15 @@ from nominatim_api.types import SearchDetails
import nominatim_api.search.db_searches as dbs
@dataclasses.dataclass
class MyToken(Token):
cc: Optional[str] = None
def get_category(self):
return 'this', 'that'
def get_country(self):
return self.lookup_word
return self.cc
def make_query(*args):
@@ -33,18 +38,24 @@ def make_query(*args):
q.add_node(qmod.BREAK_END, qmod.PHRASE_ANY)
for start, tlist in enumerate(args):
for end, ttype, tinfo in tlist:
for tid, word in tinfo:
q.add_token(TokenRange(start, end), ttype,
MyToken(penalty=0.5 if ttype == qmod.TOKEN_PARTIAL else 0.0,
token=tid, count=1, addr_count=1,
lookup_word=word))
for end, ttype, tinfos in tlist:
for tinfo in tinfos:
if isinstance(tinfo, tuple):
q.add_token(TokenRange(start, end), ttype,
MyToken(penalty=0.5 if ttype == qmod.TOKEN_PARTIAL else 0.0,
token=tinfo[0], count=1, addr_count=1,
lookup_word=tinfo[1]))
else:
q.add_token(TokenRange(start, end), ttype, tinfo)
return q
def test_country_search():
q = make_query([(1, qmod.TOKEN_COUNTRY, [(2, 'de'), (3, 'en')])])
q = make_query([(1, qmod.TOKEN_COUNTRY, [
MyToken(penalty=0.0, token=2, count=1, addr_count=1, lookup_word='Germany', cc='de'),
MyToken(penalty=0.0, token=3, count=1, addr_count=1, lookup_word='UK', cc='en'),
])])
builder = SearchBuilder(q, SearchDetails())
searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1))))
@@ -58,7 +69,10 @@ def test_country_search():
def test_country_search_with_country_restriction():
q = make_query([(1, qmod.TOKEN_COUNTRY, [(2, 'de'), (3, 'en')])])
q = make_query([(1, qmod.TOKEN_COUNTRY, [
MyToken(penalty=0.0, token=2, count=1, addr_count=1, lookup_word='Germany', cc='de'),
MyToken(penalty=0.0, token=3, count=1, addr_count=1, lookup_word='UK', cc='en'),
])])
builder = SearchBuilder(q, SearchDetails.from_kwargs({'countries': 'en,fr'}))
searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1))))
@@ -72,7 +86,10 @@ def test_country_search_with_country_restriction():
def test_country_search_with_conflicting_country_restriction():
q = make_query([(1, qmod.TOKEN_COUNTRY, [(2, 'de'), (3, 'en')])])
q = make_query([(1, qmod.TOKEN_COUNTRY, [
MyToken(penalty=0.0, token=2, count=1, addr_count=1, lookup_word='Germany', cc='de'),
MyToken(penalty=0.0, token=3, count=1, addr_count=1, lookup_word='UK', cc='en'),
])])
builder = SearchBuilder(q, SearchDetails.from_kwargs({'countries': 'fr'}))
searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1))))
@@ -97,8 +114,11 @@ def test_postcode_search_simple():
def test_postcode_with_country():
q = make_query([(1, qmod.TOKEN_POSTCODE, [(34, '2367')])],
[(2, qmod.TOKEN_COUNTRY, [(1, 'xx')])])
q = make_query(
[(1, qmod.TOKEN_POSTCODE, [(34, '2367')])],
[(2, qmod.TOKEN_COUNTRY, [
MyToken(penalty=0.0, token=1, count=1, addr_count=1, lookup_word='none', cc='xx'),
])])
builder = SearchBuilder(q, SearchDetails())
searches = list(builder.build(TokenAssignment(postcode=TokenRange(0, 1),