move search to bind parameters

This commit is contained in:
Sarah Hoffmann
2023-06-26 15:56:10 +02:00
parent 6c4c9ec1f2
commit 9f6f12cfeb
6 changed files with 117 additions and 88 deletions

View File

@@ -528,7 +528,7 @@ class ReverseGeocoder:
log().function('reverse_lookup', coord=coord, params=self.params) log().function('reverse_lookup', coord=coord, params=self.params)
self.bind_params['wkt'] = f'SRID=4326;POINT({coord[0]} {coord[1]})' self.bind_params['wkt'] = f'POINT({coord[0]} {coord[1]})'
row: Optional[SaRow] = None row: Optional[SaRow] = None
row_func: RowFunc = nres.create_from_placex_row row_func: RowFunc = nres.create_from_placex_row

View File

@@ -7,7 +7,7 @@
""" """
Implementation of the acutal database accesses for forward search. Implementation of the acutal database accesses for forward search.
""" """
from typing import List, Tuple, AsyncIterator from typing import List, Tuple, AsyncIterator, Dict, Any
import abc import abc
import sqlalchemy as sa import sqlalchemy as sa
@@ -19,10 +19,36 @@ from nominatim.api.connection import SearchConnection
from nominatim.api.types import SearchDetails, DataLayer, GeometryFormat, Bbox from nominatim.api.types import SearchDetails, DataLayer, GeometryFormat, Bbox
import nominatim.api.results as nres import nominatim.api.results as nres
from nominatim.api.search.db_search_fields import SearchData, WeightedCategories from nominatim.api.search.db_search_fields import SearchData, WeightedCategories
from nominatim.db.sqlalchemy_types import Geometry
#pylint: disable=singleton-comparison,not-callable #pylint: disable=singleton-comparison,not-callable
#pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements #pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements
def _details_to_bind_params(details: SearchDetails) -> Dict[str, Any]:
""" Create a dictionary from search parameters that can be used
as bind parameter for SQL execute.
"""
return {'limit': details.max_results,
'min_rank': details.min_rank,
'max_rank': details.max_rank,
'viewbox': details.viewbox,
'viewbox2': details.viewbox_x2,
'near': details.near,
'near_radius': details.near_radius,
'excluded': details.excluded,
'countries': details.countries}
LIMIT_PARAM = sa.bindparam('limit')
MIN_RANK_PARAM = sa.bindparam('min_rank')
MAX_RANK_PARAM = sa.bindparam('max_rank')
VIEWBOX_PARAM = sa.bindparam('viewbox', type_=Geometry)
VIEWBOX2_PARAM = sa.bindparam('viewbox2', type_=Geometry)
NEAR_PARAM = sa.bindparam('near', type_=Geometry)
NEAR_RADIUS_PARAM = sa.bindparam('near_radius')
EXCLUDED_PARAM = sa.bindparam('excluded')
COUNTRIES_PARAM = sa.bindparam('countries')
def _select_placex(t: SaFromClause) -> SaSelect: def _select_placex(t: SaFromClause) -> SaSelect:
return sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name, return sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
t.c.class_, t.c.type, t.c.class_, t.c.type,
@@ -41,16 +67,16 @@ def _add_geometry_columns(sql: SaSelect, col: SaColumn, details: SearchDetails)
out = [] out = []
if details.geometry_simplification > 0.0: if details.geometry_simplification > 0.0:
col = col.ST_SimplifyPreserveTopology(details.geometry_simplification) col = sa.func.ST_SimplifyPreserveTopology(col, details.geometry_simplification)
if details.geometry_output & GeometryFormat.GEOJSON: if details.geometry_output & GeometryFormat.GEOJSON:
out.append(col.ST_AsGeoJSON().label('geometry_geojson')) out.append(sa.func.ST_AsGeoJSON(col).label('geometry_geojson'))
if details.geometry_output & GeometryFormat.TEXT: if details.geometry_output & GeometryFormat.TEXT:
out.append(col.ST_AsText().label('geometry_text')) out.append(sa.func.ST_AsText(col).label('geometry_text'))
if details.geometry_output & GeometryFormat.KML: if details.geometry_output & GeometryFormat.KML:
out.append(col.ST_AsKML().label('geometry_kml')) out.append(sa.func.ST_AsKML(col).label('geometry_kml'))
if details.geometry_output & GeometryFormat.SVG: if details.geometry_output & GeometryFormat.SVG:
out.append(col.ST_AsSVG().label('geometry_svg')) out.append(sa.func.ST_AsSVG(col).label('geometry_svg'))
return sql.add_columns(*out) return sql.add_columns(*out)
@@ -70,7 +96,7 @@ def _make_interpolation_subquery(table: SaFromClause, inner: SaFromClause,
for n in numerals))) for n in numerals)))
if details.excluded: if details.excluded:
sql = sql.where(table.c.place_id.not_in(details.excluded)) sql = sql.where(table.c.place_id.not_in(EXCLUDED_PARAM))
return sql.scalar_subquery() return sql.scalar_subquery()
@@ -129,7 +155,7 @@ async def _get_placex_housenumbers(conn: SearchConnection,
for row in await conn.execute(sql): for row in await conn.execute(sql):
result = nres.create_from_placex_row(row, nres.SearchResult) result = nres.create_from_placex_row(row, nres.SearchResult)
assert result assert result
result.bbox = Bbox.from_wkb(row.bbox.data) result.bbox = Bbox.from_wkb(row.bbox)
yield result yield result
@@ -259,28 +285,25 @@ class NearSearch(AbstractSearch):
sql = sql.join(table, t.c.place_id == table.c.place_id)\ sql = sql.join(table, t.c.place_id == table.c.place_id)\
.join(tgeom, .join(tgeom,
sa.case((sa.and_(tgeom.c.rank_address < 9, sa.case((sa.and_(tgeom.c.rank_address < 9,
tgeom.c.geometry.ST_GeometryType().in_( tgeom.c.geometry.is_area()),
('ST_Polygon', 'ST_MultiPolygon'))),
tgeom.c.geometry.ST_Contains(table.c.centroid)), tgeom.c.geometry.ST_Contains(table.c.centroid)),
else_ = tgeom.c.centroid.ST_DWithin(table.c.centroid, 0.05)))\ else_ = tgeom.c.centroid.ST_DWithin(table.c.centroid, 0.05)))\
.order_by(tgeom.c.centroid.ST_Distance(table.c.centroid)) .order_by(tgeom.c.centroid.ST_Distance(table.c.centroid))
sql = sql.where(t.c.rank_address.between(MIN_RANK_PARAM, MAX_RANK_PARAM))
if details.countries: if details.countries:
sql = sql.where(t.c.country_code.in_(details.countries)) sql = sql.where(t.c.country_code.in_(COUNTRIES_PARAM))
if details.min_rank > 0:
sql = sql.where(t.c.rank_address >= details.min_rank)
if details.max_rank < 30:
sql = sql.where(t.c.rank_address <= details.max_rank)
if details.excluded: if details.excluded:
sql = sql.where(t.c.place_id.not_in(details.excluded)) sql = sql.where(t.c.place_id.not_in(EXCLUDED_PARAM))
if details.layers is not None: if details.layers is not None:
sql = sql.where(_filter_by_layer(t, details.layers)) sql = sql.where(_filter_by_layer(t, details.layers))
for row in await conn.execute(sql.limit(details.max_results)): sql = sql.limit(LIMIT_PARAM)
for row in await conn.execute(sql, _details_to_bind_params(details)):
result = nres.create_from_placex_row(row, nres.SearchResult) result = nres.create_from_placex_row(row, nres.SearchResult)
assert result assert result
result.accuracy = self.penalty + penalty result.accuracy = self.penalty + penalty
result.bbox = Bbox.from_wkb(row.bbox.data) result.bbox = Bbox.from_wkb(row.bbox)
results.append(result) results.append(result)
@@ -298,6 +321,7 @@ class PoiSearch(AbstractSearch):
details: SearchDetails) -> nres.SearchResults: details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database. """ Find results for the search in the database.
""" """
bind_params = _details_to_bind_params(details)
t = conn.t.placex t = conn.t.placex
rows: List[SaRow] = [] rows: List[SaRow] = []
@@ -306,15 +330,14 @@ class PoiSearch(AbstractSearch):
# simply search in placex table # simply search in placex table
sql = _select_placex(t) \ sql = _select_placex(t) \
.where(t.c.linked_place_id == None) \ .where(t.c.linked_place_id == None) \
.where(t.c.geometry.ST_DWithin(details.near.sql_value(), .where(t.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM)) \
details.near_radius)) \ .order_by(t.c.centroid.ST_Distance(NEAR_PARAM))
.order_by(t.c.centroid.ST_Distance(details.near.sql_value()))
if self.countries: if self.countries:
sql = sql.where(t.c.country_code.in_(self.countries.values)) sql = sql.where(t.c.country_code.in_(self.countries.values))
if details.viewbox is not None and details.bounded_viewbox: if details.viewbox is not None and details.bounded_viewbox:
sql = sql.where(t.c.geometry.intersects(details.viewbox.sql_value())) sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM))
classtype = self.categories.values classtype = self.categories.values
if len(classtype) == 1: if len(classtype) == 1:
@@ -324,7 +347,8 @@ class PoiSearch(AbstractSearch):
sql = sql.where(sa.or_(*(sa.and_(t.c.class_ == cls, t.c.type == typ) sql = sql.where(sa.or_(*(sa.and_(t.c.class_ == cls, t.c.type == typ)
for cls, typ in classtype))) for cls, typ in classtype)))
rows.extend(await conn.execute(sql.limit(details.max_results))) sql = sql.limit(LIMIT_PARAM)
rows.extend(await conn.execute(sql, bind_params))
else: else:
# use the class type tables # use the class type tables
for category in self.categories.values: for category in self.categories.values:
@@ -336,24 +360,25 @@ class PoiSearch(AbstractSearch):
.where(t.c.type == category[1]) .where(t.c.type == category[1])
if details.viewbox is not None and details.bounded_viewbox: if details.viewbox is not None and details.bounded_viewbox:
sql = sql.where(table.c.centroid.intersects(details.viewbox.sql_value())) sql = sql.where(table.c.centroid.intersects(VIEWBOX_PARAM))
if details.near: if details.near and details.near_radius is not None:
sql = sql.order_by(table.c.centroid.ST_Distance(details.near.sql_value()))\ sql = sql.order_by(table.c.centroid.ST_Distance(NEAR_PARAM))\
.where(table.c.centroid.ST_DWithin(details.near.sql_value(), .where(table.c.centroid.ST_DWithin(NEAR_PARAM,
details.near_radius or 0.5)) NEAR_RADIUS_PARAM))
if self.countries: if self.countries:
sql = sql.where(t.c.country_code.in_(self.countries.values)) sql = sql.where(t.c.country_code.in_(self.countries.values))
rows.extend(await conn.execute(sql.limit(details.max_results))) sql = sql.limit(LIMIT_PARAM)
rows.extend(await conn.execute(sql, bind_params))
results = nres.SearchResults() results = nres.SearchResults()
for row in rows: for row in rows:
result = nres.create_from_placex_row(row, nres.SearchResult) result = nres.create_from_placex_row(row, nres.SearchResult)
assert result assert result
result.accuracy = self.penalty + self.categories.get_penalty((row.class_, row.type)) result.accuracy = self.penalty + self.categories.get_penalty((row.class_, row.type))
result.bbox = Bbox.from_wkb(row.bbox.data) result.bbox = Bbox.from_wkb(row.bbox)
results.append(result) results.append(result)
return results return results
@@ -380,17 +405,16 @@ class CountrySearch(AbstractSearch):
sql = _add_geometry_columns(sql, t.c.geometry, details) sql = _add_geometry_columns(sql, t.c.geometry, details)
if details.excluded: if details.excluded:
sql = sql.where(t.c.place_id.not_in(details.excluded)) sql = sql.where(t.c.place_id.not_in(EXCLUDED_PARAM))
if details.viewbox is not None and details.bounded_viewbox: if details.viewbox is not None and details.bounded_viewbox:
sql = sql.where(t.c.geometry.intersects(details.viewbox.sql_value())) sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM))
if details.near is not None and details.near_radius is not None: if details.near is not None and details.near_radius is not None:
sql = sql.where(t.c.geometry.ST_DWithin(details.near.sql_value(), sql = sql.where(t.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM))
details.near_radius))
results = nres.SearchResults() results = nres.SearchResults()
for row in await conn.execute(sql): for row in await conn.execute(sql, _details_to_bind_params(details)):
result = nres.create_from_placex_row(row, nres.SearchResult) result = nres.create_from_placex_row(row, nres.SearchResult)
assert result assert result
result.accuracy = self.penalty + self.countries.get_penalty(row.country_code, 5.0) result.accuracy = self.penalty + self.countries.get_penalty(row.country_code, 5.0)
@@ -419,10 +443,9 @@ class CountrySearch(AbstractSearch):
.group_by(tgrid.c.country_code) .group_by(tgrid.c.country_code)
if details.viewbox is not None and details.bounded_viewbox: if details.viewbox is not None and details.bounded_viewbox:
sql = sql.where(tgrid.c.geometry.intersects(details.viewbox.sql_value())) sql = sql.where(tgrid.c.geometry.intersects(VIEWBOX_PARAM))
if details.near is not None and details.near_radius is not None: if details.near is not None and details.near_radius is not None:
sql = sql.where(tgrid.c.geometry.ST_DWithin(details.near.sql_value(), sql = sql.where(tgrid.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM))
details.near_radius))
sub = sql.subquery('grid') sub = sql.subquery('grid')
@@ -435,7 +458,7 @@ class CountrySearch(AbstractSearch):
.join(sub, t.c.country_code == sub.c.country_code) .join(sub, t.c.country_code == sub.c.country_code)
results = nres.SearchResults() results = nres.SearchResults()
for row in await conn.execute(sql): for row in await conn.execute(sql, _details_to_bind_params(details)):
result = nres.create_from_country_row(row, nres.SearchResult) result = nres.create_from_country_row(row, nres.SearchResult)
assert result assert result
result.accuracy = self.penalty + self.countries.get_penalty(row.country_code, 5.0) result.accuracy = self.penalty + self.countries.get_penalty(row.country_code, 5.0)
@@ -474,23 +497,22 @@ class PostcodeSearch(AbstractSearch):
if details.viewbox is not None: if details.viewbox is not None:
if details.bounded_viewbox: if details.bounded_viewbox:
sql = sql.where(t.c.geometry.intersects(details.viewbox.sql_value())) sql = sql.where(t.c.geometry.intersects(VIEWBOX_PARAM))
else: else:
penalty += sa.case((t.c.geometry.intersects(details.viewbox.sql_value()), 0.0), penalty += sa.case((t.c.geometry.intersects(VIEWBOX_PARAM), 0.0),
(t.c.geometry.intersects(details.viewbox_x2.sql_value()), 1.0), (t.c.geometry.intersects(VIEWBOX2_PARAM), 1.0),
else_=2.0) else_=2.0)
if details.near is not None: if details.near is not None:
if details.near_radius is not None: if details.near_radius is not None:
sql = sql.where(t.c.geometry.ST_DWithin(details.near.sql_value(), sql = sql.where(t.c.geometry.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM))
details.near_radius)) sql = sql.order_by(t.c.geometry.ST_Distance(NEAR_PARAM))
sql = sql.order_by(t.c.geometry.ST_Distance(details.near.sql_value()))
if self.countries: if self.countries:
sql = sql.where(t.c.country_code.in_(self.countries.values)) sql = sql.where(t.c.country_code.in_(self.countries.values))
if details.excluded: if details.excluded:
sql = sql.where(t.c.place_id.not_in(details.excluded)) sql = sql.where(t.c.place_id.not_in(EXCLUDED_PARAM))
if self.lookups: if self.lookups:
assert len(self.lookups) == 1 assert len(self.lookups) == 1
@@ -509,10 +531,10 @@ class PostcodeSearch(AbstractSearch):
sql = sql.add_columns(penalty.label('accuracy')) sql = sql.add_columns(penalty.label('accuracy'))
sql = sql.order_by('accuracy') sql = sql.order_by('accuracy').limit(LIMIT_PARAM)
results = nres.SearchResults() results = nres.SearchResults()
for row in await conn.execute(sql.limit(details.max_results)): for row in await conn.execute(sql, _details_to_bind_params(details)):
result = nres.create_from_postcode_row(row, nres.SearchResult) result = nres.create_from_postcode_row(row, nres.SearchResult)
assert result assert result
result.accuracy = row.accuracy result.accuracy = row.accuracy
@@ -542,7 +564,6 @@ class PlaceSearch(AbstractSearch):
""" """
t = conn.t.placex.alias('p') t = conn.t.placex.alias('p')
tsearch = conn.t.search_name.alias('s') tsearch = conn.t.search_name.alias('s')
limit = details.max_results
sql = sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name, sql = sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name,
t.c.class_, t.c.type, t.c.class_, t.c.type,
@@ -587,17 +608,16 @@ class PlaceSearch(AbstractSearch):
if details.viewbox is not None: if details.viewbox is not None:
if details.bounded_viewbox: if details.bounded_viewbox:
sql = sql.where(tsearch.c.centroid.intersects(details.viewbox.sql_value())) sql = sql.where(tsearch.c.centroid.intersects(VIEWBOX_PARAM))
else: else:
penalty += sa.case((t.c.geometry.intersects(details.viewbox.sql_value()), 0.0), penalty += sa.case((t.c.geometry.intersects(VIEWBOX_PARAM), 0.0),
(t.c.geometry.intersects(details.viewbox_x2.sql_value()), 1.0), (t.c.geometry.intersects(VIEWBOX2_PARAM), 1.0),
else_=2.0) else_=2.0)
if details.near is not None: if details.near is not None:
if details.near_radius is not None: if details.near_radius is not None:
sql = sql.where(tsearch.c.centroid.ST_DWithin(details.near.sql_value(), sql = sql.where(tsearch.c.centroid.ST_DWithin(NEAR_PARAM, NEAR_RADIUS_PARAM))
details.near_radius)) sql = sql.add_columns(-tsearch.c.centroid.ST_Distance(NEAR_PARAM)
sql = sql.add_columns(-tsearch.c.centroid.ST_Distance(details.near.sql_value())
.label('importance')) .label('importance'))
sql = sql.order_by(sa.desc(sa.text('importance'))) sql = sql.order_by(sa.desc(sa.text('importance')))
else: else:
@@ -629,7 +649,7 @@ class PlaceSearch(AbstractSearch):
.where(thnr.c.indexed_status == 0) .where(thnr.c.indexed_status == 0)
if details.excluded: if details.excluded:
place_sql = place_sql.where(thnr.c.place_id.not_in(details.excluded)) place_sql = place_sql.where(thnr.c.place_id.not_in(EXCLUDED_PARAM))
if self.qualifiers: if self.qualifiers:
place_sql = place_sql.where(self.qualifiers.sql_restrict(thnr)) place_sql = place_sql.where(self.qualifiers.sql_restrict(thnr))
@@ -665,22 +685,23 @@ class PlaceSearch(AbstractSearch):
if self.qualifiers: if self.qualifiers:
sql = sql.where(self.qualifiers.sql_restrict(t)) sql = sql.where(self.qualifiers.sql_restrict(t))
if details.excluded: if details.excluded:
sql = sql.where(tsearch.c.place_id.not_in(details.excluded)) sql = sql.where(tsearch.c.place_id.not_in(EXCLUDED_PARAM))
if details.min_rank > 0: if details.min_rank > 0:
sql = sql.where(sa.or_(tsearch.c.address_rank >= details.min_rank, sql = sql.where(sa.or_(tsearch.c.address_rank >= MIN_RANK_PARAM,
tsearch.c.search_rank >= details.min_rank)) tsearch.c.search_rank >= MIN_RANK_PARAM))
if details.max_rank < 30: if details.max_rank < 30:
sql = sql.where(sa.or_(tsearch.c.address_rank <= details.max_rank, sql = sql.where(sa.or_(tsearch.c.address_rank <= MAX_RANK_PARAM,
tsearch.c.search_rank <= details.max_rank)) tsearch.c.search_rank <= MAX_RANK_PARAM))
if details.layers is not None: if details.layers is not None:
sql = sql.where(_filter_by_layer(t, details.layers)) sql = sql.where(_filter_by_layer(t, details.layers))
sql = sql.limit(LIMIT_PARAM)
results = nres.SearchResults() results = nres.SearchResults()
for row in await conn.execute(sql.limit(limit)): for row in await conn.execute(sql, _details_to_bind_params(details)):
result = nres.create_from_placex_row(row, nres.SearchResult) result = nres.create_from_placex_row(row, nres.SearchResult)
assert result assert result
result.bbox = Bbox.from_wkb(row.bbox.data) result.bbox = Bbox.from_wkb(row.bbox)
result.accuracy = row.accuracy result.accuracy = row.accuracy
if not details.excluded or not result.place_id in details.excluded: if not details.excluded or not result.place_id in details.excluded:
results.append(result) results.append(result)

View File

@@ -79,7 +79,7 @@ class Point(NamedTuple):
if isinstance(wkb, str): if isinstance(wkb, str):
wkb = unhexlify(wkb) wkb = unhexlify(wkb)
if len(wkb) != 25: if len(wkb) != 25:
raise ValueError("Point wkb has unexpected length") raise ValueError(f"Point wkb has unexpected length {len(wkb)}")
if wkb[0] == 0: if wkb[0] == 0:
gtype, srid, x, y = unpack('>iidd', wkb[1:]) gtype, srid, x, y = unpack('>iidd', wkb[1:])
elif wkb[0] == 1: elif wkb[0] == 1:
@@ -124,8 +124,8 @@ class Point(NamedTuple):
return Point(x, y) return Point(x, y)
def sql_value(self) -> str: def to_wkt(self) -> str:
""" Create an SQL expression for the point. """ Return the WKT representation of the point.
""" """
return f'POINT({self.x} {self.y})' return f'POINT({self.x} {self.y})'
@@ -181,12 +181,6 @@ class Bbox:
return (self.coords[2] - self.coords[0]) * (self.coords[3] - self.coords[1]) return (self.coords[2] - self.coords[0]) * (self.coords[3] - self.coords[1])
def sql_value(self) -> Any:
""" Create an SQL expression for the box.
"""
return sa.func.ST_MakeEnvelope(*self.coords, 4326)
def contains(self, pt: Point) -> bool: def contains(self, pt: Point) -> bool:
""" Check if the point is inside or on the boundary of the box. """ Check if the point is inside or on the boundary of the box.
""" """
@@ -194,6 +188,13 @@ class Bbox:
and self.coords[2] >= pt[0] and self.coords[3] >= pt[1] and self.coords[2] >= pt[0] and self.coords[3] >= pt[1]
def to_wkt(self) -> str:
""" Return the WKT representation of the Bbox. This
is a simple polygon with four points.
"""
return 'POLYGON(({0} {1},{0} {3},{2} {3},{2} {1},{0} {1}))'.format(*self.coords)
@staticmethod @staticmethod
def from_wkb(wkb: Union[None, str, bytes]) -> 'Optional[Bbox]': def from_wkb(wkb: Union[None, str, bytes]) -> 'Optional[Bbox]':
""" Create a Bbox from a bounding box polygon as returned by """ Create a Bbox from a bounding box polygon as returned by
@@ -451,6 +452,8 @@ class SearchDetails(LookupDetails):
yext = (self.viewbox.maxlat - self.viewbox.minlat)/2 yext = (self.viewbox.maxlat - self.viewbox.minlat)/2
self.viewbox_x2 = Bbox(self.viewbox.minlon - xext, self.viewbox.minlat - yext, self.viewbox_x2 = Bbox(self.viewbox.minlon - xext, self.viewbox.minlat - yext,
self.viewbox.maxlon + xext, self.viewbox.maxlat + yext) self.viewbox.maxlon + xext, self.viewbox.maxlat + yext)
else:
self.viewbox_x2 = None
def restrict_min_max_rank(self, new_min: int, new_max: int) -> None: def restrict_min_max_rank(self, new_min: int, new_max: int) -> None:

View File

@@ -30,8 +30,10 @@ class Geometry(types.UserDefinedType[Any]):
def bind_processor(self, dialect: sa.Dialect) -> Callable[[Any], str]: def bind_processor(self, dialect: sa.Dialect) -> Callable[[Any], str]:
def process(value: Any) -> str: def process(value: Any) -> str:
assert isinstance(value, str) if isinstance(value, str):
return value return 'SRID=4326;' + value
return 'SRID=4326;' + value.to_wkt()
return process return process
@@ -84,6 +86,10 @@ class Geometry(types.UserDefinedType[Any]):
return sa.func.ST_Expand(self, other, type_=Geometry) return sa.func.ST_Expand(self, other, type_=Geometry)
def ST_Collect(self) -> SaColumn:
return sa.func.ST_Collect(self, type_=Geometry)
def ST_Centroid(self) -> SaColumn: def ST_Centroid(self) -> SaColumn:
return sa.func.ST_Centroid(self, type_=Geometry) return sa.func.ST_Centroid(self, type_=Geometry)

View File

@@ -66,11 +66,11 @@ class APITester:
'rank_search': kw.get('rank_search', 30), 'rank_search': kw.get('rank_search', 30),
'rank_address': kw.get('rank_address', 30), 'rank_address': kw.get('rank_address', 30),
'importance': kw.get('importance'), 'importance': kw.get('importance'),
'centroid': 'SRID=4326;POINT(%f %f)' % centroid, 'centroid': 'POINT(%f %f)' % centroid,
'indexed_status': kw.get('indexed_status', 0), 'indexed_status': kw.get('indexed_status', 0),
'indexed_date': kw.get('indexed_date', 'indexed_date': kw.get('indexed_date',
dt.datetime(2022, 12, 7, 14, 14, 46, 0)), dt.datetime(2022, 12, 7, 14, 14, 46, 0)),
'geometry': 'SRID=4326;' + geometry}) 'geometry': geometry})
def add_address_placex(self, object_id, **kw): def add_address_placex(self, object_id, **kw):
@@ -97,7 +97,7 @@ class APITester:
'address': kw.get('address'), 'address': kw.get('address'),
'postcode': kw.get('postcode'), 'postcode': kw.get('postcode'),
'country_code': kw.get('country_code'), 'country_code': kw.get('country_code'),
'linegeo': 'SRID=4326;' + kw.get('geometry', 'LINESTRING(1.1 -0.2, 1.09 -0.22)')}) 'linegeo': kw.get('geometry', 'LINESTRING(1.1 -0.2, 1.09 -0.22)')})
def add_tiger(self, **kw): def add_tiger(self, **kw):
@@ -108,7 +108,7 @@ class APITester:
'endnumber': kw.get('endnumber', 6), 'endnumber': kw.get('endnumber', 6),
'step': kw.get('step', 2), 'step': kw.get('step', 2),
'postcode': kw.get('postcode'), 'postcode': kw.get('postcode'),
'linegeo': 'SRID=4326;' + kw.get('geometry', 'LINESTRING(1.1 -0.2, 1.09 -0.22)')}) 'linegeo': kw.get('geometry', 'LINESTRING(1.1 -0.2, 1.09 -0.22)')})
def add_postcode(self, **kw): def add_postcode(self, **kw):
@@ -121,14 +121,14 @@ class APITester:
'rank_address': kw.get('rank_address', 22), 'rank_address': kw.get('rank_address', 22),
'indexed_date': kw.get('indexed_date', 'indexed_date': kw.get('indexed_date',
dt.datetime(2022, 12, 7, 14, 14, 46, 0)), dt.datetime(2022, 12, 7, 14, 14, 46, 0)),
'geometry': 'SRID=4326;' + kw.get('geometry', 'POINT(23 34)')}) 'geometry': kw.get('geometry', 'POINT(23 34)')})
def add_country(self, country_code, geometry): def add_country(self, country_code, geometry):
self.add_data('country_grid', self.add_data('country_grid',
{'country_code': country_code, {'country_code': country_code,
'area': 0.1, 'area': 0.1,
'geometry': 'SRID=4326;' + geometry}) 'geometry': geometry})
def add_country_name(self, country_code, names, partition=0): def add_country_name(self, country_code, names, partition=0):
@@ -148,7 +148,7 @@ class APITester:
'name_vector': kw.get('names', []), 'name_vector': kw.get('names', []),
'nameaddress_vector': kw.get('address', []), 'nameaddress_vector': kw.get('address', []),
'country_code': kw.get('country_code', 'xx'), 'country_code': kw.get('country_code', 'xx'),
'centroid': 'SRID=4326;POINT(%f %f)' % centroid}) 'centroid': 'POINT(%f %f)' % centroid})
def add_class_type_table(self, cls, typ): def add_class_type_table(self, cls, typ):

View File

@@ -8,6 +8,7 @@
Tests for result datatype helper functions. Tests for result datatype helper functions.
""" """
import struct import struct
from binascii import hexlify
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@@ -17,10 +18,8 @@ import sqlalchemy as sa
from nominatim.api import SourceTable, DetailedResult, Point from nominatim.api import SourceTable, DetailedResult, Point
import nominatim.api.results as nresults import nominatim.api.results as nresults
class FakeCentroid: def mkpoint(x, y):
def __init__(self, x, y): return hexlify(struct.pack("=biidd", 1, 0x20000001, 4326, x, y)).decode('utf-8')
self.data = struct.pack("=biidd", 1, 0x20000001, 4326,
x, y)
class FakeRow: class FakeRow:
def __init__(self, **kwargs): def __init__(self, **kwargs):
@@ -60,7 +59,7 @@ def test_create_row_none(func):
def test_create_row_with_housenumber(func): def test_create_row_with_housenumber(func):
row = FakeRow(place_id=2345, osm_type='W', osm_id=111, housenumber=4, row = FakeRow(place_id=2345, osm_type='W', osm_id=111, housenumber=4,
address=None, postcode='99900', country_code='xd', address=None, postcode='99900', country_code='xd',
centroid=FakeCentroid(0, 0)) centroid=mkpoint(0, 0))
res = func(row, DetailedResult) res = func(row, DetailedResult)
@@ -75,7 +74,7 @@ def test_create_row_without_housenumber(func):
row = FakeRow(place_id=2345, osm_type='W', osm_id=111, row = FakeRow(place_id=2345, osm_type='W', osm_id=111,
startnumber=1, endnumber=11, step=2, startnumber=1, endnumber=11, step=2,
address=None, postcode='99900', country_code='xd', address=None, postcode='99900', country_code='xd',
centroid=FakeCentroid(0, 0)) centroid=mkpoint(0, 0))
res = func(row, DetailedResult) res = func(row, DetailedResult)