make NominatimAPI[Async] a context manager

If close() isn't properly called, it can lead to odd error messages
about uncaught exceptions.
This commit is contained in:
Sarah Hoffmann
2024-08-19 11:31:38 +02:00
parent 8b41b80bff
commit c2594aca40
8 changed files with 53 additions and 65 deletions

View File

@@ -38,6 +38,8 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
This class shares most of the functions with its synchronous This class shares most of the functions with its synchronous
version. There are some additional functions or parameters, version. There are some additional functions or parameters,
which are documented below. which are documented below.
This class should usually be used as a context manager in 'with' context.
""" """
def __init__(self, project_dir: Path, def __init__(self, project_dir: Path,
environ: Optional[Mapping[str, str]] = None, environ: Optional[Mapping[str, str]] = None,
@@ -166,6 +168,14 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
await self._engine.dispose() await self._engine.dispose()
async def __aenter__(self) -> 'NominatimAPIAsync':
return self
async def __aexit__(self, *_: Any) -> None:
await self.close()
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def begin(self) -> AsyncIterator[SearchConnection]: async def begin(self) -> AsyncIterator[SearchConnection]:
""" Create a new connection with automatic transaction handling. """ Create a new connection with automatic transaction handling.
@@ -351,6 +361,8 @@ class NominatimAPI:
""" This class provides a thin synchronous wrapper around the asynchronous """ This class provides a thin synchronous wrapper around the asynchronous
Nominatim functions. It creates its own event loop and runs each Nominatim functions. It creates its own event loop and runs each
synchronous function call to completion using that loop. synchronous function call to completion using that loop.
This class should usually be used as a context manager in 'with' context.
""" """
def __init__(self, project_dir: Path, def __init__(self, project_dir: Path,
@@ -376,8 +388,17 @@ class NominatimAPI:
This function also closes the asynchronous worker loop making This function also closes the asynchronous worker loop making
the NominatimAPI object unusable. the NominatimAPI object unusable.
""" """
self._loop.run_until_complete(self._async_api.close()) if not self._loop.is_closed():
self._loop.close() self._loop.run_until_complete(self._async_api.close())
self._loop.close()
def __enter__(self) -> 'NominatimAPI':
return self
def __exit__(self, *_: Any) -> None:
self.close()
@property @property

View File

@@ -9,6 +9,7 @@ Helper fixtures for API call tests.
""" """
from pathlib import Path from pathlib import Path
import pytest import pytest
import pytest_asyncio
import time import time
import datetime as dt import datetime as dt
@@ -244,3 +245,9 @@ def frontend(request, event_loop, tmp_path):
for api in testapis: for api in testapis:
api.close() api.close()
@pytest_asyncio.fixture
async def api(temp_db):
async with napi.NominatimAPIAsync(Path('/invalid')) as api:
yield api

View File

@@ -40,10 +40,9 @@ async def conn(table_factory):
table_factory('word', table_factory('word',
definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB') definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
api = NominatimAPIAsync(Path('/invalid'), {}) async with NominatimAPIAsync(Path('/invalid'), {}) as api:
async with api.begin() as conn: async with api.begin() as conn:
yield conn yield conn
await api.close()
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -74,10 +74,9 @@ async def conn(table_factory, temp_db_cursor):
temp_db_cursor.execute("""CREATE OR REPLACE FUNCTION make_standard_name(name TEXT) temp_db_cursor.execute("""CREATE OR REPLACE FUNCTION make_standard_name(name TEXT)
RETURNS TEXT AS $$ SELECT lower(name); $$ LANGUAGE SQL;""") RETURNS TEXT AS $$ SELECT lower(name); $$ LANGUAGE SQL;""")
api = NominatimAPIAsync(Path('/invalid'), {}) async with NominatimAPIAsync(Path('/invalid'), {}) as api:
async with api.begin() as conn: async with api.begin() as conn:
yield conn yield conn
await api.close()
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -11,41 +11,35 @@ from pathlib import Path
import pytest import pytest
from nominatim_api import NominatimAPIAsync
from nominatim_api.search.query_analyzer_factory import make_query_analyzer from nominatim_api.search.query_analyzer_factory import make_query_analyzer
from nominatim_api.search.icu_tokenizer import ICUQueryAnalyzer from nominatim_api.search.icu_tokenizer import ICUQueryAnalyzer
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_import_icu_tokenizer(table_factory): async def test_import_icu_tokenizer(table_factory, api):
table_factory('nominatim_properties', table_factory('nominatim_properties',
definition='property TEXT, value TEXT', definition='property TEXT, value TEXT',
content=(('tokenizer', 'icu'), content=(('tokenizer', 'icu'),
('tokenizer_import_normalisation', ':: lower();'), ('tokenizer_import_normalisation', ':: lower();'),
('tokenizer_import_transliteration', "'1' > '/1/'; 'ä' > 'ä '"))) ('tokenizer_import_transliteration', "'1' > '/1/'; 'ä' > 'ä '")))
api = NominatimAPIAsync(Path('/invalid'), {})
async with api.begin() as conn: async with api.begin() as conn:
ana = await make_query_analyzer(conn) ana = await make_query_analyzer(conn)
assert isinstance(ana, ICUQueryAnalyzer) assert isinstance(ana, ICUQueryAnalyzer)
await api.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_import_missing_property(table_factory): async def test_import_missing_property(table_factory, api):
api = NominatimAPIAsync(Path('/invalid'), {})
table_factory('nominatim_properties', table_factory('nominatim_properties',
definition='property TEXT, value TEXT') definition='property TEXT, value TEXT')
async with api.begin() as conn: async with api.begin() as conn:
with pytest.raises(ValueError, match='Property.*not found'): with pytest.raises(ValueError, match='Property.*not found'):
await make_query_analyzer(conn) await make_query_analyzer(conn)
await api.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_import_missing_module(table_factory): async def test_import_missing_module(table_factory, api):
api = NominatimAPIAsync(Path('/invalid'), {})
table_factory('nominatim_properties', table_factory('nominatim_properties',
definition='property TEXT, value TEXT', definition='property TEXT, value TEXT',
content=(('tokenizer', 'missing'),)) content=(('tokenizer', 'missing'),))
@@ -53,5 +47,3 @@ async def test_import_missing_module(table_factory):
async with api.begin() as conn: async with api.begin() as conn:
with pytest.raises(RuntimeError, match='Tokenizer not found'): with pytest.raises(RuntimeError, match='Tokenizer not found'):
await make_query_analyzer(conn) await make_query_analyzer(conn)
await api.close()

View File

@@ -9,45 +9,34 @@ Tests for enhanced connection class for API functions.
""" """
from pathlib import Path from pathlib import Path
import pytest import pytest
import pytest_asyncio
import sqlalchemy as sa import sqlalchemy as sa
from nominatim_api import NominatimAPIAsync
@pytest_asyncio.fixture
async def apiobj(temp_db):
""" Create an asynchronous SQLAlchemy engine for the test DB.
"""
api = NominatimAPIAsync(Path('/invalid'), {})
yield api
await api.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_scalar(apiobj, table_factory): async def test_run_scalar(api, table_factory):
table_factory('foo', definition='that TEXT', content=(('a', ),)) table_factory('foo', definition='that TEXT', content=(('a', ),))
async with apiobj.begin() as conn: async with api.begin() as conn:
assert await conn.scalar(sa.text('SELECT * FROM foo')) == 'a' assert await conn.scalar(sa.text('SELECT * FROM foo')) == 'a'
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_execute(apiobj, table_factory): async def test_run_execute(api, table_factory):
table_factory('foo', definition='that TEXT', content=(('a', ),)) table_factory('foo', definition='that TEXT', content=(('a', ),))
async with apiobj.begin() as conn: async with api.begin() as conn:
result = await conn.execute(sa.text('SELECT * FROM foo')) result = await conn.execute(sa.text('SELECT * FROM foo'))
assert result.fetchone()[0] == 'a' assert result.fetchone()[0] == 'a'
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_property_existing_cached(apiobj, table_factory): async def test_get_property_existing_cached(api, table_factory):
table_factory('nominatim_properties', table_factory('nominatim_properties',
definition='property TEXT, value TEXT', definition='property TEXT, value TEXT',
content=(('dbv', '96723'), )) content=(('dbv', '96723'), ))
async with apiobj.begin() as conn: async with api.begin() as conn:
assert await conn.get_property('dbv') == '96723' assert await conn.get_property('dbv') == '96723'
await conn.execute(sa.text('TRUNCATE nominatim_properties')) await conn.execute(sa.text('TRUNCATE nominatim_properties'))
@@ -56,12 +45,12 @@ async def test_get_property_existing_cached(apiobj, table_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_property_existing_uncached(apiobj, table_factory): async def test_get_property_existing_uncached(api, table_factory):
table_factory('nominatim_properties', table_factory('nominatim_properties',
definition='property TEXT, value TEXT', definition='property TEXT, value TEXT',
content=(('dbv', '96723'), )) content=(('dbv', '96723'), ))
async with apiobj.begin() as conn: async with api.begin() as conn:
assert await conn.get_property('dbv') == '96723' assert await conn.get_property('dbv') == '96723'
await conn.execute(sa.text("UPDATE nominatim_properties SET value = '1'")) await conn.execute(sa.text("UPDATE nominatim_properties SET value = '1'"))
@@ -71,23 +60,23 @@ async def test_get_property_existing_uncached(apiobj, table_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize('param', ['foo', 'DB:server_version']) @pytest.mark.parametrize('param', ['foo', 'DB:server_version'])
async def test_get_property_missing(apiobj, table_factory, param): async def test_get_property_missing(api, table_factory, param):
table_factory('nominatim_properties', table_factory('nominatim_properties',
definition='property TEXT, value TEXT') definition='property TEXT, value TEXT')
async with apiobj.begin() as conn: async with api.begin() as conn:
with pytest.raises(ValueError): with pytest.raises(ValueError):
await conn.get_property(param) await conn.get_property(param)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_db_property_existing(apiobj): async def test_get_db_property_existing(api):
async with apiobj.begin() as conn: async with api.begin() as conn:
assert await conn.get_db_property('server_version') > 0 assert await conn.get_db_property('server_version') > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_db_property_existing(apiobj): async def test_get_db_property_existing(api):
async with apiobj.begin() as conn: async with api.begin() as conn:
with pytest.raises(ValueError): with pytest.raises(ValueError):
await conn.get_db_property('dfkgjd.rijg') await conn.get_db_property('dfkgjd.rijg')

View File

@@ -11,19 +11,10 @@ import json
from pathlib import Path from pathlib import Path
import pytest import pytest
import pytest_asyncio
from fake_adaptor import FakeAdaptor, FakeError, FakeResponse from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
import nominatim_api.v1.server_glue as glue import nominatim_api.v1.server_glue as glue
import nominatim_api as napi
@pytest_asyncio.fixture
async def api():
api = napi.NominatimAPIAsync(Path('/invalid'))
yield api
await api.close()
class TestDeletableEndPoint: class TestDeletableEndPoint:
@@ -61,4 +52,3 @@ class TestDeletableEndPoint:
{'place_id': 3, 'country_code': 'cd', 'name': None, {'place_id': 3, 'country_code': 'cd', 'name': None,
'osm_id': 781, 'osm_type': 'R', 'osm_id': 781, 'osm_type': 'R',
'class': 'landcover', 'type': 'grass'}] 'class': 'landcover', 'type': 'grass'}]

View File

@@ -12,19 +12,10 @@ import datetime as dt
from pathlib import Path from pathlib import Path
import pytest import pytest
import pytest_asyncio
from fake_adaptor import FakeAdaptor, FakeError, FakeResponse from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
import nominatim_api.v1.server_glue as glue import nominatim_api.v1.server_glue as glue
import nominatim_api as napi
@pytest_asyncio.fixture
async def api():
api = napi.NominatimAPIAsync(Path('/invalid'))
yield api
await api.close()
class TestPolygonsEndPoint: class TestPolygonsEndPoint: