make database import unit tests against real SQL

This commit is contained in:
Sarah Hoffmann
2026-02-15 21:38:38 +01:00
parent d0bd42298e
commit c31abf58d0
2 changed files with 47 additions and 48 deletions

View File

@@ -145,11 +145,12 @@ def country_row(country_table, temp_db_cursor):
@pytest.fixture
def load_sql(temp_db_conn, country_row):
proc = SQLPreprocessor(temp_db_conn, Configuration(None))
def load_sql(temp_db_conn, country_table):
conf = Configuration(None)
def _run(filename, **kwargs):
proc.run_sql_file(temp_db_conn, filename, **kwargs)
def _run(*filename, **kwargs):
for fn in filename:
SQLPreprocessor(temp_db_conn, conf).run_sql_file(temp_db_conn, fn, **kwargs)
return _run

View File

@@ -78,8 +78,8 @@ def test_setup_skeleton_already_exists(temp_db):
database_import.setup_database_skeleton(f'dbname={temp_db}')
def test_import_osm_data_simple(table_factory, osm2pgsql_options, capfd):
table_factory('place', content=((1, ), ))
def test_import_osm_data_simple(place_row, osm2pgsql_options, capfd):
place_row()
database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options)
captured = capfd.readouterr()
@@ -92,8 +92,8 @@ def test_import_osm_data_simple(table_factory, osm2pgsql_options, capfd):
assert 'file.pbf' in captured.out
def test_import_osm_data_multifile(table_factory, tmp_path, osm2pgsql_options, capfd):
table_factory('place', content=((1, ), ))
def test_import_osm_data_multifile(place_row, tmp_path, osm2pgsql_options, capfd):
place_row()
osm2pgsql_options['osm2pgsql_cache'] = 0
files = [tmp_path / 'file1.osm', tmp_path / 'file2.osm']
@@ -107,22 +107,19 @@ def test_import_osm_data_multifile(table_factory, tmp_path, osm2pgsql_options, c
assert 'file2.osm' in captured.out
def test_import_osm_data_simple_no_data(table_factory, osm2pgsql_options):
table_factory('place')
def test_import_osm_data_simple_no_data(place_row, osm2pgsql_options):
with pytest.raises(UsageError, match='No data imported'):
database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options)
def test_import_osm_data_simple_ignore_no_data(table_factory, osm2pgsql_options):
table_factory('place')
def test_import_osm_data_simple_ignore_no_data(place_table, osm2pgsql_options):
database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options,
ignore_errors=True)
def test_import_osm_data_drop(table_factory, temp_db_cursor, tmp_path, osm2pgsql_options):
table_factory('place', content=((1, ), ))
def test_import_osm_data_drop(place_row, table_factory, temp_db_cursor,
tmp_path, osm2pgsql_options):
place_row()
table_factory('planet_osm_nodes')
flatfile = tmp_path / 'flatfile'
@@ -136,8 +133,8 @@ def test_import_osm_data_drop(table_factory, temp_db_cursor, tmp_path, osm2pgsql
assert not temp_db_cursor.table_exists('planet_osm_nodes')
def test_import_osm_data_default_cache(table_factory, osm2pgsql_options, capfd):
table_factory('place', content=((1, ), ))
def test_import_osm_data_default_cache(place_row, osm2pgsql_options, capfd):
place_row()
osm2pgsql_options['osm2pgsql_cache'] = 0
@@ -215,52 +212,53 @@ async def test_load_data(dsn, place_row, placex_table, osmline_table,
class TestSetupSQL:
@pytest.fixture(autouse=True)
def init_env(self, temp_db, tmp_path, def_config, sql_preprocessor_cfg):
def_config.lib_dir.sql = tmp_path / 'sql'
def_config.lib_dir.sql.mkdir()
def osm2ppsql_skel(self, def_config, temp_db_with_extensions, place_row,
country_table, table_factory, temp_db_conn):
self.config = def_config
place_row()
table_factory('osm2pgsql_properties', 'property TEXT, value TEXT',
(('db_format', 2),))
def write_sql(self, fname, content):
(self.config.lib_dir.sql / fname).write_text(content, encoding='utf-8')
table_factory('planet_osm_rels', 'id BIGINT, members JSONB, tags JSONB')
temp_db_conn.execute("""
CREATE OR REPLACE FUNCTION planet_osm_member_ids(jsonb, character)
RETURNS bigint[] AS $$
SELECT array_agg((el->>'ref')::int8)
FROM jsonb_array_elements($1) AS el WHERE el->>'type' = $2
$$ LANGUAGE sql IMMUTABLE;
""")
@pytest.mark.parametrize("reverse", [True, False])
def test_create_tables(self, temp_db_conn, temp_db_cursor, reverse):
self.write_sql('tables.sql',
"""CREATE FUNCTION test() RETURNS bool
AS $$ SELECT {{db.reverse_only}} $$ LANGUAGE SQL""")
self.write_sql('grants.sql', "-- Mock grants file for testing\n")
def test_create_tables(self, table_factory, temp_db_conn, temp_db_cursor, reverse):
table_factory('country_osm_grid')
database_import.create_tables(temp_db_conn, self.config, reverse)
temp_db_cursor.scalar('SELECT test()') == reverse
assert temp_db_cursor.table_exists('placex')
assert not reverse == temp_db_cursor.table_exists('search_name')
def test_create_table_triggers(self, temp_db_conn, temp_db_cursor):
self.write_sql('table-triggers.sql',
"""CREATE FUNCTION test() RETURNS TEXT
AS $$ SELECT 'a'::text $$ LANGUAGE SQL""")
def test_create_table_triggers(self, temp_db_conn, placex_table, osmline_table,
postcode_table, load_sql):
load_sql('functions.sql')
database_import.create_table_triggers(temp_db_conn, self.config)
temp_db_cursor.scalar('SELECT test()') == 'a'
def test_create_partition_tables(self, temp_db_conn, temp_db_cursor):
self.write_sql('partition-tables.src.sql',
"""CREATE FUNCTION test() RETURNS TEXT
AS $$ SELECT 'b'::text $$ LANGUAGE SQL""")
def test_create_partition_tables(self, country_row, temp_db_conn, temp_db_cursor, load_sql):
for i in range(3):
country_row(partition=i)
load_sql('tables/location_area.sql')
database_import.create_partition_tables(temp_db_conn, self.config)
temp_db_cursor.scalar('SELECT test()') == 'b'
for i in range(3):
assert temp_db_cursor.table_exists(f"location_area_large_{i}")
assert temp_db_cursor.table_exists(f"search_name_{i}")
@pytest.mark.parametrize("drop", [True, False])
@pytest.mark.asyncio
async def test_create_search_indices(self, temp_db_conn, temp_db_cursor, drop):
self.write_sql('indices.sql',
"""CREATE FUNCTION test() RETURNS bool
AS $$ SELECT {{drop}} $$ LANGUAGE SQL""")
async def test_create_search_indices(self, temp_db_conn, temp_db_cursor, drop, load_sql):
load_sql('tables.sql', 'functions/ranking.sql')
await database_import.create_search_indices(temp_db_conn, self.config, drop)
temp_db_cursor.scalar('SELECT test()') == drop
assert temp_db_cursor.index_exists('placex', 'idx_placex_geometry')
assert not drop == temp_db_cursor.index_exists('placex', 'idx_placex_geometry_buildings')