enable all API tests for sqlite and port missing features

This commit is contained in:
Sarah Hoffmann
2023-12-06 20:56:21 +01:00
parent 0d840c8d4e
commit 6d39563b87
15 changed files with 514 additions and 230 deletions

View File

@@ -19,6 +19,7 @@ import sqlalchemy.ext.asyncio as sa_asyncio
from nominatim.errors import UsageError from nominatim.errors import UsageError
from nominatim.db.sqlalchemy_schema import SearchTables from nominatim.db.sqlalchemy_schema import SearchTables
from nominatim.db.async_core_library import PGCORE_LIB, PGCORE_ERROR from nominatim.db.async_core_library import PGCORE_LIB, PGCORE_ERROR
import nominatim.db.sqlite_functions
from nominatim.config import Configuration from nominatim.config import Configuration
from nominatim.api.connection import SearchConnection from nominatim.api.connection import SearchConnection
from nominatim.api.status import get_status, StatusResult from nominatim.api.status import get_status, StatusResult
@@ -122,6 +123,7 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
@sa.event.listens_for(engine.sync_engine, "connect") @sa.event.listens_for(engine.sync_engine, "connect")
def _on_sqlite_connect(dbapi_con: Any, _: Any) -> None: def _on_sqlite_connect(dbapi_con: Any, _: Any) -> None:
dbapi_con.run_async(lambda conn: conn.enable_load_extension(True)) dbapi_con.run_async(lambda conn: conn.enable_load_extension(True))
nominatim.db.sqlite_functions.install_custom_functions(dbapi_con)
cursor = dbapi_con.cursor() cursor = dbapi_con.cursor()
cursor.execute("SELECT load_extension('mod_spatialite')") cursor.execute("SELECT load_extension('mod_spatialite')")
cursor.execute('SELECT SetDecimalPrecision(7)') cursor.execute('SELECT SetDecimalPrecision(7)')

View File

