make database import unit tests against real SQL

This commit is contained in:
Sarah Hoffmann
2026-02-15 21:38:38 +01:00
parent d0bd42298e
commit c31abf58d0
2 changed files with 47 additions and 48 deletions

View File

@@ -145,11 +145,12 @@ def country_row(country_table, temp_db_cursor):
@pytest.fixture @pytest.fixture
def load_sql(temp_db_conn, country_row): def load_sql(temp_db_conn, country_table):
proc = SQLPreprocessor(temp_db_conn, Configuration(None)) conf = Configuration(None)
def _run(filename, **kwargs): def _run(*filename, **kwargs):
proc.run_sql_file(temp_db_conn, filename, **kwargs) for fn in filename:
SQLPreprocessor(temp_db_conn, conf).run_sql_file(temp_db_conn, fn, **kwargs)
return _run return _run

View File

@@ -78,8 +78,8 @@ def test_setup_skeleton_already_exists(temp_db):
database_import.setup_database_skeleton(f'dbname={temp_db}') database_import.setup_database_skeleton(f'dbname={temp_db}')
def test_import_osm_data_simple(table_factory, osm2pgsql_options, capfd): def test_import_osm_data_simple(place_row, osm2pgsql_options, capfd):
table_factory('place', content=((1, ), )) place_row()
database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options) database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options)
captured = capfd.readouterr() captured = capfd.readouterr()
@@ -92,8 +92,8 @@ def test_import_osm_data_simple(table_factory, osm2pgsql_options, capfd):
assert 'file.pbf' in captured.out assert 'file.pbf' in captured.out
def test_import_osm_data_multifile(table_factory, tmp_path, osm2pgsql_options, capfd): def test_import_osm_data_multifile(place_row, tmp_path, osm2pgsql_options, capfd):
table_factory('place', content=((1, ), )) place_row()
osm2pgsql_options['osm2pgsql_cache'] = 0 osm2pgsql_options['osm2pgsql_cache'] = 0
files = [tmp_path / 'file1.osm', tmp_path / 'file2.osm'] files = [tmp_path / 'file1.osm', tmp_path / 'file2.osm']
@@ -107,22 +107,19 @@ def test_import_osm_data_multifile(table_factory, tmp_path, osm2pgsql_options, c
assert 'file2.osm' in captured.out assert 'file2.osm' in captured.out
def test_import_osm_data_simple_no_data(table_factory, osm2pgsql_options): def test_import_osm_data_simple_no_data(place_row, osm2pgsql_options):
table_factory('place')
with pytest.raises(UsageError, match='No data imported'): with pytest.raises(UsageError, match='No data imported'):
database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options) database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options)
def test_import_osm_data_simple_ignore_no_data(table_factory, osm2pgsql_options): def test_import_osm_data_simple_ignore_no_data(place_table, osm2pgsql_options):
table_factory('place')
database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options, database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options,
ignore_errors=True) ignore_errors=True)
def test_import_osm_data_drop(table_factory, temp_db_cursor, tmp_path, osm2pgsql_options): def test_import_osm_data_drop(place_row, table_factory, temp_db_cursor,
table_factory('place', content=((1, ), )) tmp_path, osm2pgsql_options):
place_row()
table_factory('planet_osm_nodes') table_factory('planet_osm_nodes')
flatfile = tmp_path / 'flatfile' flatfile = tmp_path / 'flatfile'
@@ -136,8 +133,8 @@ def test_import_osm_data_drop(table_factory, temp_db_cursor, tmp_path, osm2pgsql
assert not temp_db_cursor.table_exists('planet_osm_nodes') assert not temp_db_cursor.table_exists('planet_osm_nodes')
def test_import_osm_data_default_cache(table_factory, osm2pgsql_options, capfd): def test_import_osm_data_default_cache(place_row, osm2pgsql_options, capfd):
table_factory('place', content=((1, ), )) place_row()
osm2pgsql_options['osm2pgsql_cache'] = 0 osm2pgsql_options['osm2pgsql_cache'] = 0
@@ -215,52 +212,53 @@ async def test_load_data(dsn, place_row, placex_table, osmline_table,
class TestSetupSQL: class TestSetupSQL:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def init_env(self, temp_db, tmp_path, def_config, sql_preprocessor_cfg): def osm2ppsql_skel(self, def_config, temp_db_with_extensions, place_row,
def_config.lib_dir.sql = tmp_path / 'sql' country_table, table_factory, temp_db_conn):
def_config.lib_dir.sql.mkdir()
self.config = def_config self.config = def_config
place_row()
table_factory('osm2pgsql_properties', 'property TEXT, value TEXT',
(('db_format', 2),))
def write_sql(self, fname, content): table_factory('planet_osm_rels', 'id BIGINT, members JSONB, tags JSONB')
(self.config.lib_dir.sql / fname).write_text(content, encoding='utf-8') temp_db_conn.execute("""
CREATE OR REPLACE FUNCTION planet_osm_member_ids(jsonb, character)
RETURNS bigint[] AS $$
SELECT array_agg((el->>'ref')::int8)
FROM jsonb_array_elements($1) AS el WHERE el->>'type' = $2
$$ LANGUAGE sql IMMUTABLE;
""")
@pytest.mark.parametrize("reverse", [True, False]) @pytest.mark.parametrize("reverse", [True, False])
def test_create_tables(self, temp_db_conn, temp_db_cursor, reverse): def test_create_tables(self, table_factory, temp_db_conn, temp_db_cursor, reverse):
self.write_sql('tables.sql', table_factory('country_osm_grid')
"""CREATE FUNCTION test() RETURNS bool
AS $$ SELECT {{db.reverse_only}} $$ LANGUAGE SQL""")
self.write_sql('grants.sql', "-- Mock grants file for testing\n")
database_import.create_tables(temp_db_conn, self.config, reverse) database_import.create_tables(temp_db_conn, self.config, reverse)
temp_db_cursor.scalar('SELECT test()') == reverse assert temp_db_cursor.table_exists('placex')
assert not reverse == temp_db_cursor.table_exists('search_name')
def test_create_table_triggers(self, temp_db_conn, temp_db_cursor): def test_create_table_triggers(self, temp_db_conn, placex_table, osmline_table,
self.write_sql('table-triggers.sql', postcode_table, load_sql):
"""CREATE FUNCTION test() RETURNS TEXT load_sql('functions.sql')
AS $$ SELECT 'a'::text $$ LANGUAGE SQL""")
database_import.create_table_triggers(temp_db_conn, self.config) database_import.create_table_triggers(temp_db_conn, self.config)
temp_db_cursor.scalar('SELECT test()') == 'a' def test_create_partition_tables(self, country_row, temp_db_conn, temp_db_cursor, load_sql):
for i in range(3):
def test_create_partition_tables(self, temp_db_conn, temp_db_cursor): country_row(partition=i)
self.write_sql('partition-tables.src.sql', load_sql('tables/location_area.sql')
"""CREATE FUNCTION test() RETURNS TEXT
AS $$ SELECT 'b'::text $$ LANGUAGE SQL""")
database_import.create_partition_tables(temp_db_conn, self.config) database_import.create_partition_tables(temp_db_conn, self.config)
temp_db_cursor.scalar('SELECT test()') == 'b' for i in range(3):
assert temp_db_cursor.table_exists(f"location_area_large_{i}")
assert temp_db_cursor.table_exists(f"search_name_{i}")
@pytest.mark.parametrize("drop", [True, False]) @pytest.mark.parametrize("drop", [True, False])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_search_indices(self, temp_db_conn, temp_db_cursor, drop): async def test_create_search_indices(self, temp_db_conn, temp_db_cursor, drop, load_sql):
self.write_sql('indices.sql', load_sql('tables.sql', 'functions/ranking.sql')
"""CREATE FUNCTION test() RETURNS bool
AS $$ SELECT {{drop}} $$ LANGUAGE SQL""")
await database_import.create_search_indices(temp_db_conn, self.config, drop) await database_import.create_search_indices(temp_db_conn, self.config, drop)
temp_db_cursor.scalar('SELECT test()') == drop assert temp_db_cursor.index_exists('placex', 'idx_placex_geometry')
assert not drop == temp_db_cursor.index_exists('placex', 'idx_placex_geometry_buildings')