add property cache for API

This caches results from querying nominatim_properties.
This commit is contained in:
Sarah Hoffmann
2023-01-28 22:24:36 +01:00
parent 2156fd4909
commit 16b6484c65
4 changed files with 152 additions and 13 deletions

View File

@@ -7,7 +7,7 @@
""" """
Extended SQLAlchemy connection class that also includes access to the schema. Extended SQLAlchemy connection class that also includes access to the schema.
""" """
from typing import Any, Mapping, Sequence, Union from typing import Any, Mapping, Sequence, Union, Dict, cast
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncConnection from sqlalchemy.ext.asyncio import AsyncConnection
@@ -22,9 +22,11 @@ class SearchConnection:
""" """
def __init__(self, conn: AsyncConnection, def __init__(self, conn: AsyncConnection,
tables: SearchTables) -> None: tables: SearchTables,
properties: Dict[str, Any]) -> None:
self.connection = conn self.connection = conn
self.t = tables # pylint: disable=invalid-name self.t = tables # pylint: disable=invalid-name
self._property_cache = properties
async def scalar(self, sql: sa.sql.base.Executable, async def scalar(self, sql: sa.sql.base.Executable,
@@ -41,3 +43,44 @@ class SearchConnection:
""" Execute a 'execute()' query on the connection. """ Execute a 'execute()' query on the connection.
""" """
return await self.connection.execute(sql, params) return await self.connection.execute(sql, params)
async def get_property(self, name: str, cached: bool = True) -> str:
""" Get a property from Nominatim's property table.
Property values are normally cached so that they are only
retrieved from the database when they are queried for the
first time with this function. Set 'cached' to False to force
reading the property from the database.
Raises a ValueError if the property does not exist.
"""
if name.startswith('DB:'):
raise ValueError(f"Illegal property value '{name}'.")
if cached and name in self._property_cache:
return cast(str, self._property_cache[name])
sql = sa.select(self.t.properties.c.value)\
.where(self.t.properties.c.property == name)
value = await self.connection.scalar(sql)
if value is None:
raise ValueError(f"Property '{name}' not found in database.")
self._property_cache[name] = cast(str, value)
return cast(str, value)
async def get_db_property(self, name: str) -> Any:
""" Get a setting from the database. At the moment, only
'server_version', the version of the database software, can
be retrieved with this function.
Raises a ValueError if the property does not exist.
"""
if name != 'server_version':
raise ValueError(f"DB setting '{name}' not found in database.")
return self._property_cache['DB:server_version']

View File

@@ -7,7 +7,7 @@
""" """
Implementation of classes for API access via libraries. Implementation of classes for API access via libraries.
""" """
from typing import Mapping, Optional, Any, AsyncIterator from typing import Mapping, Optional, Any, AsyncIterator, Dict
import asyncio import asyncio
import contextlib import contextlib
from pathlib import Path from pathlib import Path
@@ -32,6 +32,7 @@ class NominatimAPIAsync:
self._engine_lock = asyncio.Lock() self._engine_lock = asyncio.Lock()
self._engine: Optional[sa_asyncio.AsyncEngine] = None self._engine: Optional[sa_asyncio.AsyncEngine] = None
self._tables: Optional[SearchTables] = None self._tables: Optional[SearchTables] = None
self._property_cache: Dict[str, Any] = {'DB:server_version': 0}
async def setup_database(self) -> None: async def setup_database(self) -> None:
@@ -64,11 +65,11 @@ class NominatimAPIAsync:
try: try:
async with engine.begin() as conn: async with engine.begin() as conn:
result = await conn.scalar(sa.text('SHOW server_version_num')) result = await conn.scalar(sa.text('SHOW server_version_num'))
self.server_version = int(result) server_version = int(result)
except asyncpg.PostgresError: except asyncpg.PostgresError:
self.server_version = 0 server_version = 0
if self.server_version >= 110000: if server_version >= 110000:
@sa.event.listens_for(engine.sync_engine, "connect") @sa.event.listens_for(engine.sync_engine, "connect")
def _on_connect(dbapi_con: Any, _: Any) -> None: def _on_connect(dbapi_con: Any, _: Any) -> None:
cursor = dbapi_con.cursor() cursor = dbapi_con.cursor()
@@ -76,6 +77,8 @@ class NominatimAPIAsync:
# Make sure that all connections get the new settings # Make sure that all connections get the new settings
await self.close() await self.close()
self._property_cache['DB:server_version'] = server_version
self._tables = SearchTables(sa.MetaData(), engine.name) # pylint: disable=no-member self._tables = SearchTables(sa.MetaData(), engine.name) # pylint: disable=no-member
self._engine = engine self._engine = engine
@@ -104,7 +107,7 @@ class NominatimAPIAsync:
assert self._tables is not None assert self._tables is not None
async with self._engine.begin() as conn: async with self._engine.begin() as conn:
yield SearchConnection(conn, self._tables) yield SearchConnection(conn, self._tables, self._property_cache)
async def status(self) -> StatusResult: async def status(self) -> StatusResult:

View File

@@ -7,7 +7,7 @@
""" """
Classes and function releated to status call. Classes and function releated to status call.
""" """
from typing import Optional, cast from typing import Optional
import datetime as dt import datetime as dt
import dataclasses import dataclasses
@@ -37,10 +37,10 @@ async def get_status(conn: SearchConnection) -> StatusResult:
status.data_updated = await conn.scalar(sql) status.data_updated = await conn.scalar(sql)
# Database version # Database version
sql = sa.select(conn.t.properties.c.value)\ try:
.where(conn.t.properties.c.property == 'database_version') verstr = await conn.get_property('database_version')
verstr = await conn.scalar(sql) status.database_version = version.parse_version(verstr)
if verstr is not None: except ValueError:
status.database_version = version.parse_version(cast(str, verstr)) pass
return status return status

View File

@@ -0,0 +1,93 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2023 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Tests for enhanced connection class for API functions.
"""
from pathlib import Path
import pytest
import pytest_asyncio
import sqlalchemy as sa
from nominatim.api import NominatimAPIAsync
@pytest_asyncio.fixture
async def apiobj(temp_db):
""" Create an asynchronous SQLAlchemy engine for the test DB.
"""
api = NominatimAPIAsync(Path('/invalid'), {})
yield api
await api.close()
@pytest.mark.asyncio
async def test_run_scalar(apiobj, table_factory):
table_factory('foo', definition='that TEXT', content=(('a', ),))
async with apiobj.begin() as conn:
assert await conn.scalar(sa.text('SELECT * FROM foo')) == 'a'
@pytest.mark.asyncio
async def test_run_execute(apiobj, table_factory):
table_factory('foo', definition='that TEXT', content=(('a', ),))
async with apiobj.begin() as conn:
result = await conn.execute(sa.text('SELECT * FROM foo'))
assert result.fetchone()[0] == 'a'
@pytest.mark.asyncio
async def test_get_property_existing_cached(apiobj, table_factory):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('dbv', '96723'), ))
async with apiobj.begin() as conn:
assert await conn.get_property('dbv') == '96723'
await conn.execute(sa.text('TRUNCATE nominatim_properties'))
assert await conn.get_property('dbv') == '96723'
@pytest.mark.asyncio
async def test_get_property_existing_uncached(apiobj, table_factory):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('dbv', '96723'), ))
async with apiobj.begin() as conn:
assert await conn.get_property('dbv') == '96723'
await conn.execute(sa.text("UPDATE nominatim_properties SET value = '1'"))
assert await conn.get_property('dbv', cached=False) == '1'
@pytest.mark.asyncio
@pytest.mark.parametrize('param', ['foo', 'DB:server_version'])
async def test_get_property_missing(apiobj, table_factory, param):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT')
async with apiobj.begin() as conn:
with pytest.raises(ValueError):
await conn.get_property(param)
@pytest.mark.asyncio
async def test_get_db_property_existing(apiobj):
async with apiobj.begin() as conn:
assert await conn.get_db_property('server_version') > 0
@pytest.mark.asyncio
async def test_get_db_property_existing(apiobj):
async with apiobj.begin() as conn:
with pytest.raises(ValueError):
await conn.get_db_property('dfkgjd.rijg')