@@ -26,18 +26,38 @@ class LookupAll(LookupType):
inherit_cache = True inherit_cache = True
def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None: def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
super().__init__(getattr(table.c, column), super().__init__(table.c.place_id, getattr(table.c, column), column,
sa.type_coerce(tokens, IntArray)) sa.type_coerce(tokens, IntArray))
@compiles(LookupAll) # type: ignore[no-untyped-call, misc] @compiles(LookupAll) # type: ignore[no-untyped-call, misc]
def _default_lookup_all(element: LookupAll, def _default_lookup_all(element: LookupAll,
compiler: 'sa.Compiled', **kw: Any) -> str: compiler: 'sa.Compiled', **kw: Any) -> str:
col, tokens = list(element.clauses) _, col, _, tokens = list(element.clauses)
return "(%s @> %s)" % (compiler.process(col, **kw), return "(%s @> %s)" % (compiler.process(col, **kw),
compiler.process(tokens, **kw)) compiler.process(tokens, **kw))
@compiles(LookupAll, 'sqlite') # type: ignore[no-untyped-call, misc]
def _sqlite_lookup_all(element: LookupAll,
compiler: 'sa.Compiled', **kw: Any) -> str:
place, col, colname, tokens = list(element.clauses)
return "(%s IN (SELECT CAST(value as bigint) FROM"\
" (SELECT array_intersect_fuzzy(places) as p FROM"\
" (SELECT places FROM reverse_search_name"\
" WHERE word IN (SELECT value FROM json_each('[' || %s || ']'))"\
" AND column = %s"\
" ORDER BY length(places)) as x) as u,"\
" json_each('[' || u.p || ']'))"\
" AND array_contains(%s, %s))"\
% (compiler.process(place, **kw),
compiler.process(tokens, **kw),
compiler.process(colname, **kw),
compiler.process(col, **kw),
compiler.process(tokens, **kw)
)
class LookupAny(LookupType): class LookupAny(LookupType):
""" Find all entries that contain at least one of the given tokens. """ Find all entries that contain at least one of the given tokens.
@@ -46,17 +66,28 @@ class LookupAny(LookupType):
inherit_cache = True inherit_cache = True
def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None: def __init__(self, table: SaFromClause, column: str, tokens: List[int]) -> None:
super().__init__(getattr(table.c, column), super().__init__(table.c.place_id, getattr(table.c, column), column,
sa.type_coerce(tokens, IntArray)) sa.type_coerce(tokens, IntArray))
@compiles(LookupAny) # type: ignore[no-untyped-call, misc] @compiles(LookupAny) # type: ignore[no-untyped-call, misc]
def _default_lookup_any(element: LookupAny, def _default_lookup_any(element: LookupAny,
compiler: 'sa.Compiled', **kw: Any) -> str: compiler: 'sa.Compiled', **kw: Any) -> str:
col, tokens = list(element.clauses) _, col, _, tokens = list(element.clauses)
return "(%s && %s)" % (compiler.process(col, **kw), return "(%s && %s)" % (compiler.process(col, **kw),
compiler.process(tokens, **kw)) compiler.process(tokens, **kw))
@compiles(LookupAny, 'sqlite') # type: ignore[no-untyped-call, misc]
def _sqlite_lookup_any(element: LookupAny,
compiler: 'sa.Compiled', **kw: Any) -> str:
place, _, colname, tokens = list(element.clauses)
return "%s IN (SELECT CAST(value as bigint) FROM"\
" (SELECT array_union(places) as p FROM reverse_search_name"\
" WHERE word IN (SELECT value FROM json_each('[' || %s || ']'))"\
" AND column = %s) as u,"\
" json_each('[' || u.p || ']'))" % (compiler.process(place, **kw),
compiler.process(tokens, **kw),
compiler.process(colname, **kw))
class Restrict(LookupType): class Restrict(LookupType):
@@ -76,3 +107,8 @@ def _default_restrict(element: Restrict,
arg1, arg2 = list(element.clauses) arg1, arg2 = list(element.clauses)
return "(coalesce(null, %s) @> %s)" % (compiler.process(arg1, **kw), return "(coalesce(null, %s) @> %s)" % (compiler.process(arg1, **kw),
compiler.process(arg2, **kw)) compiler.process(arg2, **kw))
@compiles(Restrict, 'sqlite') # type: ignore[no-untyped-call, misc]
def _sqlite_restrict(element: Restrict,
compiler: 'sa.Compiled', **kw: Any) -> str:
return "array_contains(%s)" % compiler.process(element.clauses, **kw)

View File

@@ -11,7 +11,6 @@ from typing import List, Tuple, AsyncIterator, Dict, Any, Callable
import abc import abc
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import array_agg
from nominatim.typing import SaFromClause, SaScalarSelect, SaColumn, \ from nominatim.typing import SaFromClause, SaScalarSelect, SaColumn, \
SaExpression, SaSelect, SaLambdaSelect, SaRow, SaBind SaExpression, SaSelect, SaLambdaSelect, SaRow, SaBind
@@ -19,7 +18,7 @@ from nominatim.api.connection import SearchConnection
from nominatim.api.types import SearchDetails, DataLayer, GeometryFormat, Bbox from nominatim.api.types import SearchDetails, DataLayer, GeometryFormat, Bbox
import nominatim.api.results as nres import nominatim.api.results as nres
from nominatim.api.search.db_search_fields import SearchData, WeightedCategories from nominatim.api.search.db_search_fields import SearchData, WeightedCategories
from nominatim.db.sqlalchemy_types import Geometry from nominatim.db.sqlalchemy_types import Geometry, IntArray
#pylint: disable=singleton-comparison,not-callable #pylint: disable=singleton-comparison,not-callable
#pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements #pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements
@@ -110,7 +109,7 @@ def _add_geometry_columns(sql: SaLambdaSelect, col: SaColumn, details: SearchDet
def _make_interpolation_subquery(table: SaFromClause, inner: SaFromClause, def _make_interpolation_subquery(table: SaFromClause, inner: SaFromClause,
numerals: List[int], details: SearchDetails) -> SaScalarSelect: numerals: List[int], details: SearchDetails) -> SaScalarSelect:
all_ids = array_agg(table.c.place_id) # type: ignore[no-untyped-call] all_ids = sa.func.ArrayAgg(table.c.place_id)
sql = sa.select(all_ids).where(table.c.parent_place_id == inner.c.place_id) sql = sa.select(all_ids).where(table.c.parent_place_id == inner.c.place_id)
if len(numerals) == 1: if len(numerals) == 1:
@@ -134,9 +133,7 @@ def _filter_by_layer(table: SaFromClause, layers: DataLayer) -> SaColumn:
orexpr.append(no_index(table.c.rank_address).between(1, 30)) orexpr.append(no_index(table.c.rank_address).between(1, 30))
elif layers & DataLayer.ADDRESS: elif layers & DataLayer.ADDRESS:
orexpr.append(no_index(table.c.rank_address).between(1, 29)) orexpr.append(no_index(table.c.rank_address).between(1, 29))
orexpr.append(sa.and_(no_index(table.c.rank_address) == 30, orexpr.append(sa.func.IsAddressPoint(table))
sa.or_(table.c.housenumber != None,
table.c.address.has_key('addr:housename'))))
elif layers & DataLayer.POI: elif layers & DataLayer.POI:
orexpr.append(sa.and_(no_index(table.c.rank_address) == 30, orexpr.append(sa.and_(no_index(table.c.rank_address) == 30,
table.c.class_.not_in(('place', 'building')))) table.c.class_.not_in(('place', 'building'))))
@@ -188,12 +185,21 @@ async def _get_placex_housenumbers(conn: SearchConnection,
yield result yield result
def _int_list_to_subquery(inp: List[int]) -> 'sa.Subquery':
""" Create a subselect that returns the given list of integers
as rows in the column 'nr'.
"""
vtab = sa.func.JsonArrayEach(sa.type_coerce(inp, sa.JSON))\
.table_valued(sa.column('value', type_=sa.JSON)) # type: ignore[no-untyped-call]
return sa.select(sa.cast(sa.cast(vtab.c.value, sa.Text), sa.Integer).label('nr')).subquery()
async def _get_osmline(conn: SearchConnection, place_ids: List[int], async def _get_osmline(conn: SearchConnection, place_ids: List[int],
numerals: List[int], numerals: List[int],
details: SearchDetails) -> AsyncIterator[nres.SearchResult]: details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
t = conn.t.osmline t = conn.t.osmline
values = sa.values(sa.Column('nr', sa.Integer()), name='housenumber')\
.data([(n,) for n in numerals]) values = _int_list_to_subquery(numerals)
sql = sa.select(t.c.place_id, t.c.osm_id, sql = sa.select(t.c.place_id, t.c.osm_id,
t.c.parent_place_id, t.c.address, t.c.parent_place_id, t.c.address,
values.c.nr.label('housenumber'), values.c.nr.label('housenumber'),
@@ -216,8 +222,7 @@ async def _get_tiger(conn: SearchConnection, place_ids: List[int],
numerals: List[int], osm_id: int, numerals: List[int], osm_id: int,
details: SearchDetails) -> AsyncIterator[nres.SearchResult]: details: SearchDetails) -> AsyncIterator[nres.SearchResult]:
t = conn.t.tiger t = conn.t.tiger
values = sa.values(sa.Column('nr', sa.Integer()), name='housenumber')\ values = _int_list_to_subquery(numerals)
.data([(n,) for n in numerals])
sql = sa.select(t.c.place_id, t.c.parent_place_id, sql = sa.select(t.c.place_id, t.c.parent_place_id,
sa.literal('W').label('osm_type'), sa.literal('W').label('osm_type'),
sa.literal(osm_id).label('osm_id'), sa.literal(osm_id).label('osm_id'),
@@ -573,7 +578,8 @@ class PostcodeSearch(AbstractSearch):
tsearch = conn.t.search_name tsearch = conn.t.search_name
sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\ sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\
.where((tsearch.c.name_vector + tsearch.c.nameaddress_vector) .where((tsearch.c.name_vector + tsearch.c.nameaddress_vector)
.contains(self.lookups[0].tokens)) .contains(sa.type_coerce(self.lookups[0].tokens,
IntArray)))
for ranking in self.rankings: for ranking in self.rankings:
penalty += ranking.sql_penalty(conn.t.search_name) penalty += ranking.sql_penalty(conn.t.search_name)
@@ -692,10 +698,10 @@ class PlaceSearch(AbstractSearch):
sql = sql.order_by(sa.text('accuracy')) sql = sql.order_by(sa.text('accuracy'))
if self.housenumbers: if self.housenumbers:
hnr_regexp = f"\\m({'|'.join(self.housenumbers.values)})\\M" hnr_list = '|'.join(self.housenumbers.values)
sql = sql.where(tsearch.c.address_rank.between(16, 30))\ sql = sql.where(tsearch.c.address_rank.between(16, 30))\
.where(sa.or_(tsearch.c.address_rank < 30, .where(sa.or_(tsearch.c.address_rank < 30,
t.c.housenumber.op('~*')(hnr_regexp))) sa.func.RegexpWord(hnr_list, t.c.housenumber)))
# Cross check for housenumbers, need to do that on a rather large # Cross check for housenumbers, need to do that on a rather large
# set. Worst case there are 40.000 main streets in OSM. # set. Worst case there are 40.000 main streets in OSM.
@@ -703,10 +709,10 @@ class PlaceSearch(AbstractSearch):
# Housenumbers from placex # Housenumbers from placex
thnr = conn.t.placex.alias('hnr') thnr = conn.t.placex.alias('hnr')
pid_list = array_agg(thnr.c.place_id) # type: ignore[no-untyped-call] pid_list = sa.func.ArrayAgg(thnr.c.place_id)
place_sql = sa.select(pid_list)\ place_sql = sa.select(pid_list)\
.where(thnr.c.parent_place_id == inner.c.place_id)\ .where(thnr.c.parent_place_id == inner.c.place_id)\
.where(thnr.c.housenumber.op('~*')(hnr_regexp))\ .where(sa.func.RegexpWord(hnr_list, thnr.c.housenumber))\
.where(thnr.c.linked_place_id == None)\ .where(thnr.c.linked_place_id == None)\
.where(thnr.c.indexed_status == 0) .where(thnr.c.indexed_status == 0)

View File

@@ -188,6 +188,7 @@ def sqlite_json_array_each(element: JsonArrayEach, compiler: 'sa.Compiled', **kw
return "json_each(%s)" % compiler.process(element.clauses, **kw) return "json_each(%s)" % compiler.process(element.clauses, **kw)
class Greatest(sa.sql.functions.GenericFunction[Any]): class Greatest(sa.sql.functions.GenericFunction[Any]):
""" Function to compute maximum of all its input parameters. """ Function to compute maximum of all its input parameters.
""" """
@@ -198,3 +199,23 @@ class Greatest(sa.sql.functions.GenericFunction[Any]):
@compiles(Greatest, 'sqlite') # type: ignore[no-untyped-call, misc] @compiles(Greatest, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_greatest(element: Greatest, compiler: 'sa.Compiled', **kw: Any) -> str: def sqlite_greatest(element: Greatest, compiler: 'sa.Compiled', **kw: Any) -> str:
return "max(%s)" % compiler.process(element.clauses, **kw) return "max(%s)" % compiler.process(element.clauses, **kw)
class RegexpWord(sa.sql.functions.GenericFunction[Any]):
""" Check if a full word is in a given string.
"""
name = 'RegexpWord'
inherit_cache = True
@compiles(RegexpWord, 'postgresql') # type: ignore[no-untyped-call, misc]
def postgres_regexp_nocase(element: RegexpWord, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "%s ~* ('\\m(' || %s || ')\\M')::text" % (compiler.process(arg2, **kw), compiler.process(arg1, **kw))
@compiles(RegexpWord, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_regexp_nocase(element: RegexpWord, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "regexp('\\b(' || %s || ')\\b', %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))

View File

@@ -57,22 +57,16 @@ class IntArray(sa.types.TypeDecorator[Any]):
""" Concate the array with the given array. If one of the """ Concate the array with the given array. If one of the
operants is null, the value of the other will be returned. operants is null, the value of the other will be returned.
""" """
return sa.func.array_cat(self, other, type_=IntArray) return ArrayCat(self.expr, other)
def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators': def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators':
""" Return true if the array contains all the value of the argument """ Return true if the array contains all the value of the argument
array. array.
""" """
return cast('sa.ColumnOperators', self.op('@>', is_comparison=True)(other)) return ArrayContains(self.expr, other)
def overlaps(self, other: SaColumn) -> 'sa.Operators':
""" Return true if at least one value of the argument is contained
in the array.
"""
return self.op('&&', is_comparison=True)(other)
class ArrayAgg(sa.sql.functions.GenericFunction[Any]): class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
""" Aggregate function to collect elements in an array. """ Aggregate function to collect elements in an array.
@@ -82,6 +76,48 @@ class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
name = 'array_agg' name = 'array_agg'
inherit_cache = True inherit_cache = True
@compiles(ArrayAgg, 'sqlite') # type: ignore[no-untyped-call, misc] @compiles(ArrayAgg, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_array_agg(element: ArrayAgg, compiler: 'sa.Compiled', **kw: Any) -> str: def sqlite_array_agg(element: ArrayAgg, compiler: 'sa.Compiled', **kw: Any) -> str:
return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw) return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw)
class ArrayContains(sa.sql.expression.FunctionElement[Any]):
""" Function to check if an array is fully contained in another.
"""
name = 'ArrayContains'
inherit_cache = True
@compiles(ArrayContains) # type: ignore[no-untyped-call, misc]
def generic_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "(%s @> %s)" % (compiler.process(arg1, **kw),
compiler.process(arg2, **kw))
@compiles(ArrayContains, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
return "array_contains(%s)" % compiler.process(element.clauses, **kw)
class ArrayCat(sa.sql.expression.FunctionElement[Any]):
""" Function to check if an array is fully contained in another.
"""
type = IntArray()
identifier = 'ArrayCat'
inherit_cache = True
@compiles(ArrayCat) # type: ignore[no-untyped-call, misc]
def generic_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
return "array_cat(%s)" % compiler.process(element.clauses, **kw)
@compiles(ArrayCat, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "(%s || ',' || %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))

View File

@@ -10,6 +10,7 @@ A custom type that implements a simple key-value store of strings.
from typing import Any from typing import Any
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.dialects.postgresql import HSTORE from sqlalchemy.dialects.postgresql import HSTORE
from sqlalchemy.dialects.sqlite import JSON as sqlite_json from sqlalchemy.dialects.sqlite import JSON as sqlite_json
@@ -37,11 +38,25 @@ class KeyValueStore(sa.types.TypeDecorator[Any]):
one, overwriting values where necessary. When the argument one, overwriting values where necessary. When the argument
is null, nothing happens. is null, nothing happens.
""" """
return self.op('||')(sa.func.coalesce(other, return KeyValueConcat(self.expr, other)
sa.type_coerce('', KeyValueStore)))
class KeyValueConcat(sa.sql.expression.FunctionElement[Any]):
""" Return the merged key-value store from the input parameters.
"""
type = KeyValueStore()
name = 'JsonConcat'
inherit_cache = True
@compiles(KeyValueConcat) # type: ignore[no-untyped-call, misc]
def default_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "(%s || coalesce(%s, ''::hstore))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
@compiles(KeyValueConcat, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "json_patch(%s, coalesce(%s, '{}'))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
def has_key(self, key: SaColumn) -> 'sa.Operators':
""" Return true if the key is cotained in the store.
"""
return self.op('?', is_comparison=True)(key)

View File

@@ -0,0 +1,122 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2023 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Custom functions for SQLite.
"""
from typing import cast, Optional, Set, Any
import json
# pylint: disable=protected-access
def weigh_search(search_vector: Optional[str], rankings: str, default: float) -> float:
""" Custom weight function for search results.
"""
if search_vector is not None:
svec = [int(x) for x in search_vector.split(',')]
for rank in json.loads(rankings):
if all(r in svec for r in rank[1]):
return cast(float, rank[0])
return default
class ArrayIntersectFuzzy:
""" Compute the array of common elements of all input integer arrays.
Very large input paramenters may be ignored to speed up
computation. Therefore, the result is a superset of common elements.
Input and output arrays are given as comma-separated lists.
"""
def __init__(self) -> None:
self.first = ''
self.values: Optional[Set[int]] = None
def step(self, value: Optional[str]) -> None:
""" Add the next array to the intersection.
"""
if value is not None:
if not self.first:
self.first = value
elif len(value) < 10000000:
if self.values is None:
self.values = {int(x) for x in self.first.split(',')}
self.values.intersection_update((int(x) for x in value.split(',')))
def finalize(self) -> str:
""" Return the final result.
"""
if self.values is not None:
return ','.join(map(str, self.values))
return self.first
class ArrayUnion:
""" Compute the set of all elements of the input integer arrays.
Input and output arrays are given as strings of comma-separated lists.
"""
def __init__(self) -> None:
self.values: Optional[Set[str]] = None
def step(self, value: Optional[str]) -> None:
""" Add the next array to the union.
"""
if value is not None:
if self.values is None:
self.values = set(value.split(','))
else:
self.values.update(value.split(','))
def finalize(self) -> str:
""" Return the final result.
"""
return '' if self.values is None else ','.join(self.values)
def array_contains(container: Optional[str], containee: Optional[str]) -> Optional[bool]:
""" Is the array 'containee' completely contained in array 'container'.
"""
if container is None or containee is None:
return None
vset = container.split(',')
return all(v in vset for v in containee.split(','))
def array_pair_contains(container1: Optional[str], container2: Optional[str],
containee: Optional[str]) -> Optional[bool]:
""" Is the array 'containee' completely contained in the union of
array 'container1' and array 'container2'.
"""
if container1 is None or container2 is None or containee is None:
return None
vset = container1.split(',') + container2.split(',')
return all(v in vset for v in containee.split(','))
def install_custom_functions(conn: Any) -> None:
""" Install helper functions for Nominatim into the given SQLite
database connection.
"""
conn.create_function('weigh_search', 3, weigh_search, deterministic=True)
conn.create_function('array_contains', 2, array_contains, deterministic=True)
conn.create_function('array_pair_contains', 3, array_pair_contains, deterministic=True)
_create_aggregate(conn, 'array_intersect_fuzzy', 1, ArrayIntersectFuzzy)
_create_aggregate(conn, 'array_union', 1, ArrayUnion)
async def _make_aggregate(aioconn: Any, *args: Any) -> None:
await aioconn._execute(aioconn._conn.create_aggregate, *args)
def _create_aggregate(conn: Any, name: str, nargs: int, aggregate: Any) -> None:
try:
conn.await_(_make_aggregate(conn._connection, name, nargs, aggregate))
except Exception as error: # pylint: disable=broad-exception-caught
conn._handle_exception(error)

View File

@@ -205,15 +205,15 @@ class SqliteWriter:
async def create_search_index(self) -> None: async def create_search_index(self) -> None:
""" Create the tables and indexes needed for word lookup. """ Create the tables and indexes needed for word lookup.
""" """
LOG.warning("Creating reverse search table")
rsn = sa.Table('reverse_search_name', self.dest.t.meta,
sa.Column('word', sa.Integer()),
sa.Column('column', sa.Text()),
sa.Column('places', IntArray))
await self.dest.connection.run_sync(rsn.create)
tsrc = self.src.t.search_name tsrc = self.src.t.search_name
for column in ('name_vector', 'nameaddress_vector'): for column in ('name_vector', 'nameaddress_vector'):
table_name = f'reverse_search_{column}'
LOG.warning("Creating reverse search %s", table_name)
rsn = sa.Table(table_name, self.dest.t.meta,
sa.Column('word', sa.Integer()),
sa.Column('places', IntArray))
await self.dest.connection.run_sync(rsn.create)
sql = sa.select(sa.func.unnest(getattr(tsrc.c, column)).label('word'), sql = sa.select(sa.func.unnest(getattr(tsrc.c, column)).label('word'),
sa.func.ArrayAgg(tsrc.c.place_id).label('places'))\ sa.func.ArrayAgg(tsrc.c.place_id).label('places'))\
.group_by('word') .group_by('word')
@@ -224,11 +224,12 @@ class SqliteWriter:
for row in partition: for row in partition:
row.places.sort() row.places.sort()
data.append({'word': row.word, data.append({'word': row.word,
'column': column,
'places': row.places}) 'places': row.places})
await self.dest.execute(rsn.insert(), data) await self.dest.execute(rsn.insert(), data)
await self.dest.connection.run_sync( await self.dest.connection.run_sync(
sa.Index(f'idx_reverse_search_{column}_word', rsn.c.word).create) sa.Index('idx_reverse_search_name_word', rsn.c.word).create)
def select_from(self, table: str) -> SaSelect: def select_from(self, table: str) -> SaSelect:

View File

@@ -16,6 +16,7 @@ import sqlalchemy as sa
import nominatim.api as napi import nominatim.api as napi
from nominatim.db.sql_preprocessor import SQLPreprocessor from nominatim.db.sql_preprocessor import SQLPreprocessor
from nominatim.api.search.query_analyzer_factory import make_query_analyzer
from nominatim.tools import convert_sqlite from nominatim.tools import convert_sqlite
import nominatim.api.logging as loglib import nominatim.api.logging as loglib
@@ -160,6 +161,22 @@ class APITester:
"""))) """)))
def add_word_table(self, content):
data = [dict(zip(['word_id', 'word_token', 'type', 'word', 'info'], c))
for c in content]
async def _do_sql():
async with self.api._async_api.begin() as conn:
if 'word' not in conn.t.meta.tables:
await make_query_analyzer(conn)
word_table = conn.t.meta.tables['word']
await conn.connection.run_sync(word_table.create)
if data:
await conn.execute(conn.t.meta.tables['word'].insert(), data)
self.async_to_sync(_do_sql())
async def exec_async(self, sql, *args, **kwargs): async def exec_async(self, sql, *args, **kwargs):
async with self.api._async_api.begin() as conn: async with self.api._async_api.begin() as conn:
return await conn.execute(sql, *args, **kwargs) return await conn.execute(sql, *args, **kwargs)
@@ -195,6 +212,22 @@ def frontend(request, event_loop, tmp_path):
db = str(tmp_path / 'test_nominatim_python_unittest.sqlite') db = str(tmp_path / 'test_nominatim_python_unittest.sqlite')
def mkapi(apiobj, options={'reverse'}): def mkapi(apiobj, options={'reverse'}):
apiobj.add_data('properties',
[{'property': 'tokenizer', 'value': 'icu'},
{'property': 'tokenizer_import_normalisation', 'value': ':: lower();'},
{'property': 'tokenizer_import_transliteration', 'value': "'1' > '/1/'; 'ä' > 'ä '"},
])
async def _do_sql():
async with apiobj.api._async_api.begin() as conn:
if 'word' in conn.t.meta.tables:
return
await make_query_analyzer(conn)
word_table = conn.t.meta.tables['word']
await conn.connection.run_sync(word_table.create)
apiobj.async_to_sync(_do_sql())
event_loop.run_until_complete(convert_sqlite.convert(Path('/invalid'), event_loop.run_until_complete(convert_sqlite.convert(Path('/invalid'),
db, options)) db, options))
outapi = napi.NominatimAPI(Path('/invalid'), outapi = napi.NominatimAPI(Path('/invalid'),

View File

@@ -15,7 +15,7 @@ from nominatim.api.search.db_searches import CountrySearch
from nominatim.api.search.db_search_fields import WeightedStrings from nominatim.api.search.db_search_fields import WeightedStrings
def run_search(apiobj, global_penalty, ccodes, def run_search(apiobj, frontend, global_penalty, ccodes,
country_penalties=None, details=SearchDetails()): country_penalties=None, details=SearchDetails()):
if country_penalties is None: if country_penalties is None:
country_penalties = [0.0] * len(ccodes) country_penalties = [0.0] * len(ccodes)
@@ -25,15 +25,16 @@ def run_search(apiobj, global_penalty, ccodes,
countries = WeightedStrings(ccodes, country_penalties) countries = WeightedStrings(ccodes, country_penalties)
search = CountrySearch(MySearchData()) search = CountrySearch(MySearchData())
api = frontend(apiobj, options=['search'])
async def run(): async def run():
async with apiobj.api._async_api.begin() as conn: async with api._async_api.begin() as conn:
return await search.lookup(conn, details) return await search.lookup(conn, details)
return apiobj.async_to_sync(run()) return api._loop.run_until_complete(run())
def test_find_from_placex(apiobj): def test_find_from_placex(apiobj, frontend):
apiobj.add_placex(place_id=55, class_='boundary', type='administrative', apiobj.add_placex(place_id=55, class_='boundary', type='administrative',
rank_search=4, rank_address=4, rank_search=4, rank_address=4,
name={'name': 'Lolaland'}, name={'name': 'Lolaland'},
@@ -41,32 +42,32 @@ def test_find_from_placex(apiobj):
centroid=(10, 10), centroid=(10, 10),
geometry='POLYGON((9.5 9.5, 9.5 10.5, 10.5 10.5, 10.5 9.5, 9.5 9.5))') geometry='POLYGON((9.5 9.5, 9.5 10.5, 10.5 10.5, 10.5 9.5, 9.5 9.5))')
results = run_search(apiobj, 0.5, ['de', 'yw'], [0.0, 0.3]) results = run_search(apiobj, frontend, 0.5, ['de', 'yw'], [0.0, 0.3])
assert len(results) == 1 assert len(results) == 1
assert results[0].place_id == 55 assert results[0].place_id == 55
assert results[0].accuracy == 0.8 assert results[0].accuracy == 0.8
def test_find_from_fallback_countries(apiobj): def test_find_from_fallback_countries(apiobj, frontend):
apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))') apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))')
apiobj.add_country_name('ro', {'name': 'România'}) apiobj.add_country_name('ro', {'name': 'România'})
results = run_search(apiobj, 0.0, ['ro']) results = run_search(apiobj, frontend, 0.0, ['ro'])
assert len(results) == 1 assert len(results) == 1
assert results[0].names == {'name': 'România'} assert results[0].names == {'name': 'România'}
def test_find_none(apiobj): def test_find_none(apiobj, frontend):
assert len(run_search(apiobj, 0.0, ['xx'])) == 0 assert len(run_search(apiobj, frontend, 0.0, ['xx'])) == 0
@pytest.mark.parametrize('coord,numres', [((0.5, 1), 1), ((10, 10), 0)]) @pytest.mark.parametrize('coord,numres', [((0.5, 1), 1), ((10, 10), 0)])
def test_find_near(apiobj, coord, numres): def test_find_near(apiobj, frontend, coord, numres):
apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))') apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))')
apiobj.add_country_name('ro', {'name': 'România'}) apiobj.add_country_name('ro', {'name': 'România'})
results = run_search(apiobj, 0.0, ['ro'], results = run_search(apiobj, frontend, 0.0, ['ro'],
details=SearchDetails(near=napi.Point(*coord), details=SearchDetails(near=napi.Point(*coord),
near_radius=0.1)) near_radius=0.1))
@@ -92,8 +93,8 @@ class TestCountryParameters:
napi.GeometryFormat.SVG, napi.GeometryFormat.SVG,
napi.GeometryFormat.TEXT]) napi.GeometryFormat.TEXT])
@pytest.mark.parametrize('cc', ['yw', 'ro']) @pytest.mark.parametrize('cc', ['yw', 'ro'])
def test_return_geometries(self, apiobj, geom, cc): def test_return_geometries(self, apiobj, frontend, geom, cc):
results = run_search(apiobj, 0.5, [cc], results = run_search(apiobj, frontend, 0.5, [cc],
details=SearchDetails(geometry_output=geom)) details=SearchDetails(geometry_output=geom))
assert len(results) == 1 assert len(results) == 1
@@ -101,8 +102,8 @@ class TestCountryParameters:
@pytest.mark.parametrize('pid,rids', [(76, [55]), (55, [])]) @pytest.mark.parametrize('pid,rids', [(76, [55]), (55, [])])
def test_exclude_place_id(self, apiobj, pid, rids): def test_exclude_place_id(self, apiobj, frontend, pid, rids):
results = run_search(apiobj, 0.5, ['yw', 'ro'], results = run_search(apiobj, frontend, 0.5, ['yw', 'ro'],
details=SearchDetails(excluded=[pid])) details=SearchDetails(excluded=[pid]))
assert [r.place_id for r in results] == rids assert [r.place_id for r in results] == rids
@@ -110,8 +111,8 @@ class TestCountryParameters:
@pytest.mark.parametrize('viewbox,rids', [((9, 9, 11, 11), [55]), @pytest.mark.parametrize('viewbox,rids', [((9, 9, 11, 11), [55]),
((-10, -10, -3, -3), [])]) ((-10, -10, -3, -3), [])])
def test_bounded_viewbox_in_placex(self, apiobj, viewbox, rids): def test_bounded_viewbox_in_placex(self, apiobj, frontend, viewbox, rids):
results = run_search(apiobj, 0.5, ['yw'], results = run_search(apiobj, frontend, 0.5, ['yw'],
details=SearchDetails.from_kwargs({'viewbox': viewbox, details=SearchDetails.from_kwargs({'viewbox': viewbox,
'bounded_viewbox': True})) 'bounded_viewbox': True}))
@@ -120,8 +121,8 @@ class TestCountryParameters:
@pytest.mark.parametrize('viewbox,numres', [((0, 0, 1, 1), 1), @pytest.mark.parametrize('viewbox,numres', [((0, 0, 1, 1), 1),
((-10, -10, -3, -3), 0)]) ((-10, -10, -3, -3), 0)])
def test_bounded_viewbox_in_fallback(self, apiobj, viewbox, numres): def test_bounded_viewbox_in_fallback(self, apiobj, frontend, viewbox, numres):
results = run_search(apiobj, 0.5, ['ro'], results = run_search(apiobj, frontend, 0.5, ['ro'],
details=SearchDetails.from_kwargs({'viewbox': viewbox, details=SearchDetails.from_kwargs({'viewbox': viewbox,
'bounded_viewbox': True})) 'bounded_viewbox': True}))

View File

@@ -17,7 +17,7 @@ from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCateg
from nominatim.api.search.db_search_lookups import LookupAll from nominatim.api.search.db_search_lookups import LookupAll
def run_search(apiobj, global_penalty, cat, cat_penalty=None, ccodes=[], def run_search(apiobj, frontend, global_penalty, cat, cat_penalty=None, ccodes=[],
details=SearchDetails()): details=SearchDetails()):
class PlaceSearchData: class PlaceSearchData:
@@ -39,21 +39,23 @@ def run_search(apiobj, global_penalty, cat, cat_penalty=None, ccodes=[],
near_search = NearSearch(0.1, WeightedCategories(cat, cat_penalty), place_search) near_search = NearSearch(0.1, WeightedCategories(cat, cat_penalty), place_search)
api = frontend(apiobj, options=['search'])
async def run(): async def run():
async with apiobj.api._async_api.begin() as conn: async with api._async_api.begin() as conn:
return await near_search.lookup(conn, details) return await near_search.lookup(conn, details)
results = apiobj.async_to_sync(run()) results = api._loop.run_until_complete(run())
results.sort(key=lambda r: r.accuracy) results.sort(key=lambda r: r.accuracy)
return results return results
def test_no_results_inner_query(apiobj): def test_no_results_inner_query(apiobj, frontend):
assert not run_search(apiobj, 0.4, [('this', 'that')]) assert not run_search(apiobj, frontend, 0.4, [('this', 'that')])
def test_no_appropriate_results_inner_query(apiobj): def test_no_appropriate_results_inner_query(apiobj, frontend):
apiobj.add_placex(place_id=100, country_code='us', apiobj.add_placex(place_id=100, country_code='us',
centroid=(5.6, 4.3), centroid=(5.6, 4.3),
geometry='POLYGON((0.0 0.0, 10.0 0.0, 10.0 2.0, 0.0 2.0, 0.0 0.0))') geometry='POLYGON((0.0 0.0, 10.0 0.0, 10.0 2.0, 0.0 2.0, 0.0 0.0))')
@@ -62,7 +64,7 @@ def test_no_appropriate_results_inner_query(apiobj):
apiobj.add_placex(place_id=22, class_='amenity', type='bank', apiobj.add_placex(place_id=22, class_='amenity', type='bank',
centroid=(5.6001, 4.2994)) centroid=(5.6001, 4.2994))
assert not run_search(apiobj, 0.4, [('amenity', 'bank')]) assert not run_search(apiobj, frontend, 0.4, [('amenity', 'bank')])
class TestNearSearch: class TestNearSearch:
@@ -79,18 +81,18 @@ class TestNearSearch:
centroid=(-10.3, 56.9)) centroid=(-10.3, 56.9))
def test_near_in_placex(self, apiobj): def test_near_in_placex(self, apiobj, frontend):
apiobj.add_placex(place_id=22, class_='amenity', type='bank', apiobj.add_placex(place_id=22, class_='amenity', type='bank',
centroid=(5.6001, 4.2994)) centroid=(5.6001, 4.2994))
apiobj.add_placex(place_id=23, class_='amenity', type='bench', apiobj.add_placex(place_id=23, class_='amenity', type='bench',
centroid=(5.6001, 4.2994)) centroid=(5.6001, 4.2994))
results = run_search(apiobj, 0.1, [('amenity', 'bank')]) results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')])
assert [r.place_id for r in results] == [22] assert [r.place_id for r in results] == [22]
def test_multiple_types_near_in_placex(self, apiobj): def test_multiple_types_near_in_placex(self, apiobj, frontend):
apiobj.add_placex(place_id=22, class_='amenity', type='bank', apiobj.add_placex(place_id=22, class_='amenity', type='bank',
importance=0.002, importance=0.002,
centroid=(5.6001, 4.2994)) centroid=(5.6001, 4.2994))
@@ -98,13 +100,13 @@ class TestNearSearch:
importance=0.001, importance=0.001,
centroid=(5.6001, 4.2994)) centroid=(5.6001, 4.2994))
results = run_search(apiobj, 0.1, [('amenity', 'bank'), results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank'),
('amenity', 'bench')]) ('amenity', 'bench')])
assert [r.place_id for r in results] == [22, 23] assert [r.place_id for r in results] == [22, 23]
def test_near_in_classtype(self, apiobj): def test_near_in_classtype(self, apiobj, frontend):
apiobj.add_placex(place_id=22, class_='amenity', type='bank', apiobj.add_placex(place_id=22, class_='amenity', type='bank',
centroid=(5.6, 4.34)) centroid=(5.6, 4.34))
apiobj.add_placex(place_id=23, class_='amenity', type='bench', apiobj.add_placex(place_id=23, class_='amenity', type='bench',
@@ -112,13 +114,13 @@ class TestNearSearch:
apiobj.add_class_type_table('amenity', 'bank') apiobj.add_class_type_table('amenity', 'bank')
apiobj.add_class_type_table('amenity', 'bench') apiobj.add_class_type_table('amenity', 'bench')
results = run_search(apiobj, 0.1, [('amenity', 'bank')]) results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')])
assert [r.place_id for r in results] == [22] assert [r.place_id for r in results] == [22]
@pytest.mark.parametrize('cc,rid', [('us', 22), ('mx', 23)]) @pytest.mark.parametrize('cc,rid', [('us', 22), ('mx', 23)])
def test_restrict_by_country(self, apiobj, cc, rid): def test_restrict_by_country(self, apiobj, frontend, cc, rid):
apiobj.add_placex(place_id=22, class_='amenity', type='bank', apiobj.add_placex(place_id=22, class_='amenity', type='bank',
centroid=(5.6001, 4.2994), centroid=(5.6001, 4.2994),
country_code='us') country_code='us')
@@ -132,13 +134,13 @@ class TestNearSearch:
centroid=(-10.3001, 56.9), centroid=(-10.3001, 56.9),
country_code='us') country_code='us')
results = run_search(apiobj, 0.1, [('amenity', 'bank')], ccodes=[cc, 'fr']) results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')], ccodes=[cc, 'fr'])
assert [r.place_id for r in results] == [rid] assert [r.place_id for r in results] == [rid]
@pytest.mark.parametrize('excluded,rid', [(22, 122), (122, 22)]) @pytest.mark.parametrize('excluded,rid', [(22, 122), (122, 22)])
def test_exclude_place_by_id(self, apiobj, excluded, rid): def test_exclude_place_by_id(self, apiobj, frontend, excluded, rid):
apiobj.add_placex(place_id=22, class_='amenity', type='bank', apiobj.add_placex(place_id=22, class_='amenity', type='bank',
centroid=(5.6001, 4.2994), centroid=(5.6001, 4.2994),
country_code='us') country_code='us')
@@ -147,7 +149,7 @@ class TestNearSearch:
country_code='us') country_code='us')
results = run_search(apiobj, 0.1, [('amenity', 'bank')], results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')],
details=SearchDetails(excluded=[excluded])) details=SearchDetails(excluded=[excluded]))
assert [r.place_id for r in results] == [rid] assert [r.place_id for r in results] == [rid]
@@ -155,12 +157,12 @@ class TestNearSearch:
@pytest.mark.parametrize('layer,rids', [(napi.DataLayer.POI, [22]), @pytest.mark.parametrize('layer,rids', [(napi.DataLayer.POI, [22]),
(napi.DataLayer.MANMADE, [])]) (napi.DataLayer.MANMADE, [])])
def test_with_layer(self, apiobj, layer, rids): def test_with_layer(self, apiobj, frontend, layer, rids):
apiobj.add_placex(place_id=22, class_='amenity', type='bank', apiobj.add_placex(place_id=22, class_='amenity', type='bank',
centroid=(5.6001, 4.2994), centroid=(5.6001, 4.2994),
country_code='us') country_code='us')
results = run_search(apiobj, 0.1, [('amenity', 'bank')], results = run_search(apiobj, frontend, 0.1, [('amenity', 'bank')],
details=SearchDetails(layers=layer)) details=SearchDetails(layers=layer))
assert [r.place_id for r in results] == rids assert [r.place_id for r in results] == rids

View File

@@ -18,7 +18,9 @@ from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCateg
FieldLookup, FieldRanking, RankedTokens FieldLookup, FieldRanking, RankedTokens
from nominatim.api.search.db_search_lookups import LookupAll, LookupAny, Restrict from nominatim.api.search.db_search_lookups import LookupAll, LookupAny, Restrict
def run_search(apiobj, global_penalty, lookup, ranking, count=2, APIOPTIONS = ['search']
def run_search(apiobj, frontend, global_penalty, lookup, ranking, count=2,
hnrs=[], pcs=[], ccodes=[], quals=[], hnrs=[], pcs=[], ccodes=[], quals=[],
details=SearchDetails()): details=SearchDetails()):
class MySearchData: class MySearchData:
@@ -32,11 +34,16 @@ def run_search(apiobj, global_penalty, lookup, ranking, count=2,
search = PlaceSearch(0.0, MySearchData(), count) search = PlaceSearch(0.0, MySearchData(), count)
if frontend is None:
api = apiobj
else:
api = frontend(apiobj, options=APIOPTIONS)
async def run(): async def run():
async with apiobj.api._async_api.begin() as conn: async with api._async_api.begin() as conn:
return await search.lookup(conn, details) return await search.lookup(conn, details)
results = apiobj.async_to_sync(run()) results = api._loop.run_until_complete(run())
results.sort(key=lambda r: r.accuracy) results.sort(key=lambda r: r.accuracy)
return results return results
@@ -59,61 +66,61 @@ class TestNameOnlySearches:
@pytest.mark.parametrize('lookup_type', [LookupAll, Restrict]) @pytest.mark.parametrize('lookup_type', [LookupAll, Restrict])
@pytest.mark.parametrize('rank,res', [([10], [100, 101]), @pytest.mark.parametrize('rank,res', [([10], [100, 101]),
([20], [101, 100])]) ([20], [101, 100])])
def test_lookup_all_match(self, apiobj, lookup_type, rank, res): def test_lookup_all_match(self, apiobj, frontend, lookup_type, rank, res):
lookup = FieldLookup('name_vector', [1,2], lookup_type) lookup = FieldLookup('name_vector', [1,2], lookup_type)
ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)]) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)])
results = run_search(apiobj, 0.1, [lookup], [ranking]) results = run_search(apiobj, frontend, 0.1, [lookup], [ranking])
assert [r.place_id for r in results] == res assert [r.place_id for r in results] == res
@pytest.mark.parametrize('lookup_type', [LookupAll, Restrict]) @pytest.mark.parametrize('lookup_type', [LookupAll, Restrict])
def test_lookup_all_partial_match(self, apiobj, lookup_type): def test_lookup_all_partial_match(self, apiobj, frontend, lookup_type):
lookup = FieldLookup('name_vector', [1,20], lookup_type) lookup = FieldLookup('name_vector', [1,20], lookup_type)
ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
results = run_search(apiobj, 0.1, [lookup], [ranking]) results = run_search(apiobj, frontend, 0.1, [lookup], [ranking])
assert len(results) == 1 assert len(results) == 1
assert results[0].place_id == 101 assert results[0].place_id == 101
@pytest.mark.parametrize('rank,res', [([10], [100, 101]), @pytest.mark.parametrize('rank,res', [([10], [100, 101]),
([20], [101, 100])]) ([20], [101, 100])])
def test_lookup_any_match(self, apiobj, rank, res): def test_lookup_any_match(self, apiobj, frontend, rank, res):
lookup = FieldLookup('name_vector', [11,21], LookupAny) lookup = FieldLookup('name_vector', [11,21], LookupAny)
ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)]) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)])
results = run_search(apiobj, 0.1, [lookup], [ranking]) results = run_search(apiobj, frontend, 0.1, [lookup], [ranking])
assert [r.place_id for r in results] == res assert [r.place_id for r in results] == res
def test_lookup_any_partial_match(self, apiobj): def test_lookup_any_partial_match(self, apiobj, frontend):
lookup = FieldLookup('name_vector', [20], LookupAll) lookup = FieldLookup('name_vector', [20], LookupAll)
ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
results = run_search(apiobj, 0.1, [lookup], [ranking]) results = run_search(apiobj, frontend, 0.1, [lookup], [ranking])
assert len(results) == 1 assert len(results) == 1
assert results[0].place_id == 101 assert results[0].place_id == 101
@pytest.mark.parametrize('cc,res', [('us', 100), ('mx', 101)]) @pytest.mark.parametrize('cc,res', [('us', 100), ('mx', 101)])
def test_lookup_restrict_country(self, apiobj, cc, res): def test_lookup_restrict_country(self, apiobj, frontend, cc, res):
lookup = FieldLookup('name_vector', [1,2], LookupAll) lookup = FieldLookup('name_vector', [1,2], LookupAll)
ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])]) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])])
results = run_search(apiobj, 0.1, [lookup], [ranking], ccodes=[cc]) results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], ccodes=[cc])
assert [r.place_id for r in results] == [res] assert [r.place_id for r in results] == [res]
def test_lookup_restrict_placeid(self, apiobj): def test_lookup_restrict_placeid(self, apiobj, frontend):
lookup = FieldLookup('name_vector', [1,2], LookupAll) lookup = FieldLookup('name_vector', [1,2], LookupAll)
ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])]) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])])
results = run_search(apiobj, 0.1, [lookup], [ranking], results = run_search(apiobj, frontend, 0.1, [lookup], [ranking],
details=SearchDetails(excluded=[101])) details=SearchDetails(excluded=[101]))
assert [r.place_id for r in results] == [100] assert [r.place_id for r in results] == [100]
@@ -123,18 +130,18 @@ class TestNameOnlySearches:
napi.GeometryFormat.KML, napi.GeometryFormat.KML,
napi.GeometryFormat.SVG, napi.GeometryFormat.SVG,
napi.GeometryFormat.TEXT]) napi.GeometryFormat.TEXT])
def test_return_geometries(self, apiobj, geom): def test_return_geometries(self, apiobj, frontend, geom):
lookup = FieldLookup('name_vector', [20], LookupAll) lookup = FieldLookup('name_vector', [20], LookupAll)
ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
results = run_search(apiobj, 0.1, [lookup], [ranking], results = run_search(apiobj, frontend, 0.1, [lookup], [ranking],
details=SearchDetails(geometry_output=geom)) details=SearchDetails(geometry_output=geom))
assert geom.name.lower() in results[0].geometry assert geom.name.lower() in results[0].geometry
@pytest.mark.parametrize('factor,npoints', [(0.0, 3), (1.0, 2)]) @pytest.mark.parametrize('factor,npoints', [(0.0, 3), (1.0, 2)])
def test_return_simplified_geometry(self, apiobj, factor, npoints): def test_return_simplified_geometry(self, apiobj, frontend, factor, npoints):
apiobj.add_placex(place_id=333, country_code='us', apiobj.add_placex(place_id=333, country_code='us',
centroid=(9.0, 9.0), centroid=(9.0, 9.0),
geometry='LINESTRING(8.9 9.0, 9.0 9.0, 9.1 9.0)') geometry='LINESTRING(8.9 9.0, 9.0 9.0, 9.1 9.0)')
@@ -144,7 +151,7 @@ class TestNameOnlySearches:
lookup = FieldLookup('name_vector', [55], LookupAll) lookup = FieldLookup('name_vector', [55], LookupAll)
ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
results = run_search(apiobj, 0.1, [lookup], [ranking], results = run_search(apiobj, frontend, 0.1, [lookup], [ranking],
details=SearchDetails(geometry_output=napi.GeometryFormat.GEOJSON, details=SearchDetails(geometry_output=napi.GeometryFormat.GEOJSON,
geometry_simplification=factor)) geometry_simplification=factor))
@@ -158,50 +165,52 @@ class TestNameOnlySearches:
@pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.7,4.0,6.0,5.0']) @pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.7,4.0,6.0,5.0'])
@pytest.mark.parametrize('wcount,rids', [(2, [100, 101]), (20000, [100])]) @pytest.mark.parametrize('wcount,rids', [(2, [100, 101]), (20000, [100])])
def test_prefer_viewbox(self, apiobj, viewbox, wcount, rids): def test_prefer_viewbox(self, apiobj, frontend, viewbox, wcount, rids):
lookup = FieldLookup('name_vector', [1, 2], LookupAll) lookup = FieldLookup('name_vector', [1, 2], LookupAll)
ranking = FieldRanking('name_vector', 0.2, [RankedTokens(0.0, [21])]) ranking = FieldRanking('name_vector', 0.2, [RankedTokens(0.0, [21])])
results = run_search(apiobj, 0.1, [lookup], [ranking]) api = frontend(apiobj, options=APIOPTIONS)
results = run_search(api, None, 0.1, [lookup], [ranking])
assert [r.place_id for r in results] == [101, 100] assert [r.place_id for r in results] == [101, 100]
results = run_search(apiobj, 0.1, [lookup], [ranking], count=wcount, results = run_search(api, None, 0.1, [lookup], [ranking], count=wcount,
details=SearchDetails.from_kwargs({'viewbox': viewbox})) details=SearchDetails.from_kwargs({'viewbox': viewbox}))
assert [r.place_id for r in results] == rids assert [r.place_id for r in results] == rids
@pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.55,4.27,5.62,4.31']) @pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.55,4.27,5.62,4.31'])
def test_force_viewbox(self, apiobj, viewbox): def test_force_viewbox(self, apiobj, frontend, viewbox):
lookup = FieldLookup('name_vector', [1, 2], LookupAll) lookup = FieldLookup('name_vector', [1, 2], LookupAll)
details=SearchDetails.from_kwargs({'viewbox': viewbox, details=SearchDetails.from_kwargs({'viewbox': viewbox,
'bounded_viewbox': True}) 'bounded_viewbox': True})
results = run_search(apiobj, 0.1, [lookup], [], details=details) results = run_search(apiobj, frontend, 0.1, [lookup], [], details=details)
assert [r.place_id for r in results] == [100] assert [r.place_id for r in results] == [100]
def test_prefer_near(self, apiobj): def test_prefer_near(self, apiobj, frontend):
lookup = FieldLookup('name_vector', [1, 2], LookupAll) lookup = FieldLookup('name_vector', [1, 2], LookupAll)
ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])])
results = run_search(apiobj, 0.1, [lookup], [ranking]) api = frontend(apiobj, options=APIOPTIONS)
results = run_search(api, None, 0.1, [lookup], [ranking])
assert [r.place_id for r in results] == [101, 100] assert [r.place_id for r in results] == [101, 100]
results = run_search(apiobj, 0.1, [lookup], [ranking], results = run_search(api, None, 0.1, [lookup], [ranking],
details=SearchDetails.from_kwargs({'near': '5.6,4.3'})) details=SearchDetails.from_kwargs({'near': '5.6,4.3'}))
results.sort(key=lambda r: -r.importance) results.sort(key=lambda r: -r.importance)
assert [r.place_id for r in results] == [100, 101] assert [r.place_id for r in results] == [100, 101]
@pytest.mark.parametrize('radius', [0.09, 0.11]) @pytest.mark.parametrize('radius', [0.09, 0.11])
def test_force_near(self, apiobj, radius): def test_force_near(self, apiobj, frontend, radius):
lookup = FieldLookup('name_vector', [1, 2], LookupAll) lookup = FieldLookup('name_vector', [1, 2], LookupAll)
details=SearchDetails.from_kwargs({'near': '5.6,4.3', details=SearchDetails.from_kwargs({'near': '5.6,4.3',
'near_radius': radius}) 'near_radius': radius})
results = run_search(apiobj, 0.1, [lookup], [], details=details) results = run_search(apiobj, frontend, 0.1, [lookup], [], details=details)
assert [r.place_id for r in results] == [100] assert [r.place_id for r in results] == [100]
@@ -242,72 +251,72 @@ class TestStreetWithHousenumber:
@pytest.mark.parametrize('hnr,res', [('20', [91, 1]), ('20 a', [1]), @pytest.mark.parametrize('hnr,res', [('20', [91, 1]), ('20 a', [1]),
('21', [2]), ('22', [2, 92]), ('21', [2]), ('22', [2, 92]),
('24', [93]), ('25', [])]) ('24', [93]), ('25', [])])
def test_lookup_by_single_housenumber(self, apiobj, hnr, res): def test_lookup_by_single_housenumber(self, apiobj, frontend, hnr, res):
lookup = FieldLookup('name_vector', [1,2], LookupAll) lookup = FieldLookup('name_vector', [1,2], LookupAll)
ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=[hnr]) results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=[hnr])
assert [r.place_id for r in results] == res + [1000, 2000] assert [r.place_id for r in results] == res + [1000, 2000]
@pytest.mark.parametrize('cc,res', [('es', [2, 1000]), ('pt', [92, 2000])]) @pytest.mark.parametrize('cc,res', [('es', [2, 1000]), ('pt', [92, 2000])])
def test_lookup_with_country_restriction(self, apiobj, cc, res): def test_lookup_with_country_restriction(self, apiobj, frontend, cc, res):
lookup = FieldLookup('name_vector', [1,2], LookupAll) lookup = FieldLookup('name_vector', [1,2], LookupAll)
ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
ccodes=[cc]) ccodes=[cc])
assert [r.place_id for r in results] == res assert [r.place_id for r in results] == res
def test_lookup_exclude_housenumber_placeid(self, apiobj): def test_lookup_exclude_housenumber_placeid(self, apiobj, frontend):
lookup = FieldLookup('name_vector', [1,2], LookupAll) lookup = FieldLookup('name_vector', [1,2], LookupAll)
ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
details=SearchDetails(excluded=[92])) details=SearchDetails(excluded=[92]))
assert [r.place_id for r in results] == [2, 1000, 2000] assert [r.place_id for r in results] == [2, 1000, 2000]
def test_lookup_exclude_street_placeid(self, apiobj): def test_lookup_exclude_street_placeid(self, apiobj, frontend):
lookup = FieldLookup('name_vector', [1,2], LookupAll) lookup = FieldLookup('name_vector', [1,2], LookupAll)
ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
details=SearchDetails(excluded=[1000])) details=SearchDetails(excluded=[1000]))
assert [r.place_id for r in results] == [2, 92, 2000] assert [r.place_id for r in results] == [2, 92, 2000]
def test_lookup_only_house_qualifier(self, apiobj): def test_lookup_only_house_qualifier(self, apiobj, frontend):
lookup = FieldLookup('name_vector', [1,2], LookupAll) lookup = FieldLookup('name_vector', [1,2], LookupAll)
ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
quals=[('place', 'house')]) quals=[('place', 'house')])
assert [r.place_id for r in results] == [2, 92] assert [r.place_id for r in results] == [2, 92]
def test_lookup_only_street_qualifier(self, apiobj): def test_lookup_only_street_qualifier(self, apiobj, frontend):
lookup = FieldLookup('name_vector', [1,2], LookupAll) lookup = FieldLookup('name_vector', [1,2], LookupAll)
ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
quals=[('highway', 'residential')]) quals=[('highway', 'residential')])
assert [r.place_id for r in results] == [1000, 2000] assert [r.place_id for r in results] == [1000, 2000]
@pytest.mark.parametrize('rank,found', [(26, True), (27, False), (30, False)]) @pytest.mark.parametrize('rank,found', [(26, True), (27, False), (30, False)])
def test_lookup_min_rank(self, apiobj, rank, found): def test_lookup_min_rank(self, apiobj, frontend, rank, found):
lookup = FieldLookup('name_vector', [1,2], LookupAll) lookup = FieldLookup('name_vector', [1,2], LookupAll)
ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], results = run_search(apiobj, frontend, 0.1, [lookup], [ranking], hnrs=['22'],
details=SearchDetails(min_rank=rank)) details=SearchDetails(min_rank=rank))
assert [r.place_id for r in results] == ([2, 92, 1000, 2000] if found else [2, 92]) assert [r.place_id for r in results] == ([2, 92, 1000, 2000] if found else [2, 92])
@@ -317,17 +326,17 @@ class TestStreetWithHousenumber:
napi.GeometryFormat.KML, napi.GeometryFormat.KML,
napi.GeometryFormat.SVG, napi.GeometryFormat.SVG,
napi.GeometryFormat.TEXT]) napi.GeometryFormat.TEXT])
def test_return_geometries(self, apiobj, geom): def test_return_geometries(self, apiobj, frontend, geom):
lookup = FieldLookup('name_vector', [1, 2], LookupAll) lookup = FieldLookup('name_vector', [1, 2], LookupAll)
results = run_search(apiobj, 0.1, [lookup], [], hnrs=['20', '21', '22'], results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=['20', '21', '22'],
details=SearchDetails(geometry_output=geom)) details=SearchDetails(geometry_output=geom))
assert results assert results
assert all(geom.name.lower() in r.geometry for r in results) assert all(geom.name.lower() in r.geometry for r in results)
def test_very_large_housenumber(apiobj): def test_very_large_housenumber(apiobj, frontend):
apiobj.add_placex(place_id=93, class_='place', type='house', apiobj.add_placex(place_id=93, class_='place', type='house',
parent_place_id=2000, parent_place_id=2000,
housenumber='2467463524544', country_code='pt') housenumber='2467463524544', country_code='pt')
@@ -340,7 +349,7 @@ def test_very_large_housenumber(apiobj):
lookup = FieldLookup('name_vector', [1, 2], LookupAll) lookup = FieldLookup('name_vector', [1, 2], LookupAll)
results = run_search(apiobj, 0.1, [lookup], [], hnrs=['2467463524544'], results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=['2467463524544'],
details=SearchDetails()) details=SearchDetails())
assert results assert results
@@ -348,7 +357,7 @@ def test_very_large_housenumber(apiobj):
@pytest.mark.parametrize('wcount,rids', [(2, [990, 991]), (30000, [990])]) @pytest.mark.parametrize('wcount,rids', [(2, [990, 991]), (30000, [990])])
def test_name_and_postcode(apiobj, wcount, rids): def test_name_and_postcode(apiobj, frontend, wcount, rids):
apiobj.add_placex(place_id=990, class_='highway', type='service', apiobj.add_placex(place_id=990, class_='highway', type='service',
rank_search=27, rank_address=27, rank_search=27, rank_address=27,
postcode='11225', postcode='11225',
@@ -368,7 +377,7 @@ def test_name_and_postcode(apiobj, wcount, rids):
lookup = FieldLookup('name_vector', [111], LookupAll) lookup = FieldLookup('name_vector', [111], LookupAll)
results = run_search(apiobj, 0.1, [lookup], [], pcs=['11225'], count=wcount, results = run_search(apiobj, frontend, 0.1, [lookup], [], pcs=['11225'], count=wcount,
details=SearchDetails()) details=SearchDetails())
assert results assert results
@@ -398,10 +407,10 @@ class TestInterpolations:
@pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])]) @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])])
def test_lookup_housenumber(self, apiobj, hnr, res): def test_lookup_housenumber(self, apiobj, frontend, hnr, res):
lookup = FieldLookup('name_vector', [111], LookupAll) lookup = FieldLookup('name_vector', [111], LookupAll)
results = run_search(apiobj, 0.1, [lookup], [], hnrs=[hnr]) results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=[hnr])
assert [r.place_id for r in results] == res + [990] assert [r.place_id for r in results] == res + [990]
@@ -410,10 +419,10 @@ class TestInterpolations:
napi.GeometryFormat.KML, napi.GeometryFormat.KML,
napi.GeometryFormat.SVG, napi.GeometryFormat.SVG,
napi.GeometryFormat.TEXT]) napi.GeometryFormat.TEXT])
def test_osmline_with_geometries(self, apiobj, geom): def test_osmline_with_geometries(self, apiobj, frontend, geom):
lookup = FieldLookup('name_vector', [111], LookupAll) lookup = FieldLookup('name_vector', [111], LookupAll)
results = run_search(apiobj, 0.1, [lookup], [], hnrs=['21'], results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=['21'],
details=SearchDetails(geometry_output=geom)) details=SearchDetails(geometry_output=geom))
assert results[0].place_id == 992 assert results[0].place_id == 992
@@ -446,10 +455,10 @@ class TestTiger:
@pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])]) @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])])
def test_lookup_housenumber(self, apiobj, hnr, res): def test_lookup_housenumber(self, apiobj, frontend, hnr, res):
lookup = FieldLookup('name_vector', [111], LookupAll) lookup = FieldLookup('name_vector', [111], LookupAll)
results = run_search(apiobj, 0.1, [lookup], [], hnrs=[hnr]) results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=[hnr])
assert [r.place_id for r in results] == res + [990] assert [r.place_id for r in results] == res + [990]
@@ -458,10 +467,10 @@ class TestTiger:
napi.GeometryFormat.KML, napi.GeometryFormat.KML,
napi.GeometryFormat.SVG, napi.GeometryFormat.SVG,
napi.GeometryFormat.TEXT]) napi.GeometryFormat.TEXT])
def test_tiger_with_geometries(self, apiobj, geom): def test_tiger_with_geometries(self, apiobj, frontend, geom):
lookup = FieldLookup('name_vector', [111], LookupAll) lookup = FieldLookup('name_vector', [111], LookupAll)
results = run_search(apiobj, 0.1, [lookup], [], hnrs=['21'], results = run_search(apiobj, frontend, 0.1, [lookup], [], hnrs=['21'],
details=SearchDetails(geometry_output=geom)) details=SearchDetails(geometry_output=geom))
assert results[0].place_id == 992 assert results[0].place_id == 992
@@ -513,10 +522,10 @@ class TestLayersRank30:
(napi.DataLayer.NATURAL, [227]), (napi.DataLayer.NATURAL, [227]),
(napi.DataLayer.MANMADE | napi.DataLayer.NATURAL, [225, 227]), (napi.DataLayer.MANMADE | napi.DataLayer.NATURAL, [225, 227]),
(napi.DataLayer.MANMADE | napi.DataLayer.RAILWAY, [225, 226])]) (napi.DataLayer.MANMADE | napi.DataLayer.RAILWAY, [225, 226])])
def test_layers_rank30(self, apiobj, layer, res): def test_layers_rank30(self, apiobj, frontend, layer, res):
lookup = FieldLookup('name_vector', [34], LookupAny) lookup = FieldLookup('name_vector', [34], LookupAny)
results = run_search(apiobj, 0.1, [lookup], [], results = run_search(apiobj, frontend, 0.1, [lookup], [],
details=SearchDetails(layers=layer)) details=SearchDetails(layers=layer))
assert [r.place_id for r in results] == res assert [r.place_id for r in results] == res

View File

@@ -15,7 +15,7 @@ from nominatim.api.search.db_searches import PoiSearch
from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories
def run_search(apiobj, global_penalty, poitypes, poi_penalties=None, def run_search(apiobj, frontend, global_penalty, poitypes, poi_penalties=None,
ccodes=[], details=SearchDetails()): ccodes=[], details=SearchDetails()):
if poi_penalties is None: if poi_penalties is None:
poi_penalties = [0.0] * len(poitypes) poi_penalties = [0.0] * len(poitypes)
@@ -27,16 +27,18 @@ def run_search(apiobj, global_penalty, poitypes, poi_penalties=None,
search = PoiSearch(MySearchData()) search = PoiSearch(MySearchData())
api = frontend(apiobj, options=['search'])
async def run(): async def run():
async with apiobj.api._async_api.begin() as conn: async with api._async_api.begin() as conn:
return await search.lookup(conn, details) return await search.lookup(conn, details)
return apiobj.async_to_sync(run()) return api._loop.run_until_complete(run())
@pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2), @pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2),
('5.0, 4.59933', 1)]) ('5.0, 4.59933', 1)])
def test_simple_near_search_in_placex(apiobj, coord, pid): def test_simple_near_search_in_placex(apiobj, frontend, coord, pid):
apiobj.add_placex(place_id=1, class_='highway', type='bus_stop', apiobj.add_placex(place_id=1, class_='highway', type='bus_stop',
centroid=(5.0, 4.6)) centroid=(5.0, 4.6))
apiobj.add_placex(place_id=2, class_='highway', type='bus_stop', apiobj.add_placex(place_id=2, class_='highway', type='bus_stop',
@@ -44,7 +46,7 @@ def test_simple_near_search_in_placex(apiobj, coord, pid):
details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.001}) details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.001})
results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], details=details) results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5], details=details)
assert [r.place_id for r in results] == [pid] assert [r.place_id for r in results] == [pid]
@@ -52,7 +54,7 @@ def test_simple_near_search_in_placex(apiobj, coord, pid):
@pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2), @pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2),
('34.3, 56.4', 2), ('34.3, 56.4', 2),
('5.0, 4.59933', 1)]) ('5.0, 4.59933', 1)])
def test_simple_near_search_in_classtype(apiobj, coord, pid): def test_simple_near_search_in_classtype(apiobj, frontend, coord, pid):
apiobj.add_placex(place_id=1, class_='highway', type='bus_stop', apiobj.add_placex(place_id=1, class_='highway', type='bus_stop',
centroid=(5.0, 4.6)) centroid=(5.0, 4.6))
apiobj.add_placex(place_id=2, class_='highway', type='bus_stop', apiobj.add_placex(place_id=2, class_='highway', type='bus_stop',
@@ -61,7 +63,7 @@ def test_simple_near_search_in_classtype(apiobj, coord, pid):
details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.5}) details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.5})
results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], details=details) results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5], details=details)
assert [r.place_id for r in results] == [pid] assert [r.place_id for r in results] == [pid]
@@ -83,25 +85,25 @@ class TestPoiSearchWithRestrictions:
self.args = {'near': '34.3, 56.100021', 'near_radius': 0.001} self.args = {'near': '34.3, 56.100021', 'near_radius': 0.001}
def test_unrestricted(self, apiobj): def test_unrestricted(self, apiobj, frontend):
results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5],
details=SearchDetails.from_kwargs(self.args)) details=SearchDetails.from_kwargs(self.args))
assert [r.place_id for r in results] == [1, 2] assert [r.place_id for r in results] == [1, 2]
def test_restict_country(self, apiobj): def test_restict_country(self, apiobj, frontend):
results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5],
ccodes=['de', 'nz'], ccodes=['de', 'nz'],
details=SearchDetails.from_kwargs(self.args)) details=SearchDetails.from_kwargs(self.args))
assert [r.place_id for r in results] == [2] assert [r.place_id for r in results] == [2]
def test_restrict_by_viewbox(self, apiobj): def test_restrict_by_viewbox(self, apiobj, frontend):
args = {'bounded_viewbox': True, 'viewbox': '34.299,56.0,34.3001,56.10001'} args = {'bounded_viewbox': True, 'viewbox': '34.299,56.0,34.3001,56.10001'}
args.update(self.args) args.update(self.args)
results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], results = run_search(apiobj, frontend, 0.1, [('highway', 'bus_stop')], [0.5],
ccodes=['de', 'nz'], ccodes=['de', 'nz'],
details=SearchDetails.from_kwargs(args)) details=SearchDetails.from_kwargs(args))

View File

@@ -15,7 +15,7 @@ from nominatim.api.search.db_searches import PostcodeSearch
from nominatim.api.search.db_search_fields import WeightedStrings, FieldLookup, \ from nominatim.api.search.db_search_fields import WeightedStrings, FieldLookup, \
FieldRanking, RankedTokens FieldRanking, RankedTokens
def run_search(apiobj, global_penalty, pcs, pc_penalties=None, def run_search(apiobj, frontend, global_penalty, pcs, pc_penalties=None,
ccodes=[], lookup=[], ranking=[], details=SearchDetails()): ccodes=[], lookup=[], ranking=[], details=SearchDetails()):
if pc_penalties is None: if pc_penalties is None:
pc_penalties = [0.0] * len(pcs) pc_penalties = [0.0] * len(pcs)
@@ -29,28 +29,30 @@ def run_search(apiobj, global_penalty, pcs, pc_penalties=None,
search = PostcodeSearch(0.0, MySearchData()) search = PostcodeSearch(0.0, MySearchData())
api = frontend(apiobj, options=['search'])
async def run(): async def run():
async with apiobj.api._async_api.begin() as conn: async with api._async_api.begin() as conn:
return await search.lookup(conn, details) return await search.lookup(conn, details)
return apiobj.async_to_sync(run()) return api._loop.run_until_complete(run())
def test_postcode_only_search(apiobj): def test_postcode_only_search(apiobj, frontend):
apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345') apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345')
apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345') apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345')
results = run_search(apiobj, 0.3, ['12345', '12 345'], [0.0, 0.1]) results = run_search(apiobj, frontend, 0.3, ['12345', '12 345'], [0.0, 0.1])
assert len(results) == 2 assert len(results) == 2
assert [r.place_id for r in results] == [100, 101] assert [r.place_id for r in results] == [100, 101]
def test_postcode_with_country(apiobj): def test_postcode_with_country(apiobj, frontend):
apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345') apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345')
apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345') apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345')
results = run_search(apiobj, 0.3, ['12345', '12 345'], [0.0, 0.1], results = run_search(apiobj, frontend, 0.3, ['12345', '12 345'], [0.0, 0.1],
ccodes=['de', 'pl']) ccodes=['de', 'pl'])
assert len(results) == 1 assert len(results) == 1
@@ -81,30 +83,30 @@ class TestPostcodeSearchWithAddress:
country_code='pl') country_code='pl')
def test_lookup_both(self, apiobj): def test_lookup_both(self, apiobj, frontend):
lookup = FieldLookup('name_vector', [1,2], 'restrict') lookup = FieldLookup('name_vector', [1,2], 'restrict')
ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
results = run_search(apiobj, 0.1, ['12345'], lookup=[lookup], ranking=[ranking]) results = run_search(apiobj, frontend, 0.1, ['12345'], lookup=[lookup], ranking=[ranking])
assert [r.place_id for r in results] == [100, 101] assert [r.place_id for r in results] == [100, 101]
def test_restrict_by_name(self, apiobj): def test_restrict_by_name(self, apiobj, frontend):
lookup = FieldLookup('name_vector', [10], 'restrict') lookup = FieldLookup('name_vector', [10], 'restrict')
results = run_search(apiobj, 0.1, ['12345'], lookup=[lookup]) results = run_search(apiobj, frontend, 0.1, ['12345'], lookup=[lookup])
assert [r.place_id for r in results] == [100] assert [r.place_id for r in results] == [100]
@pytest.mark.parametrize('coord,place_id', [((16.5, 5), 100), @pytest.mark.parametrize('coord,place_id', [((16.5, 5), 100),
((-45.1, 7.004), 101)]) ((-45.1, 7.004), 101)])
def test_lookup_near(self, apiobj, coord, place_id): def test_lookup_near(self, apiobj, frontend, coord, place_id):
lookup = FieldLookup('name_vector', [1,2], 'restrict') lookup = FieldLookup('name_vector', [1,2], 'restrict')
ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])])
results = run_search(apiobj, 0.1, ['12345'], results = run_search(apiobj, frontend, 0.1, ['12345'],
lookup=[lookup], ranking=[ranking], lookup=[lookup], ranking=[ranking],
details=SearchDetails(near=napi.Point(*coord), details=SearchDetails(near=napi.Point(*coord),
near_radius=0.6)) near_radius=0.6))
@@ -116,8 +118,8 @@ class TestPostcodeSearchWithAddress:
napi.GeometryFormat.KML, napi.GeometryFormat.KML,
napi.GeometryFormat.SVG, napi.GeometryFormat.SVG,
napi.GeometryFormat.TEXT]) napi.GeometryFormat.TEXT])
def test_return_geometries(self, apiobj, geom): def test_return_geometries(self, apiobj, frontend, geom):
results = run_search(apiobj, 0.1, ['12345'], results = run_search(apiobj, frontend, 0.1, ['12345'],
details=SearchDetails(geometry_output=geom)) details=SearchDetails(geometry_output=geom))
assert results assert results
@@ -126,8 +128,8 @@ class TestPostcodeSearchWithAddress:
@pytest.mark.parametrize('viewbox, rids', [('-46,6,-44,8', [101,100]), @pytest.mark.parametrize('viewbox, rids', [('-46,6,-44,8', [101,100]),
('16,4,18,6', [100,101])]) ('16,4,18,6', [100,101])])
def test_prefer_viewbox(self, apiobj, viewbox, rids): def test_prefer_viewbox(self, apiobj, frontend, viewbox, rids):
results = run_search(apiobj, 0.1, ['12345'], results = run_search(apiobj, frontend, 0.1, ['12345'],
details=SearchDetails.from_kwargs({'viewbox': viewbox})) details=SearchDetails.from_kwargs({'viewbox': viewbox}))
assert [r.place_id for r in results] == rids assert [r.place_id for r in results] == rids
@@ -135,8 +137,8 @@ class TestPostcodeSearchWithAddress:
@pytest.mark.parametrize('viewbox, rid', [('-46,6,-44,8', 101), @pytest.mark.parametrize('viewbox, rid', [('-46,6,-44,8', 101),
('16,4,18,6', 100)]) ('16,4,18,6', 100)])
def test_restrict_to_viewbox(self, apiobj, viewbox, rid): def test_restrict_to_viewbox(self, apiobj, frontend, viewbox, rid):
results = run_search(apiobj, 0.1, ['12345'], results = run_search(apiobj, frontend, 0.1, ['12345'],
details=SearchDetails.from_kwargs({'viewbox': viewbox, details=SearchDetails.from_kwargs({'viewbox': viewbox,
'bounded_viewbox': True})) 'bounded_viewbox': True}))
@@ -145,16 +147,16 @@ class TestPostcodeSearchWithAddress:
@pytest.mark.parametrize('coord,rids', [((17.05, 5), [100, 101]), @pytest.mark.parametrize('coord,rids', [((17.05, 5), [100, 101]),
((-45, 7.1), [101, 100])]) ((-45, 7.1), [101, 100])])
def test_prefer_near(self, apiobj, coord, rids): def test_prefer_near(self, apiobj, frontend, coord, rids):
results = run_search(apiobj, 0.1, ['12345'], results = run_search(apiobj, frontend, 0.1, ['12345'],
details=SearchDetails(near=napi.Point(*coord))) details=SearchDetails(near=napi.Point(*coord)))
assert [r.place_id for r in results] == rids assert [r.place_id for r in results] == rids
@pytest.mark.parametrize('pid,rid', [(100, 101), (101, 100)]) @pytest.mark.parametrize('pid,rid', [(100, 101), (101, 100)])
def test_exclude(self, apiobj, pid, rid): def test_exclude(self, apiobj, frontend, pid, rid):
results = run_search(apiobj, 0.1, ['12345'], results = run_search(apiobj, frontend, 0.1, ['12345'],
details=SearchDetails(excluded=[pid])) details=SearchDetails(excluded=[pid]))
assert [r.place_id for r in results] == [rid] assert [r.place_id for r in results] == [rid]

View File

@@ -19,6 +19,8 @@ import sqlalchemy as sa
import nominatim.api as napi import nominatim.api as napi
import nominatim.api.logging as loglib import nominatim.api.logging as loglib
API_OPTIONS = {'search'}
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup_icu_tokenizer(apiobj): def setup_icu_tokenizer(apiobj):
""" Setup the propoerties needed for using the ICU tokenizer. """ Setup the propoerties needed for using the ICU tokenizer.
@@ -30,66 +32,62 @@ def setup_icu_tokenizer(apiobj):
]) ])
def test_search_no_content(apiobj, table_factory): def test_search_no_content(apiobj, frontend):
table_factory('word', apiobj.add_word_table([])
definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
assert apiobj.api.search('foo') == [] api = frontend(apiobj, options=API_OPTIONS)
assert api.search('foo') == []
def test_search_simple_word(apiobj, table_factory): def test_search_simple_word(apiobj, frontend):
table_factory('word', apiobj.add_word_table([(55, 'test', 'W', 'test', None),
definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
content=[(55, 'test', 'W', 'test', None),
(2, 'test', 'w', 'test', None)]) (2, 'test', 'w', 'test', None)])
apiobj.add_placex(place_id=444, class_='place', type='village', apiobj.add_placex(place_id=444, class_='place', type='village',
centroid=(1.3, 0.7)) centroid=(1.3, 0.7))
apiobj.add_search_name(444, names=[2, 55]) apiobj.add_search_name(444, names=[2, 55])
results = apiobj.api.search('TEST') api = frontend(apiobj, options=API_OPTIONS)
results = api.search('TEST')
assert [r.place_id for r in results] == [444] assert [r.place_id for r in results] == [444]
@pytest.mark.parametrize('logtype', ['text', 'html']) @pytest.mark.parametrize('logtype', ['text', 'html'])
def test_search_with_debug(apiobj, table_factory, logtype): def test_search_with_debug(apiobj, frontend, logtype):
table_factory('word', apiobj.add_word_table([(55, 'test', 'W', 'test', None),
definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
content=[(55, 'test', 'W', 'test', None),
(2, 'test', 'w', 'test', None)]) (2, 'test', 'w', 'test', None)])
apiobj.add_placex(place_id=444, class_='place', type='village', apiobj.add_placex(place_id=444, class_='place', type='village',
centroid=(1.3, 0.7)) centroid=(1.3, 0.7))
apiobj.add_search_name(444, names=[2, 55]) apiobj.add_search_name(444, names=[2, 55])
api = frontend(apiobj, options=API_OPTIONS)
loglib.set_log_output(logtype) loglib.set_log_output(logtype)
results = apiobj.api.search('TEST') results = api.search('TEST')
assert loglib.get_and_disable() assert loglib.get_and_disable()
def test_address_no_content(apiobj, table_factory): def test_address_no_content(apiobj, frontend):
table_factory('word', apiobj.add_word_table([])
definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
assert apiobj.api.search_address(amenity='hotel', api = frontend(apiobj, options=API_OPTIONS)
street='Main St 34', assert api.search_address(amenity='hotel',
city='Happyville', street='Main St 34',
county='Wideland', city='Happyville',
state='Praerie', county='Wideland',
postalcode='55648', state='Praerie',
country='xx') == [] postalcode='55648',
country='xx') == []
@pytest.mark.parametrize('atype,address,search', [('street', 26, 26), @pytest.mark.parametrize('atype,address,search', [('street', 26, 26),
('city', 16, 18), ('city', 16, 18),
('county', 12, 12), ('county', 12, 12),
('state', 8, 8)]) ('state', 8, 8)])
def test_address_simple_places(apiobj, table_factory, atype, address, search): def test_address_simple_places(apiobj, frontend, atype, address, search):
table_factory('word', apiobj.add_word_table([(55, 'test', 'W', 'test', None),
definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
content=[(55, 'test', 'W', 'test', None),
(2, 'test', 'w', 'test', None)]) (2, 'test', 'w', 'test', None)])
apiobj.add_placex(place_id=444, apiobj.add_placex(place_id=444,
@@ -97,53 +95,51 @@ def test_address_simple_places(apiobj, table_factory, atype, address, search):
centroid=(1.3, 0.7)) centroid=(1.3, 0.7))
apiobj.add_search_name(444, names=[2, 55], address_rank=address, search_rank=search) apiobj.add_search_name(444, names=[2, 55], address_rank=address, search_rank=search)
results = apiobj.api.search_address(**{atype: 'TEST'}) api = frontend(apiobj, options=API_OPTIONS)
results = api.search_address(**{atype: 'TEST'})
assert [r.place_id for r in results] == [444] assert [r.place_id for r in results] == [444]
def test_address_country(apiobj, table_factory): def test_address_country(apiobj, frontend):
table_factory('word', apiobj.add_word_table([(None, 'ro', 'C', 'ro', None)])
definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
content=[(None, 'ro', 'C', 'ro', None)])
apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))') apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))')
apiobj.add_country_name('ro', {'name': 'România'}) apiobj.add_country_name('ro', {'name': 'România'})
assert len(apiobj.api.search_address(country='ro')) == 1 api = frontend(apiobj, options=API_OPTIONS)
assert len(api.search_address(country='ro')) == 1
def test_category_no_categories(apiobj, table_factory): def test_category_no_categories(apiobj, frontend):
table_factory('word', apiobj.add_word_table([])
definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
assert apiobj.api.search_category([], near_query='Berlin') == [] api = frontend(apiobj, options=API_OPTIONS)
assert api.search_category([], near_query='Berlin') == []
def test_category_no_content(apiobj, table_factory): def test_category_no_content(apiobj, frontend):
table_factory('word', apiobj.add_word_table([])
definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
assert apiobj.api.search_category([('amenity', 'restaurant')]) == [] api = frontend(apiobj, options=API_OPTIONS)
assert api.search_category([('amenity', 'restaurant')]) == []
def test_category_simple_restaurant(apiobj, table_factory): def test_category_simple_restaurant(apiobj, frontend):
table_factory('word', apiobj.add_word_table([])
definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
apiobj.add_placex(place_id=444, class_='amenity', type='restaurant', apiobj.add_placex(place_id=444, class_='amenity', type='restaurant',
centroid=(1.3, 0.7)) centroid=(1.3, 0.7))
apiobj.add_search_name(444, names=[2, 55], address_rank=16, search_rank=18) apiobj.add_search_name(444, names=[2, 55], address_rank=16, search_rank=18)
results = apiobj.api.search_category([('amenity', 'restaurant')], api = frontend(apiobj, options=API_OPTIONS)
near=(1.3, 0.701), near_radius=0.015) results = api.search_category([('amenity', 'restaurant')],
near=(1.3, 0.701), near_radius=0.015)
assert [r.place_id for r in results] == [444] assert [r.place_id for r in results] == [444]
def test_category_with_search_phrase(apiobj, table_factory): def test_category_with_search_phrase(apiobj, frontend):
table_factory('word', apiobj.add_word_table([(55, 'test', 'W', 'test', None),
definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB',
content=[(55, 'test', 'W', 'test', None),
(2, 'test', 'w', 'test', None)]) (2, 'test', 'w', 'test', None)])
apiobj.add_placex(place_id=444, class_='place', type='village', apiobj.add_placex(place_id=444, class_='place', type='village',
@@ -153,7 +149,7 @@ def test_category_with_search_phrase(apiobj, table_factory):
apiobj.add_placex(place_id=95, class_='amenity', type='restaurant', apiobj.add_placex(place_id=95, class_='amenity', type='restaurant',
centroid=(1.3, 0.7003)) centroid=(1.3, 0.7003))
results = apiobj.api.search_category([('amenity', 'restaurant')], api = frontend(apiobj, options=API_OPTIONS)
near_query='TEST') results = api.search_category([('amenity', 'restaurant')], near_query='TEST')
assert [r.place_id for r in results] == [95] assert [r.place_id for r in results] == [95]