bdd: move column comparison in separate file

Introduces a new class DBRow that encapsulates the comparison
functions. This also is responsible for formatting more informative
assert messages. place and placex steps are unified.
This commit is contained in:
Sarah Hoffmann
2021-01-06 12:28:09 +01:00
parent d586b95ff1
commit 1f29475fa5
2 changed files with 194 additions and 176 deletions

View File

@@ -1,44 +1,8 @@
import re
import psycopg2.extras import psycopg2.extras
from check_functions import Almost
from place_inserter import PlaceColumn from place_inserter import PlaceColumn
from table_compare import NominatimID from table_compare import NominatimID, DBRow
class PlaceObjName(object):
def __init__(self, placeid, conn):
self.pid = placeid
self.conn = conn
def __str__(self):
if self.pid is None:
return "<null>"
if self.pid == 0:
return "place ID 0"
cur = self.conn.cursor()
cur.execute("""SELECT osm_type, osm_id, class
FROM placex WHERE place_id = %s""",
(self.pid, ))
assert cur.rowcount == 1, "No entry found for place id %s" % self.pid
return "%s%s:%s" % cur.fetchone()
def compare_place_id(expected, result, column, context):
if expected == '0':
assert result == 0, \
"Bad place id in column {}. Expected: 0, got: {!s}.".format(
column, PlaceObjName(result, context.db))
elif expected == '-':
assert result is None, \
"Bad place id in column {}: {!s}.".format(
column, PlaceObjName(result, context.db))
else:
assert NominatimID(expected).get_place_id(context.db.cursor()) == result, \
"Bad place id in column {}. Expected: {}, got: {!s}.".format(
column, expected, PlaceObjName(result, context.db))
def check_database_integrity(context): def check_database_integrity(context):
""" Check some generic constraints on the tables. """ Check some generic constraints on the tables.
@@ -52,37 +16,6 @@ def check_database_integrity(context):
assert cur.fetchone()[0] == 0, "Duplicates found in place_addressline" assert cur.fetchone()[0] == 0, "Duplicates found in place_addressline"
def assert_db_column(row, column, value, context):
if column == 'object':
return
if column.startswith('centroid'):
if value == 'in geometry':
query = """SELECT ST_Within(ST_SetSRID(ST_Point({}, {}), 4326),
ST_SetSRID('{}'::geometry, 4326))""".format(
row['cx'], row['cy'], row['geomtxt'])
cur = context.db.cursor()
cur.execute(query)
assert cur.fetchone()[0], "(Row %s failed: %s)" % (column, query)
else:
fac = float(column[9:]) if column.startswith('centroid*') else 1.0
x, y = value.split(' ')
assert Almost(float(x) * fac) == row['cx'], "Bad x coordinate"
assert Almost(float(y) * fac) == row['cy'], "Bad y coordinate"
elif column == 'geometry':
geom = context.osm.parse_geometry(value, context.scene)
cur = context.db.cursor()
query = "SELECT ST_Equals(ST_SnapToGrid(%s, 0.00001, 0.00001), ST_SnapToGrid(ST_SetSRID('%s'::geometry, 4326), 0.00001, 0.00001))" % (
geom, row['geomtxt'],)
cur.execute(query)
assert cur.fetchone()[0], "(Row %s failed: %s)" % (column, query)
elif value == '-':
assert row[column] is None, "Row %s" % column
else:
assert value == str(row[column]), \
"Row '%s': expected: %s, got: %s" % (column, value, str(row[column]))
################################ GIVEN ################################## ################################ GIVEN ##################################
@given("the (?P<named>named )?places") @given("the (?P<named>named )?places")
@@ -178,100 +111,49 @@ def delete_places(context, oids):
################################ THEN ################################## ################################ THEN ##################################
@then("placex contains(?P<exact> exactly)?") @then("(?P<table>placex|place) contains(?P<exact> exactly)?")
def check_placex_contents(context, exact): def check_place_contents(context, table, exact):
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
expected_content = set() expected_content = set()
for row in context.table: for row in context.table:
NominatimID(row['object']).query_osm_id(cur, nid = NominatimID(row['object'])
"""SELECT *, ST_AsText(geometry) as geomtxt, query = 'SELECT *, ST_AsText(geometry) as geomtxt, ST_GeometryType(geometry) as geometrytype'
ST_X(centroid) as cx, ST_Y(centroid) as cy if table == 'placex':
FROM placex WHERE {}""") query += ' ,ST_X(centroid) as cx, ST_Y(centroid) as cy'
query += " FROM %s WHERE {}" % (table, )
nid.query_osm_id(cur, query)
assert cur.rowcount > 0, "No rows found for " + row['object'] assert cur.rowcount > 0, "No rows found for " + row['object']
for res in cur: for res in cur:
if exact: if exact:
expected_content.add((res['osm_type'], res['osm_id'], res['class'])) expected_content.add((res['osm_type'], res['osm_id'], res['class']))
for h in row.headings:
if h in ('extratags', 'address'): DBRow(nid, res, context).assert_row(row, ['object'])
if row[h] == '-':
assert res[h] is None
else:
vdict = eval('{' + row[h] + '}')
assert vdict == res[h]
elif h.startswith('name'):
name = h[5:] if h.startswith('name+') else 'name'
assert name in res['name']
assert res['name'][name] == row[h]
elif h.startswith('extratags+'):
assert res['extratags'][h[10:]] == row[h]
elif h.startswith('addr+'):
if row[h] == '-':
if res['address'] is not None:
assert h[5:] not in res['address']
else:
assert h[5:] in res['address'], "column " + h
assert res['address'][h[5:]] == row[h], "column %s" % h
elif h in ('linked_place_id', 'parent_place_id'):
compare_place_id(row[h], res[h], h, context)
else:
assert_db_column(res, h, row[h], context)
if exact: if exact:
cur.execute('SELECT osm_type, osm_id, class from placex') cur.execute('SELECT osm_type, osm_id, class from {}'.format(table))
assert expected_content == set([(r[0], r[1], r[2]) for r in cur]) assert expected_content == set([(r[0], r[1], r[2]) for r in cur])
@then("place contains(?P<exact> exactly)?")
def check_placex_contents(context, exact): @then("(?P<table>placex|place) has no entry for (?P<oid>.*)")
def check_place_has_entry(context, table, oid):
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
expected_content = set() NominatimID(oid).query_osm_id(cur, "SELECT * FROM %s where {}" % table)
for row in context.table: assert cur.rowcount == 0, \
NominatimID(row['object']).query_osm_id(cur, "Found {} entries for ID {}".format(cur.rowcount, oid)
"""SELECT *, ST_AsText(geometry) as geomtxt,
ST_GeometryType(geometry) as geometrytype
FROM place WHERE {}""")
assert cur.rowcount > 0, "No rows found for " + row['object']
for res in cur:
if exact:
expected_content.add((res['osm_type'], res['osm_id'], res['class']))
for h in row.headings:
msg = "%s: %s" % (row['object'], h)
if h in ('name', 'extratags', 'address'):
if row[h] == '-':
assert res[h] is None, msg
else:
vdict = eval('{' + row[h] + '}')
assert vdict == res[h], msg
elif h.startswith('name+'):
assert res['name'][h[5:]] == row[h], msg
elif h.startswith('extratags+'):
assert res['extratags'][h[10:]] == row[h], msg
elif h.startswith('addr+'):
if row[h] == '-':
if res['address'] is not None:
assert h[5:] not in res['address']
else:
assert res['address'][h[5:]] == row[h], msg
elif h in ('linked_place_id', 'parent_place_id'):
compare_place_id(row[h], res[h], h, context)
else:
assert_db_column(res, h, row[h], context)
if exact:
cur.execute('SELECT osm_type, osm_id, class from place')
assert expected_content, set([(r[0], r[1], r[2]) for r in cur])
@then("search_name contains(?P<exclude> not)?") @then("search_name contains(?P<exclude> not)?")
def check_search_name_contents(context, exclude): def check_search_name_contents(context, exclude):
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
for row in context.table: for row in context.table:
NominatimID(row['object']).query_place_id(cur, nid = NominatimID(row['object'])
"""SELECT *, ST_X(centroid) as cx, ST_Y(centroid) as cy nid.row_by_place_id(cur, 'search_name',
FROM search_name WHERE place_id = %s""") ['ST_X(centroid) as cx', 'ST_Y(centroid) as cy'])
assert cur.rowcount > 0, "No rows found for " + row['object'] assert cur.rowcount > 0, "No rows found for " + row['object']
for res in cur: for res in cur:
db_row = DBRow(nid, res, context)
for h in row.headings: for h in row.headings:
if h in ('name_vector', 'nameaddress_vector'): if h in ('name_vector', 'nameaddress_vector'):
terms = [x.strip() for x in row[h].split(',') if not x.strip().startswith('#')] terms = [x.strip() for x in row[h].split(',') if not x.strip().startswith('#')]
@@ -298,8 +180,8 @@ def check_search_name_contents(context, exclude):
assert wid[0] not in res[h], "Found term for %s/%s: %s" % (row['object'], h, wid[1]) assert wid[0] not in res[h], "Found term for %s/%s: %s" % (row['object'], h, wid[1])
else: else:
assert wid[0] in res[h], "Missing term for %s/%s: %s" % (row['object'], h, wid[1]) assert wid[0] in res[h], "Missing term for %s/%s: %s" % (row['object'], h, wid[1])
else: elif h != 'object':
assert_db_column(res, h, row[h], context) assert db_row.contains(h, row[h]), db_row.assert_msg(h, row[h])
@then("location_postcode contains exactly") @then("location_postcode contains exactly")
def check_location_postcode(context): def check_location_postcode(context):
@@ -308,15 +190,18 @@ def check_location_postcode(context):
assert cur.rowcount == len(list(context.table)), \ assert cur.rowcount == len(list(context.table)), \
"Postcode table has %d rows, expected %d rows." % (cur.rowcount, len(list(context.table))) "Postcode table has %d rows, expected %d rows." % (cur.rowcount, len(list(context.table)))
table = list(cur) results = {}
for row in cur:
key = (row['country_code'], row['postcode'])
assert key not in results, "Postcode table has duplicate entry: {}".format(row)
results[key] = DBRow((row['country_code'],row['postcode']), row, context)
for row in context.table: for row in context.table:
for i in range(len(table)): db_row = results.get((row['country'],row['postcode']))
if table[i]['country_code'] != row['country'] \ assert db_row is not None, \
or table[i]['postcode'] != row['postcode']: "Missing row for country '{}' postcode '{}'.".format(r['country'],['postcode'])
continue
for h in row.headings: db_row.assert_row(row, ('country', 'postcode'))
if h not in ('country', 'postcode'):
assert_db_column(table[i], h, row[h], context)
@then("word contains(?P<exclude> not)?") @then("word contains(?P<exclude> not)?")
def check_word_table(context, exclude): def check_word_table(context, exclude):
@@ -337,7 +222,8 @@ def check_word_table(context, exclude):
def check_place_addressline(context): def check_place_addressline(context):
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
for row in context.table: for row in context.table:
pid = NominatimID(row['object']).get_place_id(cur) nid = NominatimID(row['object'])
pid = nid.get_place_id(cur)
apid = NominatimID(row['address']).get_place_id(cur) apid = NominatimID(row['address']).get_place_id(cur)
cur.execute(""" SELECT * FROM place_addressline cur.execute(""" SELECT * FROM place_addressline
WHERE place_id = %s AND address_place_id = %s""", WHERE place_id = %s AND address_place_id = %s""",
@@ -346,9 +232,7 @@ def check_place_addressline(context):
"No rows found for place %s and address %s" % (row['object'], row['address']) "No rows found for place %s and address %s" % (row['object'], row['address'])
for res in cur: for res in cur:
for h in row.headings: DBRow(nid, res, context).assert_row(row, ('address', 'object'))
if h not in ('address', 'object'):
assert_db_column(res, h, row[h], context)
@then("place_addressline doesn't contain") @then("place_addressline doesn't contain")
def check_place_addressline_exclude(context): def check_place_addressline_exclude(context):
@@ -389,28 +273,15 @@ def check_location_property_osmline(context, oid, neg):
else: else:
assert False, "Unexpected row %s" % (str(res)) assert False, "Unexpected row %s" % (str(res))
for h in row.headings: DBRow(nid, res, context).assert_row(row, ('start', 'end'))
if h in ('start', 'end'):
continue
elif h == 'parent_place_id':
compare_place_id(row[h], res[h], h, context)
else:
assert_db_column(res, h, row[h], context)
assert not todo assert not todo
@then("(?P<table>placex|place) has no entry for (?P<oid>.*)")
def check_placex_has_entry(context, table, oid):
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
NominatimID(oid).query_osm_id(cur, "SELECT * FROM %s where {}" % table)
assert cur.rowcount == 0, \
"Found {} entries for ID {}".format(cur.rowcount, oid)
@then("search_name has no entry for (?P<oid>.*)") @then("search_name has no entry for (?P<oid>.*)")
def check_search_name_has_entry(context, oid): def check_search_name_has_entry(context, oid):
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
NominatimID(oid).query_place_id(cur, NominatimID(oid).row_by_place_id(cur, 'search_name')
"SELECT * FROM search_name WHERE place_id = %s")
assert cur.rowcount == 0, \ assert cur.rowcount == 0, \
"Found {} entries for ID {}".format(cur.rowcount, oid) "Found {} entries for ID {}".format(cur.rowcount, oid)

View File

@@ -2,6 +2,9 @@
Functions to facilitate accessing and comparing the content of DB tables. Functions to facilitate accessing and comparing the content of DB tables.
""" """
import re import re
import json
from steps.check_functions import Almost
ID_REGEX = re.compile(r"(?P<typ>[NRW])(?P<oid>\d+)(:(?P<cls>\w+))?") ID_REGEX = re.compile(r"(?P<typ>[NRW])(?P<oid>\d+)(:(?P<cls>\w+))?")
@@ -41,15 +44,17 @@ class NominatimID:
where += ' and class = %s' where += ' and class = %s'
params.append(self.cls) params.append(self.cls)
return cur.execute(query.format(where), params) cur.execute(query.format(where), params)
def query_place_id(self, cur, query): def row_by_place_id(self, cur, table, extra_columns=None):
""" Run a query on cursor `cur` using the place ID. The `query` string """ Get a row by place_id from the given table using cursor `cur`.
must contain exactly one placeholder '%s' where the 'where' query extra_columns may contain a list additional elements for the select
should go. part of the query.
""" """
pid = self.get_place_id(cur) pid = self.get_place_id(cur)
return cur.execute(query, (pid, )) query = "SELECT {} FROM {} WHERE place_id = %s".format(
','.join(['*'] + (extra_columns or [])), table)
cur.execute(query, (pid, ))
def get_place_id(self, cur): def get_place_id(self, cur):
""" Look up the place id for the ID. Throws an assertion if the ID """ Look up the place id for the ID. Throws an assertion if the ID
@@ -60,3 +65,145 @@ class NominatimID:
"Place ID {!s} not unique. Found {} entries.".format(self, cur.rowcount) "Place ID {!s} not unique. Found {} entries.".format(self, cur.rowcount)
return cur.fetchone()[0] return cur.fetchone()[0]
class DBRow:
""" Represents a row from a database and offers comparison functions.
"""
def __init__(self, nid, db_row, context):
self.nid = nid
self.db_row = db_row
self.context = context
def assert_row(self, row, exclude_columns):
""" Check that all columns of the given behave row are contained
in the database row. Exclude behave rows with the names given
in the `exclude_columns` list.
"""
for name, value in zip(row.headings, row.cells):
if name not in exclude_columns:
assert self.contains(name, value), self.assert_msg(name, value)
def contains(self, name, expected):
""" Check that the DB row contains a column `name` with the given value.
"""
if '+' in name:
column, field = name.split('+', 1)
return self._contains_hstore_value(column, field, expected)
if name == 'geometry':
return self._has_geometry(expected)
if name not in self.db_row:
return False
actual = self.db_row[name]
if expected == '-':
return actual is None
if name == 'name' and ':' not in expected:
return self._compare_column(actual[name], expected)
if 'place_id' in name:
return self._compare_place_id(actual, expected)
if name == 'centroid':
return self._has_centroid(expected)
return self._compare_column(actual, expected)
def _contains_hstore_value(self, column, field, expected):
if column == 'addr':
column = 'address'
if column not in self.db_row:
return False
if expected == '-':
return self.db_row[column] is None or field not in self.db_row[column]
if self.db_row[column] is None:
return False
return self._compare_column(self.db_row[column].get(field), expected)
def _compare_column(self, actual, expected):
if isinstance(actual, dict):
return actual == eval('{' + expected + '}')
return str(actual) == expected
def _compare_place_id(self, actual, expected):
if expected == '0':
return actual == 0
with self.context.db.cursor() as cur:
return NominatimID(expected).get_place_id(cur) == actual
def _has_centroid(self, expected):
if expected == 'in geometry':
with self.context.db.cursor() as cur:
cur.execute("""SELECT ST_Within(ST_SetSRID(ST_Point({cx}, {cy}), 4326),
ST_SetSRID('{geomtxt}'::geometry, 4326))""".format(**self.db_row))
return cur.fetchone()[0]
x, y = expected.split(' ')
return Almost(float(x)) == self.db_row['cx'] and Almost(float(y)) == self.db_row['cy']
def _has_geometry(self, expected):
geom = self.context.osm.parse_geometry(expected, self.context.scene)
with self.context.db.cursor() as cur:
cur.execute("""SELECT ST_Equals(ST_SnapToGrid({}, 0.00001, 0.00001),
ST_SnapToGrid(ST_SetSRID('{}'::geometry, 4326), 0.00001, 0.00001))""".format(
geom, self.db_row['geomtxt']))
return cur.fetchone()[0]
def assert_msg(self, name, value):
""" Return a string with an informative message for a failed compare.
"""
msg = "\nBad column '{}' in row '{!s}'.".format(name, self.nid)
actual = self._get_actual(name)
if actual is not None:
msg += " Expected: {}, got: {}.".format(value, actual)
else:
msg += " No such column."
return msg + "\nFull DB row: {}".format(json.dumps(dict(self.db_row), indent=4, default=str))
def _get_actual(self, name):
if '+' in name:
column, field = name.split('+', 1)
if column == 'addr':
column = 'address'
return (self.db_row.get(column) or {}).get(field)
if name == 'geometry':
return self.db_row['geomtxt']
if name not in self.db_row:
return None
if name == 'centroid':
return "POINT({cx} {cy})".format(**self.db_row)
actual = self.db_row[name]
if 'place_id' in name:
if actual is None:
return '<null>'
if actual == 0:
return "place ID 0"
with self.context.db.cursor():
cur.execute("""SELECT osm_type, osm_id, class
FROM placex WHERE place_id = %s""",
(actual, ))
if cur.rowcount == 1:
return "{0[0]}{0[1]}:{0[2]}".format(cur.fetchone())
return "[place ID {} not found]".format(actual)
return actual