mirror of
https://github.com/osm-search/Nominatim.git
synced 2026-02-16 15:47:58 +00:00
implement BDD osm2pgsql tests with pytest-bdd
This commit is contained in:
@@ -9,6 +9,12 @@ Helper functions to compare expected values.
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import math
|
||||
import itertools
|
||||
|
||||
from psycopg import sql as pysql
|
||||
from psycopg.rows import dict_row, tuple_row
|
||||
from .geometry_alias import ALIASES
|
||||
|
||||
COMPARATOR_TERMS = {
|
||||
'exactly': lambda exp, act: exp == act,
|
||||
@@ -43,10 +49,12 @@ COMPARISON_FUNCS = {
|
||||
None: lambda val, exp: str(val) == exp,
|
||||
'i': lambda val, exp: str(val).lower() == exp.lower(),
|
||||
'fm': lambda val, exp: re.fullmatch(exp, val) is not None,
|
||||
'dict': lambda val, exp: val is None if exp == '-' else (val == eval('{' + exp + '}')),
|
||||
'in_box': within_box
|
||||
}
|
||||
|
||||
OSM_TYPE = {'node': 'n', 'way': 'w', 'relation': 'r'}
|
||||
OSM_TYPE = {'node': 'n', 'way': 'w', 'relation': 'r',
|
||||
'N': 'n', 'W': 'w', 'R': 'r'}
|
||||
|
||||
|
||||
class ResultAttr:
|
||||
@@ -60,12 +68,15 @@ class ResultAttr:
|
||||
|
||||
Available formatters:
|
||||
|
||||
!:... - use a formatting expression according to Python Mini Format Spec
|
||||
!i - make case-insensitive comparison
|
||||
!fm - consider comparison string a regular expression and match full value
|
||||
!:... - use a formatting expression according to Python Mini Format Spec
|
||||
!i - make case-insensitive comparison
|
||||
!fm - consider comparison string a regular expression and match full value
|
||||
!wkt - convert the expected value to a WKT string before comparing
|
||||
!in_box - the expected value is a comma-separated bbox description
|
||||
"""
|
||||
|
||||
def __init__(self, obj, key):
|
||||
def __init__(self, obj, key, grid=None):
|
||||
self.grid = grid
|
||||
self.obj = obj
|
||||
if '!' in key:
|
||||
self.key, self.fmt = key.rsplit('!', 1)
|
||||
@@ -100,6 +111,9 @@ class ResultAttr:
|
||||
if self.fmt.startswith(':'):
|
||||
return other == f"{{{self.fmt}}}".format(self.subobj)
|
||||
|
||||
if self.fmt == 'wkt':
|
||||
return self.compare_wkt(self.subobj, other)
|
||||
|
||||
raise RuntimeError(f"Unknown format string '{self.fmt}'.")
|
||||
|
||||
def __repr__(self):
|
||||
@@ -107,3 +121,125 @@ class ResultAttr:
|
||||
if self.fmt:
|
||||
k += '!' + self.fmt
|
||||
return f"result[{k}]({self.subobj})"
|
||||
|
||||
def compare_wkt(self, value, expected):
|
||||
""" Compare a WKT value against a compact geometry format.
|
||||
The function understands the following formats:
|
||||
|
||||
country:<country code>
|
||||
Point geometry guaranteed to be in the given country
|
||||
<P>
|
||||
Point geometry
|
||||
<P>,...,<P>
|
||||
Line geometry
|
||||
(<P>,...,<P>)
|
||||
Polygon geometry
|
||||
|
||||
<P> may either be a coordinate of the form '<x> <y>' or a single
|
||||
number. In the latter case it must refer to a point in
|
||||
a previously defined grid.
|
||||
"""
|
||||
m = re.fullmatch(r'(POINT)\(([0-9. -]*)\)', value) \
|
||||
or re.fullmatch(r'(LINESTRING)\(([0-9,. -]*)\)', value) \
|
||||
or re.fullmatch(r'(POLYGON)\(\(([0-9,. -]*)\)\)', value)
|
||||
if not m:
|
||||
return False
|
||||
|
||||
converted = [list(map(float, pt.split(' ', 1)))
|
||||
for pt in map(str.strip, m[2].split(','))]
|
||||
|
||||
if expected.startswith('country:'):
|
||||
ccode = geom[8:].upper()
|
||||
assert ccode in ALIASES, f"Geometry error: unknown country {ccode}"
|
||||
return m[1] == 'POINT' and \
|
||||
all(math.isclose(p1, p2) for p1, p2 in
|
||||
zip(converted[0], ALIASES[ccode]))
|
||||
|
||||
if ',' not in expected:
|
||||
return m[1] == 'POINT' and \
|
||||
all(math.isclose(p1, p2) for p1, p2 in
|
||||
zip(converted[0], self.get_point(expected)))
|
||||
|
||||
if '(' not in expected:
|
||||
return m[1] == 'LINESTRING' and \
|
||||
all(math.isclose(p1[0], p2[0]) and math.isclose(p1[1], p2[1]) for p1, p2 in
|
||||
zip(converted, (self.get_point(p) for p in expected.split(','))))
|
||||
|
||||
if m[1] != 'POLYGON':
|
||||
return False
|
||||
|
||||
# Polygon comparison is tricky because the polygons don't necessarily
|
||||
# end at the same point or have the same winding order.
|
||||
# Brute force all possible variants of the expected polygon
|
||||
exp_coords = [self.get_point(p) for p in expected[1:-1].split(',')]
|
||||
if exp_coords[0] != exp_coords[-1]:
|
||||
raise RuntimeError(f"Invalid polygon {expected}. "
|
||||
"First and last point need to be the same")
|
||||
for line in (exp_coords[:-1], exp_coords[-1:0:-1]):
|
||||
for i in range(len(line)):
|
||||
if all(math.isclose(p1[0], p2[0]) and math.isclose(p1[1], p2[1]) for p1, p2 in
|
||||
zip(converted, line[i:] + line[:i])):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_point(self, pt):
|
||||
pt = pt.strip()
|
||||
if ' ' in pt:
|
||||
return list(map(float, pt.split(' ', 1)))
|
||||
|
||||
assert self.grid
|
||||
|
||||
return self.grid.get(pt)
|
||||
|
||||
|
||||
def check_table_content(conn, tablename, data, grid=None, exact=False):
|
||||
lines = set(range(1, len(data)))
|
||||
|
||||
cols = []
|
||||
for col in data[0]:
|
||||
if col == 'object':
|
||||
cols.extend(('osm_id', 'osm_type'))
|
||||
elif '!' in col:
|
||||
name, fmt = col.rsplit('!', 1)
|
||||
if fmt == 'wkt':
|
||||
cols.append(f"ST_AsText({name}) as {name}")
|
||||
else:
|
||||
cols.append(name.split('+')[0])
|
||||
else:
|
||||
cols.append(col.split('+')[0])
|
||||
|
||||
with conn.cursor(row_factory=dict_row) as cur:
|
||||
cur.execute(pysql.SQL(f"SELECT {','.join(cols)} FROM")
|
||||
+ pysql.Identifier(tablename))
|
||||
|
||||
table_content = ''
|
||||
for row in cur:
|
||||
table_content += '\n' + str(row)
|
||||
for i in lines:
|
||||
for col, value in zip(data[0], data[i]):
|
||||
if ResultAttr(row, col, grid=grid) != value:
|
||||
break
|
||||
else:
|
||||
lines.remove(i)
|
||||
break
|
||||
else:
|
||||
assert not exact, f"Unexpected row in table {tablename}: {row}"
|
||||
|
||||
assert not lines, \
|
||||
"Rows not found:\n" \
|
||||
+ '\n'.join(str(data[i]) for i in lines) \
|
||||
+ "\nTable content:\n" \
|
||||
+ table_content
|
||||
|
||||
|
||||
def check_table_has_lines(conn, tablename, osm_type, osm_id, osm_class):
|
||||
sql = pysql.SQL("""SELECT count(*) FROM {}
|
||||
WHERE osm_type = %s and osm_id = %s""").format(pysql.Identifier(tablename))
|
||||
params = [osm_type, int(osm_id)]
|
||||
if osm_class:
|
||||
sql += pysql.SQL(' AND class = %s')
|
||||
params.append(osm_class)
|
||||
|
||||
with conn.cursor(row_factory=tuple_row) as cur:
|
||||
assert cur.execute(sql, params).fetchone()[0] == 0
|
||||
|
||||
Reference in New Issue
Block a user