Filter special phrases by style and frequency to fix #235

This commit is contained in:
anqixxx
2025-04-07 21:40:42 -07:00
parent 800c56642b
commit 1a323165f9
2 changed files with 253 additions and 4 deletions

View File

@@ -16,7 +16,7 @@
from typing import Iterable, Tuple, Mapping, Sequence, Optional, Set from typing import Iterable, Tuple, Mapping, Sequence, Optional, Set
import logging import logging
import re import re
import json
from psycopg.sql import Identifier, SQL from psycopg.sql import Identifier, SQL
from ...typing import Protocol from ...typing import Protocol
@@ -65,6 +65,52 @@ class SPImporter():
# special phrases class/type on the wiki. # special phrases class/type on the wiki.
self.table_phrases_to_delete: Set[str] = set() self.table_phrases_to_delete: Set[str] = set()
def get_classtype_pairs_style(self) -> Set[Tuple[str, str]]:
"""
Returns list of allowed special phrases from the the style file,
restricting to a list of combinations of classes and types
which have a 'main' property
Note: This requirement was from 2021 and I am a bit unsure if it is still relevant
"""
style_file = self.config.get_import_style_file() # this gives the path, so i will import it as a json
with open(style_file, 'r') as file:
style_data = json.loads(f'[{file.read()}]')
style_combinations = set()
for _map in style_data: # following ../settings/import-extratags.style
classes = _map.get("keys", [])
values = _map.get("values", {})
for _type, properties in values.items():
if "main" in properties and _type: # make sure the tag is not an empty string. since type is the value of the main tag
for _class in classes:
style_combinations.add((_class, _type))
return style_combinations
def get_classtype_pairs(self) -> Set[Tuple[str, str]]:
"""
Returns list of allowed special phrases from the database,
restricting to a list of combinations of classes and types
whic occur more than 100 times
"""
db_combinations = set()
query = """
SELECT class AS CLS, type AS typ
FROM placex
GROUP BY class, type
HAVING COUNT(*) > 100
"""
with self.db_connection.cursor() as db_cursor:
db_cursor.execute(SQL(query))
for row in db_cursor.fetchall():
db_combinations.add((row[0], row[1]))
return db_combinations
def import_phrases(self, tokenizer: AbstractTokenizer, should_replace: bool) -> None: def import_phrases(self, tokenizer: AbstractTokenizer, should_replace: bool) -> None:
""" """
Iterate through all SpecialPhrases extracted from the Iterate through all SpecialPhrases extracted from the
@@ -85,9 +131,11 @@ class SPImporter():
if result: if result:
class_type_pairs.add(result) class_type_pairs.add(result)
self._create_classtype_table_and_indexes(class_type_pairs) self._create_classtype_table_and_indexes(class_type_pairs)
if should_replace: if should_replace:
self._remove_non_existent_tables_from_db() self._remove_non_existent_tables_from_db()
self.db_connection.commit() self.db_connection.commit()
with tokenizer.name_analyzer() as analyzer: with tokenizer.name_analyzer() as analyzer:
@@ -177,10 +225,17 @@ class SPImporter():
with self.db_connection.cursor() as db_cursor: with self.db_connection.cursor() as db_cursor:
db_cursor.execute("CREATE INDEX idx_placex_classtype ON placex (class, type)") db_cursor.execute("CREATE INDEX idx_placex_classtype ON placex (class, type)")
allowed_special_phrases = self.get_classtype_pairs()
for pair in class_type_pairs: for pair in class_type_pairs:
phrase_class = pair[0] phrase_class = pair[0]
phrase_type = pair[1] phrase_type = pair[1]
if (phrase_class, phrase_type) not in allowed_special_phrases:
LOG.warning("Skipping phrase %s=%s: not in allowed special phrases",
phrase_class, phrase_type)
continue
table_name = _classtype_table(phrase_class, phrase_type) table_name = _classtype_table(phrase_class, phrase_type)
if table_name in self.table_phrases_to_delete: if table_name in self.table_phrases_to_delete:
@@ -212,8 +267,8 @@ class SPImporter():
if doesn't exit. if doesn't exit.
""" """
table_name = _classtype_table(phrase_class, phrase_type) table_name = _classtype_table(phrase_class, phrase_type)
with self.db_connection.cursor() as cur: with self.db_connection.cursor() as db_cursor:
cur.execute(SQL("""CREATE TABLE IF NOT EXISTS {} {} AS db_cursor.execute(SQL("""CREATE TABLE IF NOT EXISTS {} {} AS
SELECT place_id AS place_id, SELECT place_id AS place_id,
st_centroid(geometry) AS centroid st_centroid(geometry) AS centroid
FROM placex FROM placex
@@ -266,3 +321,4 @@ class SPImporter():
drop_tables(self.db_connection, *self.table_phrases_to_delete) drop_tables(self.db_connection, *self.table_phrases_to_delete)
for _ in self.table_phrases_to_delete: for _ in self.table_phrases_to_delete:
self.statistics_handler.notify_one_table_deleted() self.statistics_handler.notify_one_table_deleted()

View File

@@ -0,0 +1,193 @@
import pytest
import tempfile
import json
import os
from unittest.mock import MagicMock
from nominatim_db.errors import UsageError
from nominatim_db.tools.special_phrases.sp_csv_loader import SPCsvLoader
from nominatim_db.tools.special_phrases.special_phrase import SpecialPhrase
from nominatim_db.tools.special_phrases.sp_importer import SPImporter
@pytest.fixture
def sample_style_file():
sample_data = [
{
"keys" : ["emergency"],
"values" : {
"fire_hydrant" : "skip",
"yes" : "skip",
"no" : "skip",
"" : "main"
}
},
{
"keys" : ["historic", "military"],
"values" : {
"no" : "skip",
"yes" : "skip",
"" : "main"
}
},
{
"keys" : ["name:prefix", "name:suffix", "name:prefix:*", "name:suffix:*",
"name:botanical", "wikidata", "*:wikidata"],
"values" : {
"" : "extra"
}
},
{
"keys" : ["addr:housename"],
"values" : {
"" : "name,house"
}
},
{
"keys": ["highway"],
"values": {
"motorway": "main",
"": "skip"
}
}
]
content = ",".join(json.dumps(entry) for entry in sample_data)
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tmp:
tmp.write(content)
tmp_path = tmp.name
yield tmp_path
os.remove(tmp_path)
def test_get_sp_style(sample_style_file):
mock_config = MagicMock()
mock_config.get_import_style_file.return_value = sample_style_file
importer = SPImporter(config=mock_config, conn=None, sp_loader=None)
result = importer.get_sp_style()
expected = {
("highway", "motorway"),
}
assert result == expected
@pytest.fixture
def mock_phrase():
return SpecialPhrase(
p_label="test",
p_class="highway",
p_type="motorway",
p_operator="eq"
)
def test_create_classtype_table_and_indexes():
mock_config = MagicMock()
mock_config.TABLESPACE_AUX_DATA = ''
mock_config.DATABASE_WEBUSER = 'www-data'
mock_cursor = MagicMock()
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=None)
importer._create_place_classtype_table = MagicMock()
importer._create_place_classtype_indexes = MagicMock()
importer._grant_access_to_webuser = MagicMock()
importer.statistics_handler.notify_one_table_created = lambda: print("✓ Created table")
importer.statistics_handler.notify_one_table_ignored = lambda: print("⨉ Ignored table")
importer.table_phrases_to_delete = {"place_classtype_highway_motorway"}
test_pairs = [("highway", "motorway"), ("natural", "peak")]
importer._create_classtype_table_and_indexes(test_pairs)
print("create_place_classtype_table calls:")
for call in importer._create_place_classtype_table.call_args_list:
print(call)
print("\ncreate_place_classtype_indexes calls:")
for call in importer._create_place_classtype_indexes.call_args_list:
print(call)
print("\ngrant_access_to_webuser calls:")
for call in importer._grant_access_to_webuser.call_args_list:
print(call)
@pytest.fixture
def mock_config():
config = MagicMock()
config.TABLESPACE_AUX_DATA = ''
config.DATABASE_WEBUSER = 'www-data'
config.load_sub_configuration.return_value = {'blackList': {}, 'whiteList': {}}
return config
def test_import_phrases_original(mock_config):
phrase = SpecialPhrase("roundabout", "highway", "motorway", "eq")
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_loader = MagicMock()
mock_loader.generate_phrases.return_value = [phrase]
mock_analyzer = MagicMock()
mock_tokenizer = MagicMock()
mock_tokenizer.name_analyzer.return_value.__enter__.return_value = mock_analyzer
importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=mock_loader)
importer._fetch_existing_place_classtype_tables = MagicMock()
importer._create_classtype_table_and_indexes = MagicMock()
importer._remove_non_existent_tables_from_db = MagicMock()
importer.import_phrases(tokenizer=mock_tokenizer, should_replace=True)
assert importer.word_phrases == {("roundabout", "highway", "motorway", "-")}
mock_analyzer.update_special_phrases.assert_called_once_with(
importer.word_phrases, True
)
def test_get_sp_filters_correctly(sample_style_file):
mock_config = MagicMock()
mock_config.get_import_style_file.return_value = sample_style_file
mock_config.load_sub_configuration.return_value = {"blackList": {}, "whiteList": {}}
importer = SPImporter(config=mock_config, conn=MagicMock(), sp_loader=None)
allowed_from_db = {("highway", "motorway"), ("historic", "castle")}
importer.get_sp_db = lambda: allowed_from_db
result = importer.get_sp()
expected = {("highway", "motorway")}
assert result == expected, f"Expected {expected}, got {result}"
def test_get_sp_db_filters_by_count_threshold(mock_config):
mock_cursor = MagicMock()
# Simulate only results above the threshold being returned (as SQL would)
# These tuples simulate real SELECT class, type FROM placex GROUP BY ... HAVING COUNT(*) > 100
mock_cursor.fetchall.return_value = [
("highway", "motorway"),
("historic", "castle")
]
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=None)
result = importer.get_sp_db()
expected = {
("highway", "motorway"),
("historic", "castle")
}
assert result == expected
mock_cursor.execute.assert_called_once()