diff --git a/test/python/conftest.py b/test/python/conftest.py index 891716d3..38bf70a6 100644 --- a/test/python/conftest.py +++ b/test/python/conftest.py @@ -145,11 +145,12 @@ def country_row(country_table, temp_db_cursor): @pytest.fixture -def load_sql(temp_db_conn, country_row): - proc = SQLPreprocessor(temp_db_conn, Configuration(None)) +def load_sql(temp_db_conn, country_table): + conf = Configuration(None) - def _run(filename, **kwargs): - proc.run_sql_file(temp_db_conn, filename, **kwargs) + def _run(*filename, **kwargs): + for fn in filename: + SQLPreprocessor(temp_db_conn, conf).run_sql_file(temp_db_conn, fn, **kwargs) return _run diff --git a/test/python/tools/test_database_import.py b/test/python/tools/test_database_import.py index ec8f504c..8c0aff8b 100644 --- a/test/python/tools/test_database_import.py +++ b/test/python/tools/test_database_import.py @@ -78,8 +78,8 @@ def test_setup_skeleton_already_exists(temp_db): database_import.setup_database_skeleton(f'dbname={temp_db}') -def test_import_osm_data_simple(table_factory, osm2pgsql_options, capfd): - table_factory('place', content=((1, ), )) +def test_import_osm_data_simple(place_row, osm2pgsql_options, capfd): + place_row() database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options) captured = capfd.readouterr() @@ -92,8 +92,8 @@ def test_import_osm_data_simple(table_factory, osm2pgsql_options, capfd): assert 'file.pbf' in captured.out -def test_import_osm_data_multifile(table_factory, tmp_path, osm2pgsql_options, capfd): - table_factory('place', content=((1, ), )) +def test_import_osm_data_multifile(place_row, tmp_path, osm2pgsql_options, capfd): + place_row() osm2pgsql_options['osm2pgsql_cache'] = 0 files = [tmp_path / 'file1.osm', tmp_path / 'file2.osm'] @@ -107,22 +107,19 @@ def test_import_osm_data_multifile(table_factory, tmp_path, osm2pgsql_options, c assert 'file2.osm' in captured.out -def test_import_osm_data_simple_no_data(table_factory, osm2pgsql_options): - table_factory('place') - +def test_import_osm_data_simple_no_data(place_row, osm2pgsql_options): with pytest.raises(UsageError, match='No data imported'): database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options) -def test_import_osm_data_simple_ignore_no_data(table_factory, osm2pgsql_options): - table_factory('place') - +def test_import_osm_data_simple_ignore_no_data(place_table, osm2pgsql_options): database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options, ignore_errors=True) -def test_import_osm_data_drop(table_factory, temp_db_cursor, tmp_path, osm2pgsql_options): - table_factory('place', content=((1, ), )) +def test_import_osm_data_drop(place_row, table_factory, temp_db_cursor, + tmp_path, osm2pgsql_options): + place_row() table_factory('planet_osm_nodes') flatfile = tmp_path / 'flatfile' @@ -136,8 +133,8 @@ def test_import_osm_data_drop(table_factory, temp_db_cursor, tmp_path, osm2pgsql assert not temp_db_cursor.table_exists('planet_osm_nodes') -def test_import_osm_data_default_cache(table_factory, osm2pgsql_options, capfd): - table_factory('place', content=((1, ), )) +def test_import_osm_data_default_cache(place_row, osm2pgsql_options, capfd): + place_row() osm2pgsql_options['osm2pgsql_cache'] = 0 @@ -215,52 +212,53 @@ async def test_load_data(dsn, place_row, placex_table, osmline_table, class TestSetupSQL: @pytest.fixture(autouse=True) - def init_env(self, temp_db, tmp_path, def_config, sql_preprocessor_cfg): - def_config.lib_dir.sql = tmp_path / 'sql' - def_config.lib_dir.sql.mkdir() - + def osm2ppsql_skel(self, def_config, temp_db_with_extensions, place_row, + country_table, table_factory, temp_db_conn): self.config = def_config + place_row() + table_factory('osm2pgsql_properties', 'property TEXT, value TEXT', + (('db_format', 2),)) - def write_sql(self, fname, content): - (self.config.lib_dir.sql / fname).write_text(content, encoding='utf-8') + table_factory('planet_osm_rels', 'id BIGINT, members JSONB, tags JSONB') + temp_db_conn.execute(""" + CREATE OR REPLACE FUNCTION planet_osm_member_ids(jsonb, character) + RETURNS bigint[] AS $$ + SELECT array_agg((el->>'ref')::int8) + FROM jsonb_array_elements($1) AS el WHERE el->>'type' = $2 + $$ LANGUAGE sql IMMUTABLE; + """) @pytest.mark.parametrize("reverse", [True, False]) - def test_create_tables(self, temp_db_conn, temp_db_cursor, reverse): - self.write_sql('tables.sql', - """CREATE FUNCTION test() RETURNS bool - AS $$ SELECT {{db.reverse_only}} $$ LANGUAGE SQL""") - - self.write_sql('grants.sql', "-- Mock grants file for testing\n") + def test_create_tables(self, table_factory, temp_db_conn, temp_db_cursor, reverse): + table_factory('country_osm_grid') database_import.create_tables(temp_db_conn, self.config, reverse) - temp_db_cursor.scalar('SELECT test()') == reverse + assert temp_db_cursor.table_exists('placex') + assert not reverse == temp_db_cursor.table_exists('search_name') - def test_create_table_triggers(self, temp_db_conn, temp_db_cursor): - self.write_sql('table-triggers.sql', - """CREATE FUNCTION test() RETURNS TEXT - AS $$ SELECT 'a'::text $$ LANGUAGE SQL""") + def test_create_table_triggers(self, temp_db_conn, placex_table, osmline_table, + postcode_table, load_sql): + load_sql('functions.sql') database_import.create_table_triggers(temp_db_conn, self.config) - temp_db_cursor.scalar('SELECT test()') == 'a' - - def test_create_partition_tables(self, temp_db_conn, temp_db_cursor): - self.write_sql('partition-tables.src.sql', - """CREATE FUNCTION test() RETURNS TEXT - AS $$ SELECT 'b'::text $$ LANGUAGE SQL""") + def test_create_partition_tables(self, country_row, temp_db_conn, temp_db_cursor, load_sql): + for i in range(3): + country_row(partition=i) + load_sql('tables/location_area.sql') database_import.create_partition_tables(temp_db_conn, self.config) - temp_db_cursor.scalar('SELECT test()') == 'b' + for i in range(3): + assert temp_db_cursor.table_exists(f"location_area_large_{i}") + assert temp_db_cursor.table_exists(f"search_name_{i}") @pytest.mark.parametrize("drop", [True, False]) @pytest.mark.asyncio - async def test_create_search_indices(self, temp_db_conn, temp_db_cursor, drop): - self.write_sql('indices.sql', - """CREATE FUNCTION test() RETURNS bool - AS $$ SELECT {{drop}} $$ LANGUAGE SQL""") - + async def test_create_search_indices(self, temp_db_conn, temp_db_cursor, drop, load_sql): + load_sql('tables.sql', 'functions/ranking.sql') await database_import.create_search_indices(temp_db_conn, self.config, drop) - temp_db_cursor.scalar('SELECT test()') == drop + assert temp_db_cursor.index_exists('placex', 'idx_placex_geometry') + assert not drop == temp_db_cursor.index_exists('placex', 'idx_placex_geometry_buildings')