diff --git a/src/nominatim_db/clicmd/args.py b/src/nominatim_db/clicmd/args.py index 45df9b7c..5c6a806a 100644 --- a/src/nominatim_db/clicmd/args.py +++ b/src/nominatim_db/clicmd/args.py @@ -136,6 +136,7 @@ class NominatimArgs: import_from_wiki: bool import_from_csv: Optional[str] no_replace: bool + min: int # Arguments to all query functions format: str diff --git a/src/nominatim_db/clicmd/special_phrases.py b/src/nominatim_db/clicmd/special_phrases.py index 9ba751a0..90560fb7 100644 --- a/src/nominatim_db/clicmd/special_phrases.py +++ b/src/nominatim_db/clicmd/special_phrases.py @@ -58,6 +58,8 @@ class ImportSpecialPhrases: help='Import special phrases from a CSV file') group.add_argument('--no-replace', action='store_true', help='Keep the old phrases and only add the new ones') + group.add_argument('--min', type=int, default=0, + help='Restrict special phrases by minimum occurance') def run(self, args: NominatimArgs) -> int: @@ -82,7 +84,9 @@ class ImportSpecialPhrases: tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config) should_replace = not args.no_replace + min = args.min + with connect(args.config.get_libpq_dsn()) as db_connection: SPImporter( args.config, db_connection, loader - ).import_phrases(tokenizer, should_replace) + ).import_phrases(tokenizer, should_replace, min) diff --git a/src/nominatim_db/tools/special_phrases/sp_importer.py b/src/nominatim_db/tools/special_phrases/sp_importer.py index ac50377f..12e695b6 100644 --- a/src/nominatim_db/tools/special_phrases/sp_importer.py +++ b/src/nominatim_db/tools/special_phrases/sp_importer.py @@ -68,16 +68,17 @@ class SPImporter(): """ Returns list of allowed special phrases from the database, restricting to a list of combinations of classes and types - which occur more than a specified amount of times. + which occur equal to or more than a specified amount of times. - Default value for this, if not specified, is at least once. + Default value for this is 0, which allows everything in database. """ db_combinations = set() + query = f""" SELECT class AS CLS, type AS typ FROM placex GROUP BY class, type - HAVING COUNT(*) > {min} + HAVING COUNT(*) >= {min} """ with self.db_connection.cursor() as db_cursor: @@ -87,7 +88,8 @@ class SPImporter(): return db_combinations - def import_phrases(self, tokenizer: AbstractTokenizer, should_replace: bool) -> None: + def import_phrases(self, tokenizer: AbstractTokenizer, should_replace: bool, + min: int = 0) -> None: """ Iterate through all SpecialPhrases extracted from the loader and import them into the database. @@ -107,7 +109,7 @@ class SPImporter(): if result: class_type_pairs.add(result) - self._create_classtype_table_and_indexes(class_type_pairs) + self._create_classtype_table_and_indexes(class_type_pairs, min) if should_replace: self._remove_non_existent_tables_from_db() @@ -186,7 +188,8 @@ class SPImporter(): return (phrase.p_class, phrase.p_type) def _create_classtype_table_and_indexes(self, - class_type_pairs: Iterable[Tuple[str, str]]) -> None: + class_type_pairs: Iterable[Tuple[str, str]], + min: int = 0) -> None: """ Create table place_classtype for each given pair. Also create indexes on place_id and centroid. @@ -200,13 +203,15 @@ class SPImporter(): with self.db_connection.cursor() as db_cursor: db_cursor.execute("CREATE INDEX idx_placex_classtype ON placex (class, type)") - allowed_special_phrases = self.get_classtype_pairs() + if min: + allowed_special_phrases = self.get_classtype_pairs(min) for pair in class_type_pairs: phrase_class = pair[0] phrase_type = pair[1] - if (phrase_class, phrase_type) not in allowed_special_phrases: + # Will only filter if min is not 0 + if min and (phrase_class, phrase_type) not in allowed_special_phrases: LOG.warning("Skipping phrase %s=%s: not in allowed special phrases", phrase_class, phrase_type) continue diff --git a/test/python/tools/test_sp_importer.py b/test/python/tools/test_sp_importer.py index dda02f11..c64c2b7d 100644 --- a/test/python/tools/test_sp_importer.py +++ b/test/python/tools/test_sp_importer.py @@ -3,8 +3,8 @@ from nominatim_db.tools.special_phrases.sp_importer import SPImporter # Testing Database Class Pair Retrival using Conftest.py and placex def test_get_classtype_pair_data(placex_table, def_config, temp_db_conn): - for _ in range(101): - placex_table.add(cls='highway', typ='motorway') # edge case 101 + for _ in range(100): + placex_table.add(cls='highway', typ='motorway') # edge case 100 for _ in range(99): placex_table.add(cls='amenity', typ='prison') # edge case 99 @@ -25,8 +25,8 @@ def test_get_classtype_pair_data(placex_table, def_config, temp_db_conn): def test_get_classtype_pair_data_more(placex_table, def_config, temp_db_conn): - for _ in range(100): - placex_table.add(cls='emergency', typ='firehydrant') # edge case 100, not included + for _ in range(99): + placex_table.add(cls='emergency', typ='firehydrant') # edge case 99, not included for _ in range(199): placex_table.add(cls='amenity', typ='prison')