diff --git a/test/bdd/test_db.py b/test/bdd/test_db.py index 1a7eef7c..68a2a8f2 100644 --- a/test/bdd/test_db.py +++ b/test/bdd/test_db.py @@ -14,6 +14,7 @@ import re from collections import defaultdict import psycopg +import psycopg.sql as pysql import pytest from pytest_bdd import when, then, given @@ -50,6 +51,34 @@ def _collect_place_ids(conn): return pids +@pytest.fixture +def row_factory(db_conn): + def _insert_row(table, **data): + columns = [] + placeholders = [] + values = [] + for k, v in data.items(): + columns.append(pysql.Identifier(k)) + if isinstance(v, tuple): + placeholders.append(pysql.SQL(v[0])) + values.append(v[1]) + elif isinstance(v, (pysql.Literal, pysql.SQL)): + placeholders.append(v) + else: + placeholders.append(pysql.Placeholder()) + values.append(v) + + sql = pysql.SQL("INSERT INTO {table} ({columns}) VALUES({values})")\ + .format(table=pysql.Identifier(table), + columns=pysql.SQL(',').join(columns), + values=pysql.SQL(',').join(placeholders)) + + db_conn.execute(sql, values) + db_conn.commit() + + return _insert_row + + @pytest.fixture def test_config_env(pytestconfig): dbname = pytestconfig.getini('nominatim_test_db') @@ -85,18 +114,19 @@ def import_places(db_conn, named, datatable, node_grid): @given(step_parse('the entrances'), target_fixture=None) -def import_place_entrances(db_conn, datatable, node_grid): +def import_place_entrances(row_factory, datatable, node_grid): """ Insert todo rows into the place_entrance table. """ - with db_conn.cursor() as cur: - for row in datatable[1:]: - data = PlaceColumn(node_grid).add_row(datatable[0], row, False) - assert data.columns['osm_type'] == 'N' + for row in datatable[1:]: + data = PlaceColumn(node_grid).add_row(datatable[0], row, False) + assert data.columns['osm_type'] == 'N' - cur.execute("""INSERT INTO place_entrance (osm_id, type, extratags, geometry) - VALUES (%s, %s, %s, {})""".format(data.get_wkt()), - (data.columns['osm_id'], data.columns['type'], - data.columns.get('extratags'))) + params = {'osm_id': data.columns['osm_id'], + 'type': data.columns['type'], + 'extratags': data.columns.get('extratags'), + 'geometry': pysql.SQL(data.get_wkt())} + + row_factory('place_entrance', **params) @given(step_parse('the postcodes'), target_fixture=None) @@ -135,43 +165,41 @@ def import_place_postcode(db_conn, datatable, node_grid): @given('the ways', target_fixture=None) -def import_ways(db_conn, datatable): +def import_ways(row_factory, datatable): """ Import raw ways into the osm2pgsql way middle table. """ - with db_conn.cursor() as cur: - id_idx = datatable[0].index('id') - node_idx = datatable[0].index('nodes') - for line in datatable[1:]: - tags = psycopg.types.json.Json( - {k[5:]: v for k, v in zip(datatable[0], line) - if k.startswith("tags+")}) - nodes = [int(x) for x in line[node_idx].split(',')] - - cur.execute("INSERT INTO planet_osm_ways (id, nodes, tags) VALUES (%s, %s, %s)", - (line[id_idx], nodes, tags)) + id_idx = datatable[0].index('id') + node_idx = datatable[0].index('nodes') + for line in datatable[1:]: + row_factory('planet_osm_ways', + id=line[id_idx], + nodes=[int(x) for x in line[node_idx].split(',')], + tags=psycopg.types.json.Json( + {k[5:]: v for k, v in zip(datatable[0], line) + if k.startswith("tags+")})) @given('the relations', target_fixture=None) -def import_rels(db_conn, datatable): +def import_rels(row_factory, datatable): """ Import raw relations into the osm2pgsql relation middle table. """ - with db_conn.cursor() as cur: - id_idx = datatable[0].index('id') - memb_idx = datatable[0].index('members') - for line in datatable[1:]: - tags = psycopg.types.json.Json( - {k[5:]: v for k, v in zip(datatable[0], line) - if k.startswith("tags+")}) - members = [] - if line[memb_idx]: - for member in line[memb_idx].split(','): - m = re.fullmatch(r'\s*([RWN])(\d+)(?::(\S+))?\s*', member) - if not m: - raise ValueError(f'Illegal member {member}.') - members.append({'ref': int(m[2]), 'role': m[3] or '', 'type': m[1]}) + id_idx = datatable[0].index('id') + memb_idx = datatable[0].index('members') + for line in datatable[1:]: + tags = psycopg.types.json.Json( + {k[5:]: v for k, v in zip(datatable[0], line) + if k.startswith("tags+")}) + members = [] + if line[memb_idx]: + for member in line[memb_idx].split(','): + m = re.fullmatch(r'\s*([RWN])(\d+)(?::(\S+))?\s*', member) + if not m: + raise ValueError(f'Illegal member {member}.') + members.append({'ref': int(m[2]), 'role': m[3] or '', 'type': m[1]}) - cur.execute('INSERT INTO planet_osm_rels (id, tags, members) VALUES (%s, %s, %s)', - (int(line[id_idx]), tags, psycopg.types.json.Json(members))) + row_factory('planet_osm_rels', + id=int(line[id_idx]), tags=tags, + members=psycopg.types.json.Json(members)) @when('importing', target_fixture='place_ids')