diff --git a/src/nominatim_db/tools/special_phrases/sp_importer.py b/src/nominatim_db/tools/special_phrases/sp_importer.py index 89ac6dac..323decf9 100644 --- a/src/nominatim_db/tools/special_phrases/sp_importer.py +++ b/src/nominatim_db/tools/special_phrases/sp_importer.py @@ -64,23 +64,25 @@ class SPImporter(): # special phrases class/type on the wiki. self.table_phrases_to_delete: Set[str] = set() - def get_classtype_pairs(self) -> Set[Tuple[str, str]]: + def get_classtype_pairs(self, min: int = 0) -> Set[Tuple[str, str]]: """ Returns list of allowed special phrases from the database, restricting to a list of combinations of classes and types - which occur more than 100 times + which occur more than a specified amount of times. + + Default value for this, if not specified, is at least once. """ db_combinations = set() - query = """ + query = f""" SELECT class AS CLS, type AS typ FROM placex GROUP BY class, type - HAVING COUNT(*) > 100 + HAVING COUNT(*) > {min} """ with self.db_connection.cursor() as db_cursor: db_cursor.execute(SQL(query)) - for row in db_cursor.fetchall(): + for row in db_cursor: db_combinations.add((row[0], row[1])) return db_combinations diff --git a/test/python/tools/test_sp_importer.py b/test/python/tools/test_sp_importer.py index b27172c8..dda02f11 100644 --- a/test/python/tools/test_sp_importer.py +++ b/test/python/tools/test_sp_importer.py @@ -1,60 +1,20 @@ -import pytest -import tempfile -import os - from nominatim_db.tools.special_phrases.sp_importer import SPImporter -# Testing Database Class Pair Retrival using Mock Database -def test_get_classtype_pairs(monkeypatch): - class Config: - def load_sub_configuration(self, path, section=None): - return {"blackList": {}, "whiteList": {}} - - class Cursor: - def execute(self, query): pass - def fetchall(self): - return [ - ("highway", "motorway"), - ("historic", "castle") - ] - def __enter__(self): return self - def __exit__(self, exc_type, exc_val, exc_tb): pass - - class Connection: - def cursor(self): return Cursor() - - config = Config() - conn = Connection() - importer = SPImporter(config=config, conn=conn, sp_loader=None) - - result = importer.get_classtype_pairs() - - expected = { - ("highway", "motorway"), - ("historic", "castle") - } - - assert result == expected # Testing Database Class Pair Retrival using Conftest.py and placex -def test_get_classtype_pair_data(placex_table, temp_db_conn): - class Config: - def load_sub_configuration(self, *_): - return {'blackList': {}, 'whiteList': {}} - +def test_get_classtype_pair_data(placex_table, def_config, temp_db_conn): for _ in range(101): - placex_table.add(cls='highway', typ='motorway') # edge case 101 + placex_table.add(cls='highway', typ='motorway') # edge case 101 for _ in range(99): - placex_table.add(cls='amenity', typ='prison') # edge case 99 + placex_table.add(cls='amenity', typ='prison') # edge case 99 for _ in range(150): placex_table.add(cls='tourism', typ='hotel') - config = Config() - importer = SPImporter(config=config, conn=temp_db_conn, sp_loader=None) + importer = SPImporter(config=def_config, conn=temp_db_conn, sp_loader=None) - result = importer.get_classtype_pairs() + result = importer.get_classtype_pairs(min=100) expected = { ("highway", "motorway"), @@ -63,24 +23,20 @@ def test_get_classtype_pair_data(placex_table, temp_db_conn): assert result == expected, f"Expected {expected}, got {result}" -def test_get_classtype_pair_data_more(placex_table, temp_db_conn): - class Config: - def load_sub_configuration(self, *_): - return {'blackList': {}, 'whiteList': {}} - + +def test_get_classtype_pair_data_more(placex_table, def_config, temp_db_conn): for _ in range(100): - placex_table.add(cls='emergency', typ='firehydrant') # edge case 100, not included + placex_table.add(cls='emergency', typ='firehydrant') # edge case 100, not included for _ in range(199): - placex_table.add(cls='amenity', typ='prison') + placex_table.add(cls='amenity', typ='prison') for _ in range(3478): placex_table.add(cls='tourism', typ='hotel') - config = Config() - importer = SPImporter(config=config, conn=temp_db_conn, sp_loader=None) + importer = SPImporter(config=def_config, conn=temp_db_conn, sp_loader=None) - result = importer.get_classtype_pairs() + result = importer.get_classtype_pairs(min=100) expected = { ("amenity", "prison"), @@ -88,3 +44,26 @@ def test_get_classtype_pair_data_more(placex_table, temp_db_conn): } assert result == expected, f"Expected {expected}, got {result}" + + +def test_get_classtype_pair_data_default(placex_table, def_config, temp_db_conn): + for _ in range(1): + placex_table.add(cls='emergency', typ='firehydrant') + + for _ in range(199): + placex_table.add(cls='amenity', typ='prison') + + for _ in range(3478): + placex_table.add(cls='tourism', typ='hotel') + + importer = SPImporter(config=def_config, conn=temp_db_conn, sp_loader=None) + + result = importer.get_classtype_pairs() + + expected = { + ("amenity", "prison"), + ("tourism", "hotel"), + ("emergency", "firehydrant") + } + + assert result == expected, f"Expected {expected}, got {result}"