Removed magic mocking, using monkeypatch instead, and using a placex table to simulate a 'real database'

This commit is contained in:
anqixxx
2025-04-11 12:03:57 -07:00
parent 1a323165f9
commit 1952290359
2 changed files with 106 additions and 141 deletions

View File

@@ -16,7 +16,7 @@
from typing import Iterable, Tuple, Mapping, Sequence, Optional, Set from typing import Iterable, Tuple, Mapping, Sequence, Optional, Set
import logging import logging
import re import re
import json import json
from psycopg.sql import Identifier, SQL from psycopg.sql import Identifier, SQL
from ...typing import Protocol from ...typing import Protocol
@@ -65,37 +65,37 @@ class SPImporter():
# special phrases class/type on the wiki. # special phrases class/type on the wiki.
self.table_phrases_to_delete: Set[str] = set() self.table_phrases_to_delete: Set[str] = set()
def get_classtype_pairs_style(self) -> Set[Tuple[str, str]]: def get_classtype_pairs_style(self) -> Set[Tuple[str, str]]:
""" """
Returns list of allowed special phrases from the the style file, Returns list of allowed special phrases from the the style file,
restricting to a list of combinations of classes and types restricting to a list of combinations of classes and types
which have a 'main' property which have a 'main' property
Note: This requirement was from 2021 and I am a bit unsure if it is still relevant Note: This requirement was from 2021 and I am a bit unsure if it is still relevant
""" """
style_file = self.config.get_import_style_file() # this gives the path, so i will import it as a json style_file = self.config.get_import_style_file() # import style file as json
with open(style_file, 'r') as file: with open(style_file, 'r') as file:
style_data = json.loads(f'[{file.read()}]') style_data = json.loads(f'[{file.read()}]')
style_combinations = set() style_combinations = set()
for _map in style_data: # following ../settings/import-extratags.style for _map in style_data: # following ../settings/import-extratags.style
classes = _map.get("keys", []) classes = _map.get("keys", [])
values = _map.get("values", {}) values = _map.get("values", {})
for _type, properties in values.items(): for _type, properties in values.items():
if "main" in properties and _type: # make sure the tag is not an empty string. since type is the value of the main tag if "main" in properties and _type: # make sure the tag is a non-empty string
for _class in classes: for _class in classes:
style_combinations.add((_class, _type)) style_combinations.add((_class, _type)) # type is the value of the main tag
return style_combinations return style_combinations
def get_classtype_pairs(self) -> Set[Tuple[str, str]]: def get_classtype_pairs(self) -> Set[Tuple[str, str]]:
""" """
Returns list of allowed special phrases from the database, Returns list of allowed special phrases from the database,
restricting to a list of combinations of classes and types restricting to a list of combinations of classes and types
whic occur more than 100 times which occur more than 100 times
""" """
db_combinations = set() db_combinations = set()
query = """ query = """
SELECT class AS CLS, type AS typ SELECT class AS CLS, type AS typ
FROM placex FROM placex
@@ -104,13 +104,12 @@ class SPImporter():
""" """
with self.db_connection.cursor() as db_cursor: with self.db_connection.cursor() as db_cursor:
db_cursor.execute(SQL(query)) db_cursor.execute(SQL(query))
for row in db_cursor.fetchall(): for row in db_cursor.fetchall():
db_combinations.add((row[0], row[1])) db_combinations.add((row[0], row[1]))
return db_combinations return db_combinations
def import_phrases(self, tokenizer: AbstractTokenizer, should_replace: bool) -> None: def import_phrases(self, tokenizer: AbstractTokenizer, should_replace: bool) -> None:
""" """
Iterate through all SpecialPhrases extracted from the Iterate through all SpecialPhrases extracted from the
@@ -131,11 +130,10 @@ class SPImporter():
if result: if result:
class_type_pairs.add(result) class_type_pairs.add(result)
self._create_classtype_table_and_indexes(class_type_pairs) self._create_classtype_table_and_indexes(class_type_pairs)
if should_replace: if should_replace:
self._remove_non_existent_tables_from_db() self._remove_non_existent_tables_from_db()
self.db_connection.commit() self.db_connection.commit()
with tokenizer.name_analyzer() as analyzer: with tokenizer.name_analyzer() as analyzer:
@@ -235,7 +233,7 @@ class SPImporter():
LOG.warning("Skipping phrase %s=%s: not in allowed special phrases", LOG.warning("Skipping phrase %s=%s: not in allowed special phrases",
phrase_class, phrase_type) phrase_class, phrase_type)
continue continue
table_name = _classtype_table(phrase_class, phrase_type) table_name = _classtype_table(phrase_class, phrase_type)
if table_name in self.table_phrases_to_delete: if table_name in self.table_phrases_to_delete:
@@ -268,13 +266,13 @@ class SPImporter():
""" """
table_name = _classtype_table(phrase_class, phrase_type) table_name = _classtype_table(phrase_class, phrase_type)
with self.db_connection.cursor() as db_cursor: with self.db_connection.cursor() as db_cursor:
db_cursor.execute(SQL("""CREATE TABLE IF NOT EXISTS {} {} AS db_cursor.execute(SQL(
SELECT place_id AS place_id, """CREATE TABLE IF NOT EXISTS {} {} AS
st_centroid(geometry) AS centroid SELECT place_id AS place_id,
FROM placex st_centroid(geometry) AS centroid
WHERE class = %s AND type = %s FROM placex WHERE class = %s AND type = %s
""").format(Identifier(table_name), SQL(sql_tablespace)), """).format(Identifier(table_name), SQL(sql_tablespace)),
(phrase_class, phrase_type)) (phrase_class, phrase_type))
def _create_place_classtype_indexes(self, sql_tablespace: str, def _create_place_classtype_indexes(self, sql_tablespace: str,
phrase_class: str, phrase_type: str) -> None: phrase_class: str, phrase_type: str) -> None:
@@ -321,4 +319,3 @@ class SPImporter():
drop_tables(self.db_connection, *self.table_phrases_to_delete) drop_tables(self.db_connection, *self.table_phrases_to_delete)
for _ in self.table_phrases_to_delete: for _ in self.table_phrases_to_delete:
self.statistics_handler.notify_one_table_deleted() self.statistics_handler.notify_one_table_deleted()

