forked from hans/Nominatim
Merge pull request #3710 from anqixxx/fix-special-phrases-filtering
Fix special phrases filtering
This commit is contained in:
@@ -342,7 +342,8 @@ HTML_HEADER: str = """<!DOCTYPE html>
|
||||
<title>Nominatim - Debug</title>
|
||||
<style>
|
||||
""" + \
|
||||
(HtmlFormatter(nobackground=True).get_style_defs('.highlight') if CODE_HIGHLIGHT else '') + \
|
||||
(HtmlFormatter(nobackground=True).get_style_defs('.highlight') # type: ignore[no-untyped-call]
|
||||
if CODE_HIGHLIGHT else '') + \
|
||||
"""
|
||||
h2 { font-size: x-large }
|
||||
|
||||
|
||||
@@ -127,7 +127,7 @@ def import_osm_data(osm_files: Union[Path, Sequence[Path]],
|
||||
fsize += os.stat(str(fname)).st_size
|
||||
else:
|
||||
fsize = os.stat(str(osm_files)).st_size
|
||||
options['osm2pgsql_cache'] = int(min((mem.available + mem.cached) * 0.75,
|
||||
options['osm2pgsql_cache'] = int(min((mem.available + getattr(mem, 'cached', 0)) * 0.75,
|
||||
fsize * 2) / 1024 / 1024) + 1
|
||||
|
||||
run_osm2pgsql(options)
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
from typing import Iterable, Tuple, Mapping, Sequence, Optional, Set
|
||||
import logging
|
||||
import re
|
||||
|
||||
from psycopg.sql import Identifier, SQL
|
||||
|
||||
from ...typing import Protocol
|
||||
@@ -65,6 +64,29 @@ class SPImporter():
|
||||
# special phrases class/type on the wiki.
|
||||
self.table_phrases_to_delete: Set[str] = set()
|
||||
|
||||
def get_classtype_pairs(self, min: int = 0) -> Set[Tuple[str, str]]:
|
||||
"""
|
||||
Returns list of allowed special phrases from the database,
|
||||
restricting to a list of combinations of classes and types
|
||||
which occur more than a specified amount of times.
|
||||
|
||||
Default value for this, if not specified, is at least once.
|
||||
"""
|
||||
db_combinations = set()
|
||||
query = f"""
|
||||
SELECT class AS CLS, type AS typ
|
||||
FROM placex
|
||||
GROUP BY class, type
|
||||
HAVING COUNT(*) > {min}
|
||||
"""
|
||||
|
||||
with self.db_connection.cursor() as db_cursor:
|
||||
db_cursor.execute(SQL(query))
|
||||
for row in db_cursor:
|
||||
db_combinations.add((row[0], row[1]))
|
||||
|
||||
return db_combinations
|
||||
|
||||
def import_phrases(self, tokenizer: AbstractTokenizer, should_replace: bool) -> None:
|
||||
"""
|
||||
Iterate through all SpecialPhrases extracted from the
|
||||
@@ -88,6 +110,7 @@ class SPImporter():
|
||||
self._create_classtype_table_and_indexes(class_type_pairs)
|
||||
if should_replace:
|
||||
self._remove_non_existent_tables_from_db()
|
||||
|
||||
self.db_connection.commit()
|
||||
|
||||
with tokenizer.name_analyzer() as analyzer:
|
||||
@@ -177,10 +200,17 @@ class SPImporter():
|
||||
with self.db_connection.cursor() as db_cursor:
|
||||
db_cursor.execute("CREATE INDEX idx_placex_classtype ON placex (class, type)")
|
||||
|
||||
allowed_special_phrases = self.get_classtype_pairs()
|
||||
|
||||
for pair in class_type_pairs:
|
||||
phrase_class = pair[0]
|
||||
phrase_type = pair[1]
|
||||
|
||||
if (phrase_class, phrase_type) not in allowed_special_phrases:
|
||||
LOG.warning("Skipping phrase %s=%s: not in allowed special phrases",
|
||||
phrase_class, phrase_type)
|
||||
continue
|
||||
|
||||
table_name = _classtype_table(phrase_class, phrase_type)
|
||||
|
||||
if table_name in self.table_phrases_to_delete:
|
||||
|
||||
@@ -127,7 +127,7 @@ def test_grant_access_to_web_user(temp_db_conn, temp_db_cursor, table_factory,
|
||||
|
||||
def test_create_place_classtype_table_and_indexes(
|
||||
temp_db_cursor, def_config, placex_table,
|
||||
sp_importer, temp_db_conn):
|
||||
sp_importer, temp_db_conn, monkeypatch):
|
||||
"""
|
||||
Test that _create_place_classtype_table_and_indexes()
|
||||
create the right place_classtype tables and place_id indexes
|
||||
@@ -135,7 +135,8 @@ def test_create_place_classtype_table_and_indexes(
|
||||
for the given set of pairs.
|
||||
"""
|
||||
pairs = set([('class1', 'type1'), ('class2', 'type2')])
|
||||
|
||||
for pair in pairs:
|
||||
placex_table.add(cls=pair[0], typ=pair[1]) # adding to db
|
||||
sp_importer._create_classtype_table_and_indexes(pairs)
|
||||
temp_db_conn.commit()
|
||||
|
||||
@@ -194,14 +195,16 @@ def test_import_phrases(monkeypatch, temp_db_cursor, def_config, sp_importer,
|
||||
monkeypatch.setattr('nominatim_db.tools.special_phrases.sp_wiki_loader._get_wiki_content',
|
||||
lambda lang: xml_wiki_content)
|
||||
|
||||
class_test = 'aerialway'
|
||||
type_test = 'zip_line'
|
||||
|
||||
tokenizer = tokenizer_mock()
|
||||
placex_table.add(cls=class_test, typ=type_test) # in db for special phrase filtering
|
||||
placex_table.add(cls='amenity', typ='animal_shelter') # in db for special phrase filtering
|
||||
sp_importer.import_phrases(tokenizer, should_replace)
|
||||
|
||||
assert len(tokenizer.analyser_cache['special_phrases']) == 18
|
||||
|
||||
class_test = 'aerialway'
|
||||
type_test = 'zip_line'
|
||||
|
||||
assert check_table_exist(temp_db_cursor, class_test, type_test)
|
||||
assert check_placeid_and_centroid_indexes(temp_db_cursor, class_test, type_test)
|
||||
assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, class_test, type_test)
|
||||
@@ -250,3 +253,38 @@ def check_placeid_and_centroid_indexes(temp_db_cursor, phrase_class, phrase_type
|
||||
and
|
||||
temp_db_cursor.index_exists(table_name, index_prefix + 'place_id')
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("should_replace", [(True), (False)])
|
||||
def test_import_phrases_special_phrase_filtering(monkeypatch, temp_db_cursor, def_config,
|
||||
sp_importer, placex_table, tokenizer_mock,
|
||||
xml_wiki_content, should_replace):
|
||||
|
||||
monkeypatch.setattr('nominatim_db.tools.special_phrases.sp_wiki_loader._get_wiki_content',
|
||||
lambda lang: xml_wiki_content)
|
||||
|
||||
class_test = 'aerialway'
|
||||
type_test = 'zip_line'
|
||||
|
||||
placex_table.add(cls=class_test, typ=type_test) # add to the database to make valid
|
||||
tokenizer = tokenizer_mock()
|
||||
sp_importer.import_phrases(tokenizer, should_replace)
|
||||
|
||||
assert ('Zip Line', 'aerialway', 'zip_line', '-') in sp_importer.word_phrases
|
||||
assert check_table_exist(temp_db_cursor, class_test, type_test)
|
||||
assert check_placeid_and_centroid_indexes(temp_db_cursor, class_test, type_test)
|
||||
assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, class_test, type_test)
|
||||
|
||||
|
||||
def test_get_classtype_pairs_directly(placex_table, temp_db_conn, sp_importer):
|
||||
for _ in range(101):
|
||||
placex_table.add(cls='highway', typ='residential')
|
||||
for _ in range(99):
|
||||
placex_table.add(cls='amenity', typ='toilet')
|
||||
|
||||
temp_db_conn.commit()
|
||||
|
||||
result = sp_importer.get_classtype_pairs(100)
|
||||
print("RESULT:", result)
|
||||
assert ('highway', 'residential') in result
|
||||
assert ('amenity', 'toilet') not in result
|
||||
|
||||
69
test/python/tools/test_sp_importer.py
Normal file
69
test/python/tools/test_sp_importer.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from nominatim_db.tools.special_phrases.sp_importer import SPImporter
|
||||
|
||||
|
||||
# Testing Database Class Pair Retrival using Conftest.py and placex
|
||||
def test_get_classtype_pair_data(placex_table, def_config, temp_db_conn):
|
||||
for _ in range(101):
|
||||
placex_table.add(cls='highway', typ='motorway') # edge case 101
|
||||
|
||||
for _ in range(99):
|
||||
placex_table.add(cls='amenity', typ='prison') # edge case 99
|
||||
|
||||
for _ in range(150):
|
||||
placex_table.add(cls='tourism', typ='hotel')
|
||||
|
||||
importer = SPImporter(config=def_config, conn=temp_db_conn, sp_loader=None)
|
||||
|
||||
result = importer.get_classtype_pairs(min=100)
|
||||
|
||||
expected = {
|
||||
("highway", "motorway"),
|
||||
("tourism", "hotel")
|
||||
}
|
||||
|
||||
assert result == expected, f"Expected {expected}, got {result}"
|
||||
|
||||
|
||||
def test_get_classtype_pair_data_more(placex_table, def_config, temp_db_conn):
|
||||
for _ in range(100):
|
||||
placex_table.add(cls='emergency', typ='firehydrant') # edge case 100, not included
|
||||
|
||||
for _ in range(199):
|
||||
placex_table.add(cls='amenity', typ='prison')
|
||||
|
||||
for _ in range(3478):
|
||||
placex_table.add(cls='tourism', typ='hotel')
|
||||
|
||||
importer = SPImporter(config=def_config, conn=temp_db_conn, sp_loader=None)
|
||||
|
||||
result = importer.get_classtype_pairs(min=100)
|
||||
|
||||
expected = {
|
||||
("amenity", "prison"),
|
||||
("tourism", "hotel")
|
||||
}
|
||||
|
||||
assert result == expected, f"Expected {expected}, got {result}"
|
||||
|
||||
|
||||
def test_get_classtype_pair_data_default(placex_table, def_config, temp_db_conn):
|
||||
for _ in range(1):
|
||||
placex_table.add(cls='emergency', typ='firehydrant')
|
||||
|
||||
for _ in range(199):
|
||||
placex_table.add(cls='amenity', typ='prison')
|
||||
|
||||
for _ in range(3478):
|
||||
placex_table.add(cls='tourism', typ='hotel')
|
||||
|
||||
importer = SPImporter(config=def_config, conn=temp_db_conn, sp_loader=None)
|
||||
|
||||
result = importer.get_classtype_pairs()
|
||||
|
||||
expected = {
|
||||
("amenity", "prison"),
|
||||
("tourism", "hotel"),
|
||||
("emergency", "firehydrant")
|
||||
}
|
||||
|
||||
assert result == expected, f"Expected {expected}, got {result}"
|
||||
Reference in New Issue
Block a user