implement BDD osm2pgsql tests with pytest-bdd

This commit is contained in:
Sarah Hoffmann
2025-03-31 09:39:01 +02:00
parent 0f725b1880
commit fb440f29a2
18 changed files with 2409 additions and 17 deletions

View File

@@ -9,6 +9,12 @@ Helper functions to compare expected values.
"""
import json
import re
import math
import itertools
from psycopg import sql as pysql
from psycopg.rows import dict_row, tuple_row
from .geometry_alias import ALIASES
COMPARATOR_TERMS = {
'exactly': lambda exp, act: exp == act,
@@ -43,10 +49,12 @@ COMPARISON_FUNCS = {
None: lambda val, exp: str(val) == exp,
'i': lambda val, exp: str(val).lower() == exp.lower(),
'fm': lambda val, exp: re.fullmatch(exp, val) is not None,
'dict': lambda val, exp: val is None if exp == '-' else (val == eval('{' + exp + '}')),
'in_box': within_box
}
OSM_TYPE = {'node': 'n', 'way': 'w', 'relation': 'r'}
OSM_TYPE = {'node': 'n', 'way': 'w', 'relation': 'r',
'N': 'n', 'W': 'w', 'R': 'r'}
class ResultAttr:
@@ -60,12 +68,15 @@ class ResultAttr:
Available formatters:
!:... - use a formatting expression according to Python Mini Format Spec
!i - make case-insensitive comparison
!fm - consider comparison string a regular expression and match full value
!:... - use a formatting expression according to Python Mini Format Spec
!i - make case-insensitive comparison
!fm - consider comparison string a regular expression and match full value
!wkt - convert the expected value to a WKT string before comparing
!in_box - the expected value is a comma-separated bbox description
"""
def __init__(self, obj, key):
def __init__(self, obj, key, grid=None):
self.grid = grid
self.obj = obj
if '!' in key:
self.key, self.fmt = key.rsplit('!', 1)
@@ -100,6 +111,9 @@ class ResultAttr:
if self.fmt.startswith(':'):
return other == f"{{{self.fmt}}}".format(self.subobj)
if self.fmt == 'wkt':
return self.compare_wkt(self.subobj, other)
raise RuntimeError(f"Unknown format string '{self.fmt}'.")
def __repr__(self):
@@ -107,3 +121,125 @@ class ResultAttr:
if self.fmt:
k += '!' + self.fmt
return f"result[{k}]({self.subobj})"
def compare_wkt(self, value, expected):
""" Compare a WKT value against a compact geometry format.
The function understands the following formats:
country:<country code>
Point geometry guaranteed to be in the given country
<P>
Point geometry
<P>,...,<P>
Line geometry
(<P>,...,<P>)
Polygon geometry
<P> may either be a coordinate of the form '<x> <y>' or a single
number. In the latter case it must refer to a point in
a previously defined grid.
"""
m = re.fullmatch(r'(POINT)\(([0-9. -]*)\)', value) \
or re.fullmatch(r'(LINESTRING)\(([0-9,. -]*)\)', value) \
or re.fullmatch(r'(POLYGON)\(\(([0-9,. -]*)\)\)', value)
if not m:
return False
converted = [list(map(float, pt.split(' ', 1)))
for pt in map(str.strip, m[2].split(','))]
if expected.startswith('country:'):
ccode = geom[8:].upper()
assert ccode in ALIASES, f"Geometry error: unknown country {ccode}"
return m[1] == 'POINT' and \
all(math.isclose(p1, p2) for p1, p2 in
zip(converted[0], ALIASES[ccode]))
if ',' not in expected:
return m[1] == 'POINT' and \
all(math.isclose(p1, p2) for p1, p2 in
zip(converted[0], self.get_point(expected)))
if '(' not in expected:
return m[1] == 'LINESTRING' and \
all(math.isclose(p1[0], p2[0]) and math.isclose(p1[1], p2[1]) for p1, p2 in
zip(converted, (self.get_point(p) for p in expected.split(','))))
if m[1] != 'POLYGON':
return False
# Polygon comparison is tricky because the polygons don't necessarily
# end at the same point or have the same winding order.
# Brute force all possible variants of the expected polygon
exp_coords = [self.get_point(p) for p in expected[1:-1].split(',')]
if exp_coords[0] != exp_coords[-1]:
raise RuntimeError(f"Invalid polygon {expected}. "
"First and last point need to be the same")
for line in (exp_coords[:-1], exp_coords[-1:0:-1]):
for i in range(len(line)):
if all(math.isclose(p1[0], p2[0]) and math.isclose(p1[1], p2[1]) for p1, p2 in
zip(converted, line[i:] + line[:i])):
return True
return False
def get_point(self, pt):
pt = pt.strip()
if ' ' in pt:
return list(map(float, pt.split(' ', 1)))
assert self.grid
return self.grid.get(pt)
def check_table_content(conn, tablename, data, grid=None, exact=False):
lines = set(range(1, len(data)))
cols = []
for col in data[0]:
if col == 'object':
cols.extend(('osm_id', 'osm_type'))
elif '!' in col:
name, fmt = col.rsplit('!', 1)
if fmt == 'wkt':
cols.append(f"ST_AsText({name}) as {name}")
else:
cols.append(name.split('+')[0])
else:
cols.append(col.split('+')[0])
with conn.cursor(row_factory=dict_row) as cur:
cur.execute(pysql.SQL(f"SELECT {','.join(cols)} FROM")
+ pysql.Identifier(tablename))
table_content = ''
for row in cur:
table_content += '\n' + str(row)
for i in lines:
for col, value in zip(data[0], data[i]):
if ResultAttr(row, col, grid=grid) != value:
break
else:
lines.remove(i)
break
else:
assert not exact, f"Unexpected row in table {tablename}: {row}"
assert not lines, \
"Rows not found:\n" \
+ '\n'.join(str(data[i]) for i in lines) \
+ "\nTable content:\n" \
+ table_content
def check_table_has_lines(conn, tablename, osm_type, osm_id, osm_class):
sql = pysql.SQL("""SELECT count(*) FROM {}
WHERE osm_type = %s and osm_id = %s""").format(pysql.Identifier(tablename))
params = [osm_type, int(osm_id)]
if osm_class:
sql += pysql.SQL(' AND class = %s')
params.append(osm_class)
with conn.cursor(row_factory=tuple_row) as cur:
assert cur.execute(sql, params).fetchone()[0] == 0

