BDD tests: factor out insert sql code

This commit is contained in:
Sarah Hoffmann
2026-02-16 18:04:47 +01:00
parent c0f1aeea4d
commit b43116ff52

View File

@@ -14,6 +14,7 @@ import re
from collections import defaultdict from collections import defaultdict
import psycopg import psycopg
import psycopg.sql as pysql
import pytest import pytest
from pytest_bdd import when, then, given from pytest_bdd import when, then, given
@@ -50,6 +51,34 @@ def _collect_place_ids(conn):
return pids return pids
@pytest.fixture
def row_factory(db_conn):
def _insert_row(table, **data):
columns = []
placeholders = []
values = []
for k, v in data.items():
columns.append(pysql.Identifier(k))
if isinstance(v, tuple):
placeholders.append(pysql.SQL(v[0]))
values.append(v[1])
elif isinstance(v, (pysql.Literal, pysql.SQL)):
placeholders.append(v)
else:
placeholders.append(pysql.Placeholder())
values.append(v)
sql = pysql.SQL("INSERT INTO {table} ({columns}) VALUES({values})")\
.format(table=pysql.Identifier(table),
columns=pysql.SQL(',').join(columns),
values=pysql.SQL(',').join(placeholders))
db_conn.execute(sql, values)
db_conn.commit()
return _insert_row
@pytest.fixture @pytest.fixture
def test_config_env(pytestconfig): def test_config_env(pytestconfig):
dbname = pytestconfig.getini('nominatim_test_db') dbname = pytestconfig.getini('nominatim_test_db')
@@ -85,18 +114,19 @@ def import_places(db_conn, named, datatable, node_grid):
@given(step_parse('the entrances'), target_fixture=None) @given(step_parse('the entrances'), target_fixture=None)
def import_place_entrances(db_conn, datatable, node_grid): def import_place_entrances(row_factory, datatable, node_grid):
""" Insert todo rows into the place_entrance table. """ Insert todo rows into the place_entrance table.
""" """
with db_conn.cursor() as cur:
for row in datatable[1:]: for row in datatable[1:]:
data = PlaceColumn(node_grid).add_row(datatable[0], row, False) data = PlaceColumn(node_grid).add_row(datatable[0], row, False)
assert data.columns['osm_type'] == 'N' assert data.columns['osm_type'] == 'N'
cur.execute("""INSERT INTO place_entrance (osm_id, type, extratags, geometry) params = {'osm_id': data.columns['osm_id'],
VALUES (%s, %s, %s, {})""".format(data.get_wkt()), 'type': data.columns['type'],
(data.columns['osm_id'], data.columns['type'], 'extratags': data.columns.get('extratags'),
data.columns.get('extratags'))) 'geometry': pysql.SQL(data.get_wkt())}
row_factory('place_entrance', **params)
@given(step_parse('the postcodes'), target_fixture=None) @given(step_parse('the postcodes'), target_fixture=None)
@@ -135,27 +165,24 @@ def import_place_postcode(db_conn, datatable, node_grid):
@given('the ways', target_fixture=None) @given('the ways', target_fixture=None)
def import_ways(db_conn, datatable): def import_ways(row_factory, datatable):
""" Import raw ways into the osm2pgsql way middle table. """ Import raw ways into the osm2pgsql way middle table.
""" """
with db_conn.cursor() as cur:
id_idx = datatable[0].index('id') id_idx = datatable[0].index('id')
node_idx = datatable[0].index('nodes') node_idx = datatable[0].index('nodes')
for line in datatable[1:]: for line in datatable[1:]:
tags = psycopg.types.json.Json( row_factory('planet_osm_ways',
id=line[id_idx],
nodes=[int(x) for x in line[node_idx].split(',')],
tags=psycopg.types.json.Json(
{k[5:]: v for k, v in zip(datatable[0], line) {k[5:]: v for k, v in zip(datatable[0], line)
if k.startswith("tags+")}) if k.startswith("tags+")}))
nodes = [int(x) for x in line[node_idx].split(',')]
cur.execute("INSERT INTO planet_osm_ways (id, nodes, tags) VALUES (%s, %s, %s)",
(line[id_idx], nodes, tags))
@given('the relations', target_fixture=None) @given('the relations', target_fixture=None)
def import_rels(db_conn, datatable): def import_rels(row_factory, datatable):
""" Import raw relations into the osm2pgsql relation middle table. """ Import raw relations into the osm2pgsql relation middle table.
""" """
with db_conn.cursor() as cur:
id_idx = datatable[0].index('id') id_idx = datatable[0].index('id')
memb_idx = datatable[0].index('members') memb_idx = datatable[0].index('members')
for line in datatable[1:]: for line in datatable[1:]:
@@ -170,8 +197,9 @@ def import_rels(db_conn, datatable):
raise ValueError(f'Illegal member {member}.') raise ValueError(f'Illegal member {member}.')
members.append({'ref': int(m[2]), 'role': m[3] or '', 'type': m[1]}) members.append({'ref': int(m[2]), 'role': m[3] or '', 'type': m[1]})
cur.execute('INSERT INTO planet_osm_rels (id, tags, members) VALUES (%s, %s, %s)', row_factory('planet_osm_rels',
(int(line[id_idx]), tags, psycopg.types.json.Json(members))) id=int(line[id_idx]), tags=tags,
members=psycopg.types.json.Json(members))
@when('importing', target_fixture='place_ids') @when('importing', target_fixture='place_ids')