Removed magic mocking, using monkeypatch instead, and using a placex table to simulate a 'real database'

This commit is contained in:
anqixxx
2025-04-11 12:03:57 -07:00
parent 1a323165f9
commit 1952290359
2 changed files with 106 additions and 141 deletions

View File

@@ -65,7 +65,7 @@ class SPImporter():
# special phrases class/type on the wiki. # special phrases class/type on the wiki.
self.table_phrases_to_delete: Set[str] = set() self.table_phrases_to_delete: Set[str] = set()
def get_classtype_pairs_style(self) -> Set[Tuple[str, str]]: def get_classtype_pairs_style(self) -> Set[Tuple[str, str]]:
""" """
Returns list of allowed special phrases from the the style file, Returns list of allowed special phrases from the the style file,
restricting to a list of combinations of classes and types restricting to a list of combinations of classes and types
@@ -73,27 +73,27 @@ class SPImporter():
Note: This requirement was from 2021 and I am a bit unsure if it is still relevant Note: This requirement was from 2021 and I am a bit unsure if it is still relevant
""" """
style_file = self.config.get_import_style_file() # this gives the path, so i will import it as a json style_file = self.config.get_import_style_file() # import style file as json
with open(style_file, 'r') as file: with open(style_file, 'r') as file:
style_data = json.loads(f'[{file.read()}]') style_data = json.loads(f'[{file.read()}]')
style_combinations = set() style_combinations = set()
for _map in style_data: # following ../settings/import-extratags.style for _map in style_data: # following ../settings/import-extratags.style
classes = _map.get("keys", []) classes = _map.get("keys", [])
values = _map.get("values", {}) values = _map.get("values", {})
for _type, properties in values.items(): for _type, properties in values.items():
if "main" in properties and _type: # make sure the tag is not an empty string. since type is the value of the main tag if "main" in properties and _type: # make sure the tag is a non-empty string
for _class in classes: for _class in classes:
style_combinations.add((_class, _type)) style_combinations.add((_class, _type)) # type is the value of the main tag
return style_combinations return style_combinations
def get_classtype_pairs(self) -> Set[Tuple[str, str]]: def get_classtype_pairs(self) -> Set[Tuple[str, str]]:
""" """
Returns list of allowed special phrases from the database, Returns list of allowed special phrases from the database,
restricting to a list of combinations of classes and types restricting to a list of combinations of classes and types
whic occur more than 100 times which occur more than 100 times
""" """
db_combinations = set() db_combinations = set()
query = """ query = """
@@ -108,8 +108,7 @@ class SPImporter():
for row in db_cursor.fetchall(): for row in db_cursor.fetchall():
db_combinations.add((row[0], row[1])) db_combinations.add((row[0], row[1]))
return db_combinations return db_combinations
def import_phrases(self, tokenizer: AbstractTokenizer, should_replace: bool) -> None: def import_phrases(self, tokenizer: AbstractTokenizer, should_replace: bool) -> None:
""" """
@@ -135,7 +134,6 @@ class SPImporter():
if should_replace: if should_replace:
self._remove_non_existent_tables_from_db() self._remove_non_existent_tables_from_db()
self.db_connection.commit() self.db_connection.commit()
with tokenizer.name_analyzer() as analyzer: with tokenizer.name_analyzer() as analyzer:
@@ -268,13 +266,13 @@ class SPImporter():
""" """
table_name = _classtype_table(phrase_class, phrase_type) table_name = _classtype_table(phrase_class, phrase_type)
with self.db_connection.cursor() as db_cursor: with self.db_connection.cursor() as db_cursor:
db_cursor.execute(SQL("""CREATE TABLE IF NOT EXISTS {} {} AS db_cursor.execute(SQL(
SELECT place_id AS place_id, """CREATE TABLE IF NOT EXISTS {} {} AS
st_centroid(geometry) AS centroid SELECT place_id AS place_id,
FROM placex st_centroid(geometry) AS centroid
WHERE class = %s AND type = %s FROM placex WHERE class = %s AND type = %s
""").format(Identifier(table_name), SQL(sql_tablespace)), """).format(Identifier(table_name), SQL(sql_tablespace)),
(phrase_class, phrase_type)) (phrase_class, phrase_type))
def _create_place_classtype_indexes(self, sql_tablespace: str, def _create_place_classtype_indexes(self, sql_tablespace: str,
phrase_class: str, phrase_type: str) -> None: phrase_class: str, phrase_type: str) -> None:
@@ -321,4 +319,3 @@ class SPImporter():
drop_tables(self.db_connection, *self.table_phrases_to_delete) drop_tables(self.db_connection, *self.table_phrases_to_delete)
for _ in self.table_phrases_to_delete: for _ in self.table_phrases_to_delete:
self.statistics_handler.notify_one_table_deleted() self.statistics_handler.notify_one_table_deleted()