View File

@@ -7,9 +7,16 @@
"""
Helper functions for managing test databases.
"""
import asyncio
import psycopg
from psycopg import sql as pysql
from nominatim_db.tools.database_import import setup_database_skeleton, create_tables, \
create_partition_tables, create_search_indices
from nominatim_db.data.country_info import setup_country_tables
from nominatim_db.tools.refresh import create_functions, load_address_levels_from_config
from nominatim_db.tools.exec_utils import run_osm2pgsql
from nominatim_db.tokenizer import factory as tokenizer_factory
class DBManager:
@@ -42,3 +49,53 @@ class DBManager:
cur = conn.execute('select count(*) from pg_database where datname = %s',
(dbname,))
return cur.fetchone()[0] == 1
def create_db_from_template(self, dbname, template):
""" Create a new database from the given template database.
Any existing database with the same name will be dropped.
"""
with psycopg.connect(dbname='postgres') as conn:
conn.autocommit = True
conn.execute(pysql.SQL('DROP DATABASE IF EXISTS')
+ pysql.Identifier(dbname))
conn.execute(pysql.SQL('CREATE DATABASE {} WITH TEMPLATE {}')
.format(pysql.Identifier(dbname),
pysql.Identifier(template)))
def setup_template_db(self, config):
""" Create a template DB which contains the necessary extensions
and basic static tables.
The template will only be created if the database does not yet
exist or 'purge' is set.
"""
dsn = config.get_libpq_dsn()
if self.check_for_db(config.get_database_params()['dbname']):
return
setup_database_skeleton(dsn)
run_osm2pgsql(dict(osm2pgsql='osm2pgsql',
osm2pgsql_cache=1,
osm2pgsql_style=str(config.get_import_style_file()),
osm2pgsql_style_path=config.lib_dir.lua,
threads=1,
dsn=dsn,
flatnode_file='',
tablespaces=dict(slim_data='', slim_index='',
main_data='', main_index=''),
append=False,
import_data=b'<osm version="0.6"></osm>'))
setup_country_tables(dsn, config.lib_dir.data)
with psycopg.connect(dsn) as conn:
create_tables(conn, config)
load_address_levels_from_config(conn, config)
create_partition_tables(conn, config)
create_functions(conn, config, enable_diff_updates=False)
asyncio.run(create_search_indices(conn, config))
tokenizer_factory.create_tokenizer(config)

