type annotations for non-blocking DB connection

This commit is contained in:
Sarah Hoffmann
2022-07-05 15:00:33 +02:00
parent 0dff71a410
commit 7a1d22ff15

View File

@@ -4,8 +4,9 @@
# #
# Copyright (C) 2022 by the Nominatim developer community. # Copyright (C) 2022 by the Nominatim developer community.
# For a full list of authors see the git log. # For a full list of authors see the git log.
""" Database helper functions for the indexer. """ Non-blocking database connections.
""" """
from typing import Callable, Any, Optional, List, Iterator
import logging import logging
import select import select
import time import time
@@ -21,6 +22,8 @@ try:
except ImportError: except ImportError:
__has_psycopg2_errors__ = False __has_psycopg2_errors__ = False
from nominatim.typing import T_cursor
LOG = logging.getLogger() LOG = logging.getLogger()
class DeadlockHandler: class DeadlockHandler:
@@ -29,14 +32,14 @@ class DeadlockHandler:
normally. normally.
""" """
def __init__(self, handler, ignore_sql_errors=False): def __init__(self, handler: Callable[[], None], ignore_sql_errors: bool = False) -> None:
self.handler = handler self.handler = handler
self.ignore_sql_errors = ignore_sql_errors self.ignore_sql_errors = ignore_sql_errors
def __enter__(self): def __enter__(self) -> 'DeadlockHandler':
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
if __has_psycopg2_errors__: if __has_psycopg2_errors__:
if exc_type == psycopg2.errors.DeadlockDetected: # pylint: disable=E1101 if exc_type == psycopg2.errors.DeadlockDetected: # pylint: disable=E1101
self.handler() self.handler()
@@ -57,26 +60,31 @@ class DBConnection:
""" A single non-blocking database connection. """ A single non-blocking database connection.
""" """
def __init__(self, dsn, cursor_factory=None, ignore_sql_errors=False): def __init__(self, dsn: str,
self.current_query = None cursor_factory: Optional[Callable[..., T_cursor]] = None,
self.current_params = None ignore_sql_errors: bool = False) -> None:
self.dsn = dsn self.dsn = dsn
self.current_query: Optional[str] = None
self.current_params: Optional[List[Any]] = None
self.ignore_sql_errors = ignore_sql_errors self.ignore_sql_errors = ignore_sql_errors
self.conn = None self.conn: Optional['psycopg2.connection'] = None
self.cursor = None self.cursor: Optional['psycopg2.cursor'] = None
self.connect(cursor_factory=cursor_factory) self.connect(cursor_factory=cursor_factory)
def close(self): def close(self) -> None:
""" Close all open connections. Does not wait for pending requests. """ Close all open connections. Does not wait for pending requests.
""" """
if self.conn is not None: if self.conn is not None:
self.cursor.close() if self.cursor is not None:
self.cursor.close() # type: ignore[no-untyped-call]
self.cursor = None
self.conn.close() self.conn.close()
self.conn = None self.conn = None
def connect(self, cursor_factory=None): def connect(self, cursor_factory: Optional[Callable[..., T_cursor]] = None) -> None:
""" (Re)connect to the database. Creates an asynchronous connection """ (Re)connect to the database. Creates an asynchronous connection
with JIT and parallel processing disabled. If a connection was with JIT and parallel processing disabled. If a connection was
already open, it is closed and a new connection established. already open, it is closed and a new connection established.
@@ -89,7 +97,10 @@ class DBConnection:
self.conn = psycopg2.connect(**{'dsn': self.dsn, 'async': True}) self.conn = psycopg2.connect(**{'dsn': self.dsn, 'async': True})
self.wait() self.wait()
if cursor_factory is not None:
self.cursor = self.conn.cursor(cursor_factory=cursor_factory) self.cursor = self.conn.cursor(cursor_factory=cursor_factory)
else:
self.cursor = self.conn.cursor()
# Disable JIT and parallel workers as they are known to cause problems. # Disable JIT and parallel workers as they are known to cause problems.
# Update pg_settings instead of using SET because it does not yield # Update pg_settings instead of using SET because it does not yield
# errors on older versions of Postgres where the settings are not # errors on older versions of Postgres where the settings are not
@@ -100,11 +111,15 @@ class DBConnection:
WHERE name = 'max_parallel_workers_per_gather';""") WHERE name = 'max_parallel_workers_per_gather';""")
self.wait() self.wait()
def _deadlock_handler(self): def _deadlock_handler(self) -> None:
LOG.info("Deadlock detected (params = %s), retry.", str(self.current_params)) LOG.info("Deadlock detected (params = %s), retry.", str(self.current_params))
assert self.cursor is not None
assert self.current_query is not None
assert self.current_params is not None
self.cursor.execute(self.current_query, self.current_params) self.cursor.execute(self.current_query, self.current_params)
def wait(self): def wait(self) -> None:
""" Block until any pending operation is done. """ Block until any pending operation is done.
""" """
while True: while True:
@@ -113,25 +128,29 @@ class DBConnection:
self.current_query = None self.current_query = None
return return
def perform(self, sql, args=None): def perform(self, sql: str, args: Optional[List[Any]] = None) -> None:
""" Send SQL query to the server. Returns immediately without """ Send SQL query to the server. Returns immediately without
blocking. blocking.
""" """
assert self.cursor is not None
self.current_query = sql self.current_query = sql
self.current_params = args self.current_params = args
self.cursor.execute(sql, args) self.cursor.execute(sql, args)
def fileno(self): def fileno(self) -> int:
""" File descriptor to wait for. (Makes this class select()able.) """ File descriptor to wait for. (Makes this class select()able.)
""" """
assert self.conn is not None
return self.conn.fileno() return self.conn.fileno()
def is_done(self): def is_done(self) -> bool:
""" Check if the connection is available for a new query. """ Check if the connection is available for a new query.
Also checks if the previous query has run into a deadlock. Also checks if the previous query has run into a deadlock.
If so, then the previous query is repeated. If so, then the previous query is repeated.
""" """
assert self.conn is not None
if self.current_query is None: if self.current_query is None:
return True return True
@@ -150,14 +169,14 @@ class WorkerPool:
""" """
REOPEN_CONNECTIONS_AFTER = 100000 REOPEN_CONNECTIONS_AFTER = 100000
def __init__(self, dsn, pool_size, ignore_sql_errors=False): def __init__(self, dsn: str, pool_size: int, ignore_sql_errors: bool = False) -> None:
self.threads = [DBConnection(dsn, ignore_sql_errors=ignore_sql_errors) self.threads = [DBConnection(dsn, ignore_sql_errors=ignore_sql_errors)
for _ in range(pool_size)] for _ in range(pool_size)]
self.free_workers = self._yield_free_worker() self.free_workers = self._yield_free_worker()
self.wait_time = 0 self.wait_time = 0.0
def finish_all(self): def finish_all(self) -> None:
""" Wait for all connection to finish. """ Wait for all connection to finish.
""" """
for thread in self.threads: for thread in self.threads:
@@ -166,22 +185,22 @@ class WorkerPool:
self.free_workers = self._yield_free_worker() self.free_workers = self._yield_free_worker()
def close(self): def close(self) -> None:
""" Close all connections and clear the pool. """ Close all connections and clear the pool.
""" """
for thread in self.threads: for thread in self.threads:
thread.close() thread.close()
self.threads = [] self.threads = []
self.free_workers = None self.free_workers = iter([])
def next_free_worker(self): def next_free_worker(self) -> DBConnection:
""" Get the next free connection. """ Get the next free connection.
""" """
return next(self.free_workers) return next(self.free_workers)
def _yield_free_worker(self): def _yield_free_worker(self) -> Iterator[DBConnection]:
ready = self.threads ready = self.threads
command_stat = 0 command_stat = 0
while True: while True:
@@ -200,17 +219,17 @@ class WorkerPool:
self.wait_time += time.time() - tstart self.wait_time += time.time() - tstart
def _reconnect_threads(self): def _reconnect_threads(self) -> None:
for thread in self.threads: for thread in self.threads:
while not thread.is_done(): while not thread.is_done():
thread.wait() thread.wait()
thread.connect() thread.connect()
def __enter__(self): def __enter__(self) -> 'WorkerPool':
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.finish_all() self.finish_all()
self.close() self.close()