View File

@@ -2,13 +2,10 @@ import pytest
import tempfile import tempfile
import json import json
import os import os
from unittest.mock import MagicMock
from nominatim_db.errors import UsageError
from nominatim_db.tools.special_phrases.sp_csv_loader import SPCsvLoader
from nominatim_db.tools.special_phrases.special_phrase import SpecialPhrase
from nominatim_db.tools.special_phrases.sp_importer import SPImporter from nominatim_db.tools.special_phrases.sp_importer import SPImporter
# Testing Style Class Pair Retrival
@pytest.fixture @pytest.fixture
def sample_style_file(): def sample_style_file():
sample_data = [ sample_data = [
@@ -59,135 +56,106 @@ def sample_style_file():
yield tmp_path yield tmp_path
os.remove(tmp_path) os.remove(tmp_path)
def test_get_classtype_style(sample_style_file):
class Config:
def get_import_style_file(self):
return sample_style_file
def load_sub_configuration(self, name):
return {'blackList': {}, 'whiteList': {}}
def test_get_sp_style(sample_style_file): config = Config()
mock_config = MagicMock() importer = SPImporter(config=config, conn=None, sp_loader=None)
mock_config.get_import_style_file.return_value = sample_style_file
importer = SPImporter(config=mock_config, conn=None, sp_loader=None) result = importer.get_classtype_pairs_style()
result = importer.get_sp_style()
expected = { expected = {
("highway", "motorway"), ("highway", "motorway"),
} }
assert expected.issubset(result)
# Testing Database Class Pair Retrival using Mock Database
def test_get_classtype_pairs(monkeypatch):
class Config:
def load_sub_configuration(self, path, section=None):
return {"blackList": {}, "whiteList": {}}
class Cursor:
def execute(self, query): pass
def fetchall(self):
return [
("highway", "motorway"),
("historic", "castle")
]
def __enter__(self): return self
def __exit__(self, exc_type, exc_val, exc_tb): pass
class Connection:
def cursor(self): return Cursor()
config = Config()
conn = Connection()
importer = SPImporter(config=config, conn=conn, sp_loader=None)
result = importer.get_classtype_pairs()
expected = {
("highway", "motorway"),
("historic", "castle")
}
assert result == expected assert result == expected
@pytest.fixture # Testing Database Class Pair Retrival using Conftest.py and placex
def mock_phrase(): def test_get_classtype_pair_data(placex_table, temp_db_conn):
return SpecialPhrase( class Config:
p_label="test", def load_sub_configuration(self, *_):
p_class="highway", return {'blackList': {}, 'whiteList': {}}
p_type="motorway",
p_operator="eq" for _ in range(101):
) placex_table.add(cls='highway', typ='motorway') # edge case 101
def test_create_classtype_table_and_indexes(): for _ in range(99):
mock_config = MagicMock() placex_table.add(cls='amenity', typ='prison') # edge case 99
mock_config.TABLESPACE_AUX_DATA = ''
mock_config.DATABASE_WEBUSER = 'www-data'
mock_cursor = MagicMock() for _ in range(150):
mock_conn = MagicMock() placex_table.add(cls='tourism', typ='hotel')
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=None) config = Config()
importer = SPImporter(config=config, conn=temp_db_conn, sp_loader=None)
importer._create_place_classtype_table = MagicMock() result = importer.get_classtype_pairs()
importer._create_place_classtype_indexes = MagicMock()
importer._grant_access_to_webuser = MagicMock()
importer.statistics_handler.notify_one_table_created = lambda: print("✓ Created table")
importer.statistics_handler.notify_one_table_ignored = lambda: print("⨉ Ignored table")
importer.table_phrases_to_delete = {"place_classtype_highway_motorway"} expected = {
("highway", "motorway"),
test_pairs = [("highway", "motorway"), ("natural", "peak")] ("tourism", "hotel")
importer._create_classtype_table_and_indexes(test_pairs) }
print("create_place_classtype_table calls:")
for call in importer._create_place_classtype_table.call_args_list:
print(call)
print("\ncreate_place_classtype_indexes calls:")
for call in importer._create_place_classtype_indexes.call_args_list:
print(call)
print("\ngrant_access_to_webuser calls:")
for call in importer._grant_access_to_webuser.call_args_list:
print(call)
@pytest.fixture
def mock_config():
config = MagicMock()
config.TABLESPACE_AUX_DATA = ''
config.DATABASE_WEBUSER = 'www-data'
config.load_sub_configuration.return_value = {'blackList': {}, 'whiteList': {}}
return config
def test_import_phrases_original(mock_config):
phrase = SpecialPhrase("roundabout", "highway", "motorway", "eq")
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor
mock_loader = MagicMock()
mock_loader.generate_phrases.return_value = [phrase]
mock_analyzer = MagicMock()
mock_tokenizer = MagicMock()
mock_tokenizer.name_analyzer.return_value.__enter__.return_value = mock_analyzer
importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=mock_loader)
importer._fetch_existing_place_classtype_tables = MagicMock()
importer._create_classtype_table_and_indexes = MagicMock()
importer._remove_non_existent_tables_from_db = MagicMock()
importer.import_phrases(tokenizer=mock_tokenizer, should_replace=True)
assert importer.word_phrases == {("roundabout", "highway", "motorway", "-")}
mock_analyzer.update_special_phrases.assert_called_once_with(
importer.word_phrases, True
)
def test_get_sp_filters_correctly(sample_style_file):
mock_config = MagicMock()
mock_config.get_import_style_file.return_value = sample_style_file
mock_config.load_sub_configuration.return_value = {"blackList": {}, "whiteList": {}}
importer = SPImporter(config=mock_config, conn=MagicMock(), sp_loader=None)
allowed_from_db = {("highway", "motorway"), ("historic", "castle")}
importer.get_sp_db = lambda: allowed_from_db
result = importer.get_sp()
expected = {("highway", "motorway")}
assert result == expected, f"Expected {expected}, got {result}" assert result == expected, f"Expected {expected}, got {result}"
def test_get_sp_db_filters_by_count_threshold(mock_config): def test_get_classtype_pair_data_more(placex_table, temp_db_conn):
mock_cursor = MagicMock() class Config:
def load_sub_configuration(self, *_):
# Simulate only results above the threshold being returned (as SQL would) return {'blackList': {}, 'whiteList': {}}
# These tuples simulate real SELECT class, type FROM placex GROUP BY ... HAVING COUNT(*) > 100
mock_cursor.fetchall.return_value = [ for _ in range(100):
("highway", "motorway"), placex_table.add(cls='emergency', typ='firehydrant') # edge case 100, not included
("historic", "castle")
]
mock_conn = MagicMock() for _ in range(199):
mock_conn.cursor.return_value.__enter__.return_value = mock_cursor placex_table.add(cls='amenity', typ='prison')
importer = SPImporter(config=mock_config, conn=mock_conn, sp_loader=None)
result = importer.get_sp_db() for _ in range(3478):
placex_table.add(cls='tourism', typ='hotel')
config = Config()
importer = SPImporter(config=config, conn=temp_db_conn, sp_loader=None)
result = importer.get_classtype_pairs()
expected = { expected = {
("highway", "motorway"), ("amenity", "prison"),
("historic", "castle") ("tourism", "hotel")
} }
assert result == expected assert result == expected, f"Expected {expected}, got {result}"
mock_cursor.execute.assert_called_once()