Files
Nominatim/test/bdd/utils/checks.py
Itz-Agasta 9fa980bca2 Replaces eval with json.loads for safer dict parsing
Switches from eval to json.loads when parsing string representations
of dictionaries to  prevent arbitrary code
execution.
2026-01-24 15:32:47 +05:30

244 lines
8.2 KiB
Python

# SPDX-License-Identifier: GPL-2.0-only
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2025 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Helper functions to compare expected values.
"""
import collections.abc
import json
import re
import math
from psycopg import sql as pysql
from psycopg.rows import dict_row
from .geometry_alias import ALIASES
COMPARATOR_TERMS = {
'exactly': lambda exp, act: exp == act,
'more than': lambda exp, act: act > exp,
'less than': lambda exp, act: act < exp,
}
def _pretty(obj):
return json.dumps(obj, sort_keys=True, indent=2)
def _pt_close(p1, p2):
return math.isclose(p1[0], p2[0], abs_tol=1e-07) \
and math.isclose(p1[1], p2[1], abs_tol=1e-07)
def within_box(value, expect):
coord = [float(x) for x in expect.split(',')]
if isinstance(value, str):
if value.startswith('POINT'):
value = value[6:-1].split(' ')
else:
value = value.split(',')
value = list(map(float, value))
if len(value) == 2:
return coord[0] <= value[0] <= coord[2] \
and coord[1] <= value[1] <= coord[3]
if len(value) == 4:
return value[0] >= coord[0] and value[1] <= coord[1] \
and value[2] >= coord[2] and value[3] <= coord[3]
raise ValueError("Not a coordinate or bbox.")
COMPARISON_FUNCS = {
None: lambda val, exp: str(val) == exp,
'i': lambda val, exp: str(val).lower() == exp.lower(),
'fm': lambda val, exp: re.fullmatch(exp, val) is not None,
'dict': lambda val, exp: val is None if exp == '-' else (val == json.loads('{' + exp + '}')),
'in_box': within_box
}
OSM_TYPE = {'node': 'n', 'way': 'w', 'relation': 'r',
'N': 'n', 'W': 'w', 'R': 'r'}
class ResultAttr:
""" Returns the given attribute as a string.
The key parameter determines how the value is formatted before
returning. To refer to sub attributes, use '+' to add more keys
(e.g. 'name+ref' will access obj['name']['ref']). A '!' introduces
a formatting suffix. If no suffix is given, the value will be
converted using the str() function.
Available formatters:
!:... - use a formatting expression according to Python Mini Format Spec
!i - make case-insensitive comparison
!fm - consider comparison string a regular expression and match full value
!wkt - convert the expected value to a WKT string before comparing
!in_box - the expected value is a comma-separated bbox description
"""
def __init__(self, obj, key, grid=None):
self.grid = grid
self.obj = obj
if '!' in key:
self.key, self.fmt = key.rsplit('!', 1)
else:
self.key = key
self.fmt = None
if self.key == 'object':
assert 'osm_id' in obj
assert 'osm_type' in obj
self.subobj = OSM_TYPE[obj['osm_type']] + str(obj['osm_id'])
self.fmt = 'i'
else:
done = ''
self.subobj = self.obj
for sub in self.key.split('+'):
done += f"[{sub}]"
if isinstance(self.subobj, collections.abc.Sequence) and sub.isdigit():
sub = int(sub)
assert sub < len(self.subobj), \
f"Out of bound index {done}. Full object:\n{_pretty(self.obj)}"
else:
assert sub in self.subobj, \
f"Missing attribute {done}. Full object:\n{_pretty(self.obj)}"
self.subobj = self.subobj[sub]
def __eq__(self, other):
# work around bad quoting by pytest-bdd
if not isinstance(other, str):
return self.subobj == other
other = other.replace(r'\\', '\\')
if self.fmt in COMPARISON_FUNCS:
return COMPARISON_FUNCS[self.fmt](self.subobj, other)
if self.fmt.startswith(':'):
return other == f"{{{self.fmt}}}".format(self.subobj)
if self.fmt == 'wkt':
return self.compare_wkt(self.subobj, other)
raise RuntimeError(f"Unknown format string '{self.fmt}'.")
def __repr__(self):
k = self.key.replace('+', '][')
if self.fmt:
k += '!' + self.fmt
return f"result[{k}]({self.subobj})"
def compare_wkt(self, value, expected):
""" Compare a WKT value against a compact geometry format.
The function understands the following formats:
country:<country code>
Point geometry guaranteed to be in the given country
<P>
Point geometry
<P>,...,<P>
Line geometry
(<P>,...,<P>)
Polygon geometry
<P> may either be a coordinate of the form '<x> <y>' or a single
number. In the latter case it must refer to a point in
a previously defined grid.
"""
m = re.fullmatch(r'(POINT)\(([0-9. -]*)\)', value) \
or re.fullmatch(r'(LINESTRING)\(([0-9,. -]*)\)', value) \
or re.fullmatch(r'(POLYGON)\(\(([0-9,. -]*)\)\)', value)
if not m:
return False
converted = [list(map(float, pt.split(' ', 1)))
for pt in map(str.strip, m[2].split(','))]
if expected.startswith('country:'):
ccode = expected[8:].upper()
assert ccode in ALIASES, f"Geometry error: unknown country {ccode}"
return m[1] == 'POINT' and _pt_close(converted[0], ALIASES[ccode])
if ',' not in expected:
return m[1] == 'POINT' and _pt_close(converted[0], self.get_point(expected))
if '(' not in expected:
return m[1] == 'LINESTRING' and \
all(_pt_close(p1, p2) for p1, p2 in
zip(converted, (self.get_point(p) for p in expected.split(','))))
if m[1] != 'POLYGON':
return False
# Polygon comparison is tricky because the polygons don't necessarily
# end at the same point or have the same winding order.
# Brute force all possible variants of the expected polygon
exp_coords = [self.get_point(p) for p in expected[1:-1].split(',')]
if exp_coords[0] != exp_coords[-1]:
raise RuntimeError(f"Invalid polygon {expected}. "
"First and last point need to be the same")
for line in (exp_coords[:-1], exp_coords[-1:0:-1]):
for i in range(len(line)):
if all(_pt_close(p1, p2) for p1, p2 in
zip(converted, line[i:] + line[:i])):
return True
return False
def get_point(self, pt):
pt = pt.strip()
if ' ' in pt:
return list(map(float, pt.split(' ', 1)))
assert self.grid
return self.grid.get(pt)
def check_table_content(conn, tablename, data, grid=None, exact=False):
lines = set(range(1, len(data)))
cols = []
for col in data[0]:
if col == 'object':
cols.extend(('osm_id', 'osm_type'))
elif '!' in col:
name, fmt = col.rsplit('!', 1)
if fmt in ('wkt', 'in_box'):
cols.append(f"ST_AsText({name}) as {name}")
else:
cols.append(name.split('+')[0])
else:
cols.append(col.split('+')[0])
with conn.cursor(row_factory=dict_row) as cur:
cur.execute(pysql.SQL(f"SELECT {','.join(cols)} FROM")
+ pysql.Identifier(tablename))
table_content = ''
for row in cur:
table_content += '\n' + str(row)
for i in lines:
for col, value in zip(data[0], data[i]):
if ResultAttr(row, col, grid=grid) != (None if value == '-' else value):
break
else:
lines.remove(i)
break
else:
assert not exact, f"Unexpected row in table {tablename}: {row}"
assert not lines, \
"Rows not found:\n" \
+ '\n'.join(str(data[i]) for i in lines) \
+ "\nTable content:\n" \
+ table_content