apply request timeout also while waiting for a connection from pool

This commit is contained in:
Sarah Hoffmann
2025-09-05 14:47:14 +02:00
parent 563255202d
commit 3a50f749dd
6 changed files with 93 additions and 22 deletions

View File

@@ -15,6 +15,7 @@ classifiers = [
"Operating System :: OS Independent", "Operating System :: OS Independent",
] ]
dependencies = [ dependencies = [
"async-timeout",
"python-dotenv", "python-dotenv",
"pyYAML>=5.1", "pyYAML>=5.1",
"SQLAlchemy>=1.4.31", "SQLAlchemy>=1.4.31",

View File

@@ -175,12 +175,15 @@ NOMINATIM_SERVE_LEGACY_URLS=yes
NOMINATIM_API_POOL_SIZE=5 NOMINATIM_API_POOL_SIZE=5
# Timeout is seconds after which a single query to the database is cancelled. # Timeout is seconds after which a single query to the database is cancelled.
# The user receives a 503 response, when a query times out. # The caller receives a TimeoutError (or HTTP 503), when a query times out.
# When empty, then timeouts are disabled. # When empty, then timeouts are disabled.
NOMINATIM_QUERY_TIMEOUT=10 NOMINATIM_QUERY_TIMEOUT=10
# Maximum time a single request is allowed to take. When the timeout is # Maximum time a single request is allowed to take. If the timeout is exceeded
# exceeded, the available results are returned. # before the request is able to obtain a database connection from the
# connection pool, a TimeoutError (or HTTP 503) is thrown. If the timeout
# is exceeded while the search is ongoing, all results already found will
# be returned.
# When empty, then timeouts are disabled. # When empty, then timeouts are disabled.
NOMINATIM_REQUEST_TIMEOUT=60 NOMINATIM_REQUEST_TIMEOUT=60

View File

@@ -2,7 +2,7 @@
# #
# This file is part of Nominatim. (https://nominatim.org) # This file is part of Nominatim. (https://nominatim.org)
# #
# Copyright (C) 2024 by the Nominatim developer community. # Copyright (C) 2025 by the Nominatim developer community.
# For a full list of authors see the git log. # For a full list of authors see the git log.
""" """
Implementation of classes for API access via libraries. Implementation of classes for API access via libraries.
@@ -14,6 +14,11 @@ import sys
import contextlib import contextlib
from pathlib import Path from pathlib import Path
if sys.version_info >= (3, 11):
from asyncio import timeout_at
else:
from async_timeout import timeout_at
import sqlalchemy as sa import sqlalchemy as sa
import sqlalchemy.ext.asyncio as sa_asyncio import sqlalchemy.ext.asyncio as sa_asyncio
@@ -26,6 +31,7 @@ from .connection import SearchConnection
from .status import get_status, StatusResult from .status import get_status, StatusResult
from .lookup import get_places, get_detailed_place from .lookup import get_places, get_detailed_place
from .reverse import ReverseGeocoder from .reverse import ReverseGeocoder
from .timeout import Timeout
from . import search as nsearch from . import search as nsearch
from . import types as ntyp from . import types as ntyp
from .results import DetailedResult, ReverseResult, SearchResults from .results import DetailedResult, ReverseResult, SearchResults
@@ -172,12 +178,15 @@ class NominatimAPIAsync:
await self.close() await self.close()
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def begin(self) -> AsyncIterator[SearchConnection]: async def begin(self, abs_timeout: Optional[float] = None) -> AsyncIterator[SearchConnection]:
""" Create a new connection with automatic transaction handling. """ Create a new connection with automatic transaction handling.
This function may be used to get low-level access to the database. This function may be used to get low-level access to the database.
Refer to the documentation of SQLAlchemy for details how to use Refer to the documentation of SQLAlchemy for details how to use
the connection object. the connection object.
You may optionally give an absolute timeout until when to wait
for a connection to become available.
""" """
if self._engine is None: if self._engine is None:
await self.setup_database() await self.setup_database()
@@ -185,14 +194,15 @@ class NominatimAPIAsync:
assert self._engine is not None assert self._engine is not None
assert self._tables is not None assert self._tables is not None
async with self._engine.begin() as conn: async with timeout_at(abs_timeout), self._engine.begin() as conn:
yield SearchConnection(conn, self._tables, self._property_cache, self.config) yield SearchConnection(conn, self._tables, self._property_cache, self.config)
async def status(self) -> StatusResult: async def status(self) -> StatusResult:
""" Return the status of the database. """ Return the status of the database.
""" """
timeout = Timeout(self.request_timeout)
try: try:
async with self.begin() as conn: async with self.begin(abs_timeout=timeout.abs) as conn:
conn.set_query_timeout(self.query_timeout) conn.set_query_timeout(self.query_timeout)
status = await get_status(conn) status = await get_status(conn)
except (PGCORE_ERROR, sa.exc.OperationalError): except (PGCORE_ERROR, sa.exc.OperationalError):
@@ -205,8 +215,9 @@ class NominatimAPIAsync:
Returns None if there is no entry under the given ID. Returns None if there is no entry under the given ID.
""" """
timeout = Timeout(self.request_timeout)
details = ntyp.LookupDetails.from_kwargs(params) details = ntyp.LookupDetails.from_kwargs(params)
async with self.begin() as conn: async with self.begin(abs_timeout=timeout.abs) as conn:
conn.set_query_timeout(self.query_timeout) conn.set_query_timeout(self.query_timeout)
if details.keywords: if details.keywords:
await nsearch.make_query_analyzer(conn) await nsearch.make_query_analyzer(conn)
@@ -217,8 +228,9 @@ class NominatimAPIAsync:
Returns a list of place information for all IDs that were found. Returns a list of place information for all IDs that were found.
""" """
timeout = Timeout(self.request_timeout)
details = ntyp.LookupDetails.from_kwargs(params) details = ntyp.LookupDetails.from_kwargs(params)
async with self.begin() as conn: async with self.begin(abs_timeout=timeout.abs) as conn:
conn.set_query_timeout(self.query_timeout) conn.set_query_timeout(self.query_timeout)
if details.keywords: if details.keywords:
await nsearch.make_query_analyzer(conn) await nsearch.make_query_analyzer(conn)
@@ -235,8 +247,9 @@ class NominatimAPIAsync:
# There are no results to be expected outside valid coordinates. # There are no results to be expected outside valid coordinates.
return None return None
timeout = Timeout(self.request_timeout)
details = ntyp.ReverseDetails.from_kwargs(params) details = ntyp.ReverseDetails.from_kwargs(params)
async with self.begin() as conn: async with self.begin(abs_timeout=timeout.abs) as conn:
conn.set_query_timeout(self.query_timeout) conn.set_query_timeout(self.query_timeout)
if details.keywords: if details.keywords:
await nsearch.make_query_analyzer(conn) await nsearch.make_query_analyzer(conn)
@@ -251,10 +264,11 @@ class NominatimAPIAsync:
if not query: if not query:
raise UsageError('Nothing to search for.') raise UsageError('Nothing to search for.')
async with self.begin() as conn: timeout = Timeout(self.request_timeout)
async with self.begin(abs_timeout=timeout.abs) as conn:
conn.set_query_timeout(self.query_timeout) conn.set_query_timeout(self.query_timeout)
geocoder = nsearch.ForwardGeocoder(conn, ntyp.SearchDetails.from_kwargs(params), geocoder = nsearch.ForwardGeocoder(conn, ntyp.SearchDetails.from_kwargs(params),
self.request_timeout) timeout)
phrases = [nsearch.Phrase(nsearch.PHRASE_ANY, p.strip()) for p in query.split(',')] phrases = [nsearch.Phrase(nsearch.PHRASE_ANY, p.strip()) for p in query.split(',')]
return await geocoder.lookup(phrases) return await geocoder.lookup(phrases)
@@ -268,7 +282,8 @@ class NominatimAPIAsync:
**params: Any) -> SearchResults: **params: Any) -> SearchResults:
""" Find an address using structured search. """ Find an address using structured search.
""" """
async with self.begin() as conn: timeout = Timeout(self.request_timeout)
async with self.begin(abs_timeout=timeout.abs) as conn:
conn.set_query_timeout(self.query_timeout) conn.set_query_timeout(self.query_timeout)
details = ntyp.SearchDetails.from_kwargs(params) details = ntyp.SearchDetails.from_kwargs(params)
@@ -310,7 +325,7 @@ class NominatimAPIAsync:
if amenity: if amenity:
details.layers |= ntyp.DataLayer.POI details.layers |= ntyp.DataLayer.POI
geocoder = nsearch.ForwardGeocoder(conn, details, self.request_timeout) geocoder = nsearch.ForwardGeocoder(conn, details, timeout)
return await geocoder.lookup(phrases) return await geocoder.lookup(phrases)
async def search_category(self, categories: List[Tuple[str, str]], async def search_category(self, categories: List[Tuple[str, str]],
@@ -323,8 +338,9 @@ class NominatimAPIAsync:
if not categories: if not categories:
return SearchResults() return SearchResults()
timeout = Timeout(self.request_timeout)
details = ntyp.SearchDetails.from_kwargs(params) details = ntyp.SearchDetails.from_kwargs(params)
async with self.begin() as conn: async with self.begin(abs_timeout=timeout.abs) as conn:
conn.set_query_timeout(self.query_timeout) conn.set_query_timeout(self.query_timeout)
if near_query: if near_query:
phrases = [nsearch.Phrase(nsearch.PHRASE_ANY, p) for p in near_query.split(',')] phrases = [nsearch.Phrase(nsearch.PHRASE_ANY, p) for p in near_query.split(',')]
@@ -333,7 +349,7 @@ class NominatimAPIAsync:
if details.keywords: if details.keywords:
await nsearch.make_query_analyzer(conn) await nsearch.make_query_analyzer(conn)
geocoder = nsearch.ForwardGeocoder(conn, details, self.request_timeout) geocoder = nsearch.ForwardGeocoder(conn, details, timeout)
return await geocoder.lookup_pois(categories, phrases) return await geocoder.lookup_pois(categories, phrases)

View File

@@ -10,12 +10,12 @@ Public interface to the search code.
from typing import List, Any, Optional, Iterator, Tuple, Dict from typing import List, Any, Optional, Iterator, Tuple, Dict
import itertools import itertools
import re import re
import datetime as dt
import difflib import difflib
from ..connection import SearchConnection from ..connection import SearchConnection
from ..types import SearchDetails from ..types import SearchDetails
from ..results import SearchResult, SearchResults, add_result_details from ..results import SearchResult, SearchResults, add_result_details
from ..timeout import Timeout
from ..logging import log from ..logging import log
from .token_assignment import yield_token_assignments from .token_assignment import yield_token_assignments
from .db_search_builder import SearchBuilder, build_poi_search, wrap_near_search from .db_search_builder import SearchBuilder, build_poi_search, wrap_near_search
@@ -29,10 +29,10 @@ class ForwardGeocoder:
""" """
def __init__(self, conn: SearchConnection, def __init__(self, conn: SearchConnection,
params: SearchDetails, timeout: Optional[int]) -> None: params: SearchDetails, timeout: Timeout) -> None:
self.conn = conn self.conn = conn
self.params = params self.params = params
self.timeout = dt.timedelta(seconds=timeout or 1000000) self.timeout = timeout
self.query_analyzer: Optional[AbstractQueryAnalyzer] = None self.query_analyzer: Optional[AbstractQueryAnalyzer] = None
@property @property
@@ -78,8 +78,6 @@ class ForwardGeocoder:
log().section('Execute database searches') log().section('Execute database searches')
results: Dict[Any, SearchResult] = {} results: Dict[Any, SearchResult] = {}
end_time = dt.datetime.now() + self.timeout
min_ranking = searches[0].penalty + 2.0 min_ranking = searches[0].penalty + 2.0
prev_penalty = 0.0 prev_penalty = 0.0
for i, search in enumerate(searches): for i, search in enumerate(searches):
@@ -99,7 +97,7 @@ class ForwardGeocoder:
min_ranking = min(min_ranking, result.accuracy * 1.2, 2.0) min_ranking = min(min_ranking, result.accuracy * 1.2, 2.0)
log().result_dump('Results', ((r.accuracy, r) for r in lookup_results)) log().result_dump('Results', ((r.accuracy, r) for r in lookup_results))
prev_penalty = search.penalty prev_penalty = search.penalty
if dt.datetime.now() >= end_time: if self.timeout.is_elapsed():
break break
return SearchResults(results.values()) return SearchResults(results.values())

View File

@@ -0,0 +1,24 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2025 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Helpers for handling of timeouts for request.
"""
from typing import Union, Optional
import asyncio
class Timeout:
""" A class that provides helper functions to ensure a given timeout
is respected. Can only be used from coroutines.
"""
def __init__(self, timeout: Optional[Union[int, float]]) -> None:
self.abs = None if timeout is None else asyncio.get_running_loop().time() + timeout
def is_elapsed(self) -> bool:
""" Check if the timeout has already passed.
"""
return (self.abs is not None) and (asyncio.get_running_loop().time() >= self.abs)

View File

@@ -7,8 +7,11 @@
""" """
Tests for the status API call. Tests for the status API call.
""" """
import asyncio
import datetime as dt import datetime as dt
import pytest
from nominatim_api.version import NOMINATIM_API_VERSION from nominatim_api.version import NOMINATIM_API_VERSION
import nominatim_api as napi import nominatim_api as napi
@@ -53,3 +56,29 @@ def test_status_database_not_found(monkeypatch):
assert result.software_version == NOMINATIM_API_VERSION assert result.software_version == NOMINATIM_API_VERSION
assert result.database_version is None assert result.database_version is None
assert result.data_updated is None assert result.data_updated is None
@pytest.mark.asyncio
async def test_status_connection_timeout_single_pool(status_table, property_table, monkeypatch):
monkeypatch.setenv('NOMINATIM_API_POOL_SIZE', '1')
monkeypatch.setenv('NOMINATIM_REQUEST_TIMEOUT', '1')
async with napi.NominatimAPIAsync() as api:
async with api.begin():
with pytest.raises((TimeoutError, asyncio.TimeoutError)):
await api.status()
await api.status()
@pytest.mark.asyncio
async def test_status_connection_timeout_multi_pool(status_table, property_table, monkeypatch):
monkeypatch.setenv('NOMINATIM_API_POOL_SIZE', '2')
monkeypatch.setenv('NOMINATIM_REQUEST_TIMEOUT', '1')
async with napi.NominatimAPIAsync() as api:
async with api.begin(), api.begin():
with pytest.raises((TimeoutError, asyncio.TimeoutError)):
await api.status()
await api.status()