View File

@@ -2,13 +2,10 @@ import pytest
import tempfile import tempfile
import json import json
import os import os
from unittest.mock import MagicMock
from nominatim_db.errors import UsageError
from nominatim_db.tools.special_phrases.sp_csv_loader import SPCsvLoader
from nominatim_db.tools.special_phrases.special_phrase import SpecialPhrase
from nominatim_db.tools.special_phrases.sp_importer import SPImporter from nominatim_db.tools.special_phrases.sp_importer import SPImporter
# Testing Style Class Pair Retrival
@pytest.fixture @pytest.fixture
def sample_style_file(): def sample_style_file():
sample_data = [ sample_data = [
@@ -59,135 +56,106 @@ def sample_style_file():
yield tmp_path yield tmp_path
os.remove(tmp_path) os.remove(tmp_path)
def test_get_classtype_style(sample_style_file):
class Config:
def get_import_style_file(self):
return sample_style_file
def test_get_sp_style(sample_style_file): def load_sub_configuration(self, name):
mock_config = MagicMock() return {'blackList': {}, 'whiteList': {}}
mock_config.get_import_style_file.return_value = sample_style_file
importer = SPImporter(config=mock_config, conn=None, sp_loader=None) config = Config()
result = importer.get_sp_style() importer = SPImporter(config=config, conn=None, sp_loader=None)
result = importer.get_classtype_pairs_style()
expected = { expected = {
("highway", "motorway"), ("highway", "motorway"),
} }
assert expected.issubset(result)
# Testing Database Class Pair Retrival using Mock Database
def test_get_classtype_pairs(monkeypatch):
class Config:
def load_sub_configuration(self, path, section=None):
return {"blackList": {}, "whiteList": {}}
class Cursor:
def execute(self, query): pass
def fetchall(self):
return [
("highway", "motorway"),
("historic", "castle")
]
def __enter__(self): return self
def __exit__(self, exc_type, exc_val, exc_tb): pass
class Connection:
def cursor(self): return Cursor()
config = Config()
conn = Connection()
importer = SPImporter(config=config, conn=conn, sp_loader=None)
result = importer.get_classtype_pairs()
expected = {
("highway", "motorway"),
("historic", "castle")
}
assert result == expected assert result == expected
@pytest.fixture # Testing Database Class Pair Retrival using Conftest.py and placex
def mock_phrase(): def test_get_classtype_pair_data(placex_table, temp_db_conn):
return SpecialPhrase( class Config:
p_label="test", def load_sub_configuration(self, *_):
p_class="highway", return {'blackList': {}, 'whiteList': {}}
p_type="motorway",
p_operator="eq"
)
def test_create_classtype_table_and_indexes(): for _ in range(101):
mock_config = MagicMock() placex_table.add(cls='highway', typ='motorway') # edge case 101
mock_config.TABLESPACE_AUX_DATA = ''
mock_config.DATABASE_WEBUSER = 'www-data'
mock_cursor = MagicMock() for _ in range(99):
mock_conn = MagicMock() placex_table.add(cls='amenity', typ='prison') # edge case 99
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=None) for _ in range(150):
placex_table.add(cls='tourism', typ='hotel')
importer._create_place_classtype_table = MagicMock() config = Config()
importer._create_place_classtype_indexes = MagicMock() importer = SPImporter(config=config, conn=temp_db_conn, sp_loader=None)
importer._grant_access_to_webuser = MagicMock()
importer.statistics_handler.notify_one_table_created = lambda: print("✓ Created table")
importer.statistics_handler.notify_one_table_ignored = lambda: print("⨉ Ignored table")
importer.table_phrases_to_delete = {"place_classtype_highway_motorway"} result = importer.get_classtype_pairs()
test_pairs = [("highway", "motorway"), ("natural", "peak")] expected = {
importer._create_classtype_table_and_indexes(test_pairs) ("highway", "motorway"),
("tourism", "hotel")
print("create_place_classtype_table calls:") }
for call in importer._create_place_classtype_table.call_args_list:
print(call)
print("\ncreate_place_classtype_indexes calls:")
for call in importer._create_place_classtype_indexes.call_args_list:
print(call)
print("\ngrant_access_to_webuser calls:")
for call in importer._grant_access_to_webuser.call_args_list:
print(call)
@pytest.fixture
def mock_config():
config = MagicMock()
config.TABLESPACE_AUX_DATA = ''
config.DATABASE_WEBUSER = 'www-data'
config.load_sub_configuration.return_value = {'blackList': {}, 'whiteList': {}}
return config
def test_import_phrases_original(mock_config):
phrase = SpecialPhrase("roundabout", "highway", "motorway", "eq")
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_loader = MagicMock()
mock_loader.generate_phrases.return_value = [phrase]
mock_analyzer = MagicMock()
mock_tokenizer = MagicMock()
mock_tokenizer.name_analyzer.return_value.__enter__.return_value = mock_analyzer
importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=mock_loader)
importer._fetch_existing_place_classtype_tables = MagicMock()
importer._create_classtype_table_and_indexes = MagicMock()
importer._remove_non_existent_tables_from_db = MagicMock()
importer.import_phrases(tokenizer=mock_tokenizer, should_replace=True)
assert importer.word_phrases == {("roundabout", "highway", "motorway", "-")}
mock_analyzer.update_special_phrases.assert_called_once_with(
importer.word_phrases, True
)
def test_get_sp_filters_correctly(sample_style_file):
mock_config = MagicMock()
mock_config.get_import_style_file.return_value = sample_style_file
mock_config.load_sub_configuration.return_value = {"blackList": {}, "whiteList": {}}
importer = SPImporter(config=mock_config, conn=MagicMock(), sp_loader=None)
allowed_from_db = {("highway", "motorway"), ("historic", "castle")}
importer.get_sp_db = lambda: allowed_from_db
result = importer.get_sp()
expected = {("highway", "motorway")}
assert result == expected, f"Expected {expected}, got {result}" assert result == expected, f"Expected {expected}, got {result}"
def test_get_sp_db_filters_by_count_threshold(mock_config): def test_get_classtype_pair_data_more(placex_table, temp_db_conn):
mock_cursor = MagicMock() class Config:
def load_sub_configuration(self, *_):
return {'blackList': {}, 'whiteList': {}}
# Simulate only results above the threshold being returned (as SQL would) for _ in range(100):
# These tuples simulate real SELECT class, type FROM placex GROUP BY ... HAVING COUNT(*) > 100 placex_table.add(cls='emergency', typ='firehydrant') # edge case 100, not included
mock_cursor.fetchall.return_value = [
("highway", "motorway"),
("historic", "castle")
]
mock_conn = MagicMock() for _ in range(199):
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor placex_table.add(cls='amenity', typ='prison')
importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=None)
result = importer.get_sp_db() for _ in range(3478):
placex_table.add(cls='tourism', typ='hotel')
config = Config()
importer = SPImporter(config=config, conn=temp_db_conn, sp_loader=None)
result = importer.get_classtype_pairs()
expected = { expected = {
("highway", "motorway"), ("amenity", "prison"),
("historic", "castle") ("tourism", "hotel")
} }
assert result == expected assert result == expected, f"Expected {expected}, got {result}"
mock_cursor.execute.assert_called_once()