View File

@@ -0,0 +1,262 @@
# SPDX-License-Identifier: GPL-2.0-only
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2025 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Collection of aliases for various world coordinates.
"""
ALIASES = {
# Country aliases
'AD': (1.58972, 42.54241),
'AE': (54.61589, 24.82431),
'AF': (65.90264, 34.84708),
'AG': (-61.72430, 17.069),
'AI': (-63.10571, 18.25461),
'AL': (19.84941, 40.21232),
'AM': (44.64229, 40.37821),
'AO': (16.21924, -12.77014),
'AQ': (44.99999, -75.65695),
'AR': (-61.10759, -34.37615),
'AS': (-170.68470, -14.29307),
'AT': (14.25747, 47.36542),
'AU': (138.23155, -23.72068),
'AW': (-69.98255, 12.555),
'AX': (19.91839, 59.81682),
'AZ': (48.38555, 40.61639),
'BA': (17.18514, 44.25582),
'BB': (-59.53342, 13.19),
'BD': (89.75989, 24.34205),
'BE': (4.90078, 50.34682),
'BF': (-0.56743, 11.90471),
'BG': (24.80616, 43.09859),
'BH': (50.52032, 25.94685),
'BI': (29.54561, -2.99057),
'BJ': (2.70062, 10.02792),
'BL': (-62.79349, 17.907),
'BM': (-64.77406, 32.30199),
'BN': (114.52196, 4.28638),
'BO': (-62.02473, -17.77723),
'BQ': (-63.14322, 17.566),
'BR': (-45.77065, -9.58685),
'BS': (-77.60916, 23.8745),
'BT': (90.01350, 27.28137),
'BV': (3.35744, -54.4215),
'BW': (23.51505, -23.48391),
'BY': (26.77259, 53.15885),
'BZ': (-88.63489, 16.33951),
'CA': (-107.74817, 67.12612),
'CC': (96.84420, -12.01734),
'CD': (24.09544, -1.67713),
'CF': (22.58701, 5.98438),
'CG': (15.78875, 0.40388),
'CH': (7.65705, 46.57446),
'CI': (-6.31190, 6.62783),
'CK': (-159.77835, -21.23349),
'CL': (-70.41790, -53.77189),
'CM': (13.26022, 5.94519),
'CN': (96.44285, 38.04260),
'CO': (-72.52951, 2.45174),
'CR': (-83.83314, 9.93514),
'CU': (-80.81673, 21.88852),
'CV': (-24.50810, 14.929),
'CW': (-68.96409, 12.1845),
'CX': (105.62411, -10.48417),
'CY': (32.95922, 35.37010),
'CZ': (16.32098, 49.50692),
'DE': (9.30716, 50.21289),
'DJ': (42.96904, 11.41542),
'DK': (9.18490, 55.98916),
'DM': (-61.00358, 15.65470),
'DO': (-69.62855, 18.58841),
'DZ': (4.24749, 25.79721),
'EC': (-77.45831, -0.98284),
'EE': (23.94288, 58.43952),
'EG': (28.95293, 28.17718),
'EH': (-13.69031, 25.01241),
'ER': (39.01223, 14.96033),
'ES': (-2.59110, 38.79354),
'ET': (38.61697, 7.71399),
'FI': (26.89798, 63.56194),
'FJ': (177.91853, -17.74237),
'FK': (-58.99044, -51.34509),
'FM': (151.95358, 8.5045),
'FO': (-6.60483, 62.10000),
'FR': (0.28410, 47.51045),
'GA': (10.81070, -0.07429),
'GB': (-0.92823, 52.01618),
'GD': (-61.64524, 12.191),
'GE': (44.16664, 42.00385),
'GF': (-53.46524, 3.56188),
'GG': (-2.50580, 49.58543),
'GH': (-0.46348, 7.16051),
'GI': (-5.32053, 36.11066),
'GL': (-33.85511, 74.66355),
'GM': (-16.40960, 13.25),
'GN': (-13.83940, 10.96291),
'GP': (-61.68712, 16.23049),
'GQ': (10.23973, 1.43119),
'GR': (23.17850, 39.06206),
'GS': (-36.49430, -54.43067),
'GT': (-90.74368, 15.20428),
'GU': (144.73362, 13.44413),
'GW': (-14.83525, 11.92486),
'GY': (-58.45167, 5.73698),
'HK': (114.18577, 22.34923),
'HM': (73.68230, -53.22105),
'HN': (-86.95414, 15.23820),
'HR': (17.49966, 45.52689),
'HT': (-73.51925, 18.32492),
'HU': (20.35362, 47.51721),
'ID': (123.34505, -0.83791),
'IE': (-9.00520, 52.87725),
'IL': (35.46314, 32.86165),
'IM': (-4.86740, 54.023),
'IN': (88.67620, 27.86155),
'IO': (71.42743, -6.14349),
'IQ': (42.58109, 34.26103),
'IR': (56.09355, 30.46751),
'IS': (-17.51785, 64.71687),
'IT': (10.42639, 44.87904),
'JE': (-2.19261, 49.12458),
'JM': (-76.84020, 18.3935),
'JO': (36.55552, 30.75741),
'JP': (138.72531, 35.92099),
'KE': (36.90602, 1.08512),
'KG': (76.15571, 41.66497),
'KH': (104.31901, 12.95555),
'KI': (173.63353, 0.139),
'KM': (44.31474, -12.241),
'KN': (-62.69379, 17.2555),
'KP': (126.65575, 39.64575),
'KR': (127.27740, 36.41388),
'KW': (47.30684, 29.69180),
'KY': (-81.07455, 19.29949),
'KZ': (72.00811, 49.88855),
'LA': (102.44391, 19.81609),
'LB': (35.48464, 33.41766),
'LC': (-60.97894, 13.891),
'LI': (9.54693, 47.15934),
'LK': (80.38520, 8.41649),
'LR': (-11.16960, 4.04122),
'LS': (28.66984, -29.94538),
'LT': (24.51735, 55.49293),
'LU': (6.08649, 49.81533),
'LV': (23.51033, 56.67144),
'LY': (15.36841, 28.12177),
'MA': (-4.03061, 33.21696),
'MC': (7.47743, 43.62917),
'MD': (29.61725, 46.66517),
'ME': (19.72291, 43.02441),
'MF': (-63.06666, 18.08102),
'MG': (45.86378, -20.50245),
'MH': (171.94982, 5.983),
'MK': (21.42108, 41.08980),
'ML': (-1.93310, 16.46993),
'MM': (95.54624, 21.09620),
'MN': (99.81138, 48.18615),
'MO': (113.56441, 22.16209),
'MP': (145.21345, 14.14902),
'MQ': (-60.81128, 14.43706),
'MR': (-9.42324, 22.59251),
'MS': (-62.19455, 16.745),
'MT': (14.38363, 35.94467),
'MU': (57.55121, -20.41),
'MV': (73.39292, 4.19375),
'MW': (33.95722, -12.28218),
'MX': (-105.89221, 25.86826),
'MY': (112.71154, 2.10098),
'MZ': (37.58689, -13.72682),
'NA': (16.68569, -21.46572),
'NC': (164.95322, -20.38889),
'NE': (10.06041, 19.08273),
'NF': (167.95718, -29.0645),
'NG': (10.17781, 10.17804),
'NI': (-85.87974, 13.21715),
'NL': (-68.57062, 12.041),
'NO': (23.11556, 70.09934),
'NP': (83.36259, 28.13107),
'NR': (166.93479, -0.5275),
'NU': (-169.84873, -19.05305),
'NZ': (167.97209, -45.13056),
'OM': (56.86055, 20.47413),
'PA': (-79.40160, 8.80656),
'PE': (-78.66540, -7.54711),
'PF': (-145.05719, -16.70862),
'PG': (146.64600, -7.37427),
'PH': (121.48359, 15.09965),
'PK': (72.11347, 31.14629),
'PL': (17.88136, 52.77182),
'PM': (-56.19515, 46.78324),
'PN': (-130.10642, -25.06955),
'PR': (-65.88755, 18.37169),
'PS': (35.39801, 32.24773),
'PT': (-8.45743, 40.11154),
'PW': (134.49645, 7.3245),
'PY': (-59.51787, -22.41281),
'QA': (51.49903, 24.99816),
'RE': (55.77345, -21.36388),
'RO': (26.37632, 45.36120),
'RS': (20.40371, 44.56413),
'RU': (116.44060, 59.06780),
'RW': (29.57882, -1.62404),
'SA': (47.73169, 22.43790),
'SB': (164.63894, -10.23606),
'SC': (46.36566, -9.454),
'SD': (28.14720, 14.56423),
'SE': (15.68667, 60.35568),
'SG': (103.84187, 1.304),
'SH': (-12.28155, -37.11546),
'SI': (14.04738, 46.39085),
'SJ': (15.27552, 79.23365),
'SK': (20.41603, 48.86970),
'SL': (-11.47773, 8.78156),
'SM': (12.46062, 43.94279),
'SN': (-15.37111, 14.99477),
'SO': (46.93383, 9.34094),
'SR': (-55.42864, 4.56985),
'SS': (28.13573, 8.50933),
'ST': (6.61025, 0.2215),
'SV': (-89.36665, 13.43072),
'SX': (-63.15393, 17.9345),
'SY': (38.15513, 35.34221),
'SZ': (31.78263, -26.14244),
'TC': (-71.32554, 21.35),
'TD': (17.42092, 13.46223),
'TF': (137.5, -67.5),
'TG': (1.06983, 7.87677),
'TH': (102.00877, 16.42310),
'TJ': (71.91349, 39.01527),
'TK': (-171.82603, -9.20990),
'TL': (126.22520, -8.72636),
'TM': (57.71603, 39.92534),
'TN': (9.04958, 34.84199),
'TO': (-176.99320, -23.11104),
'TR': (32.82002, 39.86350),
'TT': (-60.70793, 11.1385),
'TV': (178.77499, -9.41685),
'TW': (120.30074, 23.17002),
'TZ': (33.53892, -5.01840),
'UA': (33.44335, 49.30619),
'UG': (32.96523, 2.08584),
'UM': (-169.50993, 16.74605),
'US': (-116.39535, 40.71379),
'UY': (-56.46505, -33.62658),
'UZ': (61.35529, 42.96107),
'VA': (12.33197, 42.04931),
'VC': (-61.09905, 13.316),
'VE': (-64.88323, 7.69849),
'VG': (-64.62479, 18.419),
'VI': (-64.88950, 18.32263),
'VN': (104.20179, 10.27644),
'VU': (167.31919, -15.88687),
'WF': (-176.20781, -13.28535),
'WS': (-172.10966, -13.85093),
'YE': (45.94562, 16.16338),
'YT': (44.93774, -12.60882),
'ZA': (23.19488, -30.43276),
'ZM': (26.38618, -14.39966),
'ZW': (30.12419, -19.86907)
}

34
test/bdd/utils/grid.py Normal file
View File

@@ -0,0 +1,34 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2025 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
A grid describing node placement in an area.
Useful for visually describing geometries.
"""
class Grid:
def __init__(self, table, step, origin):
if step is None:
step = 0.00001
if origin is None:
origin = (0.0, 0.0)
self.grid = {}
y = origin[1]
for line in table:
x = origin[0]
for pt_id in line:
if pt_id:
self.grid[pt_id] = (x, y)
x += step
y += step
def get(self, nodeid):
""" Get the coordinates for the given grid node.
"""
return self.grid.get(nodeid)