make NominatimAPI[Async] a context manager

If close() isn't properly called, it can lead to odd error messages
about uncaught exceptions.
This commit is contained in:
Sarah Hoffmann
2024-08-19 11:31:38 +02:00
parent 8b41b80bff
commit c2594aca40
8 changed files with 53 additions and 65 deletions

View File

@@ -9,6 +9,7 @@ Helper fixtures for API call tests.
"""
from pathlib import Path
import pytest
import pytest_asyncio
import time
import datetime as dt
@@ -244,3 +245,9 @@ def frontend(request, event_loop, tmp_path):
for api in testapis:
api.close()
@pytest_asyncio.fixture
async def api(temp_db):
async with napi.NominatimAPIAsync(Path('/invalid')) as api:
yield api

View File

@@ -40,10 +40,9 @@ async def conn(table_factory):
table_factory('word',
definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
api = NominatimAPIAsync(Path('/invalid'), {})
async with api.begin() as conn:
yield conn
await api.close()
async with NominatimAPIAsync(Path('/invalid'), {}) as api:
async with api.begin() as conn:
yield conn
@pytest.mark.asyncio

View File

@@ -74,10 +74,9 @@ async def conn(table_factory, temp_db_cursor):
temp_db_cursor.execute("""CREATE OR REPLACE FUNCTION make_standard_name(name TEXT)
RETURNS TEXT AS $$ SELECT lower(name); $$ LANGUAGE SQL;""")
api = NominatimAPIAsync(Path('/invalid'), {})
async with api.begin() as conn:
yield conn
await api.close()
async with NominatimAPIAsync(Path('/invalid'), {}) as api:
async with api.begin() as conn:
yield conn
@pytest.mark.asyncio

View File

@@ -11,41 +11,35 @@ from pathlib import Path
import pytest
from nominatim_api import NominatimAPIAsync
from nominatim_api.search.query_analyzer_factory import make_query_analyzer
from nominatim_api.search.icu_tokenizer import ICUQueryAnalyzer
@pytest.mark.asyncio
async def test_import_icu_tokenizer(table_factory):
async def test_import_icu_tokenizer(table_factory, api):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('tokenizer', 'icu'),
('tokenizer_import_normalisation', ':: lower();'),
('tokenizer_import_transliteration', "'1' > '/1/'; 'ä' > 'ä '")))
api = NominatimAPIAsync(Path('/invalid'), {})
async with api.begin() as conn:
ana = await make_query_analyzer(conn)
assert isinstance(ana, ICUQueryAnalyzer)
await api.close()
@pytest.mark.asyncio
async def test_import_missing_property(table_factory):
api = NominatimAPIAsync(Path('/invalid'), {})
async def test_import_missing_property(table_factory, api):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT')
async with api.begin() as conn:
with pytest.raises(ValueError, match='Property.*not found'):
await make_query_analyzer(conn)
await api.close()
@pytest.mark.asyncio
async def test_import_missing_module(table_factory):
api = NominatimAPIAsync(Path('/invalid'), {})
async def test_import_missing_module(table_factory, api):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('tokenizer', 'missing'),))
@@ -53,5 +47,3 @@ async def test_import_missing_module(table_factory):
async with api.begin() as conn:
with pytest.raises(RuntimeError, match='Tokenizer not found'):
await make_query_analyzer(conn)
await api.close()

View File

@@ -9,45 +9,34 @@ Tests for enhanced connection class for API functions.
"""
from pathlib import Path
import pytest
import pytest_asyncio
import sqlalchemy as sa
from nominatim_api import NominatimAPIAsync
@pytest_asyncio.fixture
async def apiobj(temp_db):
""" Create an asynchronous SQLAlchemy engine for the test DB.
"""
api = NominatimAPIAsync(Path('/invalid'), {})
yield api
await api.close()
@pytest.mark.asyncio
async def test_run_scalar(apiobj, table_factory):
async def test_run_scalar(api, table_factory):
table_factory('foo', definition='that TEXT', content=(('a', ),))
async with apiobj.begin() as conn:
async with api.begin() as conn:
assert await conn.scalar(sa.text('SELECT * FROM foo')) == 'a'
@pytest.mark.asyncio
async def test_run_execute(apiobj, table_factory):
async def test_run_execute(api, table_factory):
table_factory('foo', definition='that TEXT', content=(('a', ),))
async with apiobj.begin() as conn:
async with api.begin() as conn:
result = await conn.execute(sa.text('SELECT * FROM foo'))
assert result.fetchone()[0] == 'a'
@pytest.mark.asyncio
async def test_get_property_existing_cached(apiobj, table_factory):
async def test_get_property_existing_cached(api, table_factory):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('dbv', '96723'), ))
async with apiobj.begin() as conn:
async with api.begin() as conn:
assert await conn.get_property('dbv') == '96723'
await conn.execute(sa.text('TRUNCATE nominatim_properties'))
@@ -56,12 +45,12 @@ async def test_get_property_existing_cached(apiobj, table_factory):
@pytest.mark.asyncio
async def test_get_property_existing_uncached(apiobj, table_factory):
async def test_get_property_existing_uncached(api, table_factory):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('dbv', '96723'), ))
async with apiobj.begin() as conn:
async with api.begin() as conn:
assert await conn.get_property('dbv') == '96723'
await conn.execute(sa.text("UPDATE nominatim_properties SET value = '1'"))
@@ -71,23 +60,23 @@ async def test_get_property_existing_uncached(apiobj, table_factory):
@pytest.mark.asyncio
@pytest.mark.parametrize('param', ['foo', 'DB:server_version'])
async def test_get_property_missing(apiobj, table_factory, param):
async def test_get_property_missing(api, table_factory, param):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT')
async with apiobj.begin() as conn:
async with api.begin() as conn:
with pytest.raises(ValueError):
await conn.get_property(param)
@pytest.mark.asyncio
async def test_get_db_property_existing(apiobj):
async with apiobj.begin() as conn:
async def test_get_db_property_existing(api):
async with api.begin() as conn:
assert await conn.get_db_property('server_version') > 0
@pytest.mark.asyncio
async def test_get_db_property_existing(apiobj):
async with apiobj.begin() as conn:
async def test_get_db_property_existing(api):
async with api.begin() as conn:
with pytest.raises(ValueError):
await conn.get_db_property('dfkgjd.rijg')

View File

@@ -11,19 +11,10 @@ import json
from pathlib import Path
import pytest
import pytest_asyncio
from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
import nominatim_api.v1.server_glue as glue
import nominatim_api as napi
@pytest_asyncio.fixture
async def api():
api = napi.NominatimAPIAsync(Path('/invalid'))
yield api
await api.close()
class TestDeletableEndPoint:
@@ -61,4 +52,3 @@ class TestDeletableEndPoint:
{'place_id': 3, 'country_code': 'cd', 'name': None,
'osm_id': 781, 'osm_type': 'R',
'class': 'landcover', 'type': 'grass'}]

View File

@@ -12,19 +12,10 @@ import datetime as dt
from pathlib import Path
import pytest
import pytest_asyncio
from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
import nominatim_api.v1.server_glue as glue
import nominatim_api as napi
@pytest_asyncio.fixture
async def api():
api = napi.NominatimAPIAsync(Path('/invalid'))
yield api
await api.close()
class TestPolygonsEndPoint: