type annotations for non-blocking DB connection

This commit is contained in:
Sarah Hoffmann
2022-07-05 15:00:33 +02:00
parent 0dff71a410
commit 7a1d22ff15

View File

@@ -4,8 +4,9 @@
#
# Copyright (C) 2022 by the Nominatim developer community.
# For a full list of authors see the git log.
""" Database helper functions for the indexer.
""" Non-blocking database connections.
"""
from typing import Callable, Any, Optional, List, Iterator
import logging
import select
import time
@@ -21,6 +22,8 @@ try:
except ImportError:
__has_psycopg2_errors__ = False
from nominatim.typing import T_cursor
LOG = logging.getLogger()
class DeadlockHandler:
@@ -29,14 +32,14 @@ class DeadlockHandler:
normally.
"""
def __init__(self, handler, ignore_sql_errors=False):
def __init__(self, handler: Callable[[], None], ignore_sql_errors: bool = False) -> None:
self.handler = handler
self.ignore_sql_errors = ignore_sql_errors
def __enter__(self):
def __enter__(self) -> 'DeadlockHandler':
return self
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
if __has_psycopg2_errors__:
if exc_type == psycopg2.errors.DeadlockDetected: # pylint: disable=E1101
self.handler()
@@ -57,26 +60,31 @@ class DBConnection:
""" A single non-blocking database connection.
"""
def __init__(self, dsn, cursor_factory=None, ignore_sql_errors=False):
self.current_query = None
self.current_params = None
def __init__(self, dsn: str,
cursor_factory: Optional[Callable[..., T_cursor]] = None,
ignore_sql_errors: bool = False) -> None:
self.dsn = dsn
self.current_query: Optional[str] = None
self.current_params: Optional[List[Any]] = None
self.ignore_sql_errors = ignore_sql_errors
self.conn = None
self.cursor = None
self.conn: Optional['psycopg2.connection'] = None
self.cursor: Optional['psycopg2.cursor'] = None
self.connect(cursor_factory=cursor_factory)
def close(self):
def close(self) -> None:
""" Close all open connections. Does not wait for pending requests.
"""
if self.conn is not None:
self.cursor.close()
if self.cursor is not None:
self.cursor.close() # type: ignore[no-untyped-call]
self.cursor = None
self.conn.close()
self.conn = None
def connect(self, cursor_factory=None):
def connect(self, cursor_factory: Optional[Callable[..., T_cursor]] = None) -> None:
""" (Re)connect to the database. Creates an asynchronous connection
with JIT and parallel processing disabled. If a connection was
already open, it is closed and a new connection established.
@@ -89,7 +97,10 @@ class DBConnection:
self.conn = psycopg2.connect(**{'dsn': self.dsn, 'async': True})
self.wait()
self.cursor = self.conn.cursor(cursor_factory=cursor_factory)
if cursor_factory is not None:
self.cursor = self.conn.cursor(cursor_factory=cursor_factory)
else:
self.cursor = self.conn.cursor()
# Disable JIT and parallel workers as they are known to cause problems.
# Update pg_settings instead of using SET because it does not yield
# errors on older versions of Postgres where the settings are not
@@ -100,11 +111,15 @@ class DBConnection:
WHERE name = 'max_parallel_workers_per_gather';""")
self.wait()
def _deadlock_handler(self):
def _deadlock_handler(self) -> None:
LOG.info("Deadlock detected (params = %s), retry.", str(self.current_params))
assert self.cursor is not None
assert self.current_query is not None
assert self.current_params is not None
self.cursor.execute(self.current_query, self.current_params)
def wait(self):
def wait(self) -> None:
""" Block until any pending operation is done.
"""
while True:
@@ -113,25 +128,29 @@ class DBConnection:
self.current_query = None
return
def perform(self, sql, args=None):
def perform(self, sql: str, args: Optional[List[Any]] = None) -> None:
""" Send SQL query to the server. Returns immediately without
blocking.
"""
assert self.cursor is not None
self.current_query = sql
self.current_params = args
self.cursor.execute(sql, args)
def fileno(self):
def fileno(self) -> int:
""" File descriptor to wait for. (Makes this class select()able.)
"""
assert self.conn is not None
return self.conn.fileno()
def is_done(self):
def is_done(self) -> bool:
""" Check if the connection is available for a new query.
Also checks if the previous query has run into a deadlock.
If so, then the previous query is repeated.
"""
assert self.conn is not None
if self.current_query is None:
return True
@@ -150,14 +169,14 @@ class WorkerPool:
"""
REOPEN_CONNECTIONS_AFTER = 100000
def __init__(self, dsn, pool_size, ignore_sql_errors=False):
def __init__(self, dsn: str, pool_size: int, ignore_sql_errors: bool = False) -> None:
self.threads = [DBConnection(dsn, ignore_sql_errors=ignore_sql_errors)
for _ in range(pool_size)]
self.free_workers = self._yield_free_worker()
self.wait_time = 0
self.wait_time = 0.0
def finish_all(self):
def finish_all(self) -> None:
""" Wait for all connection to finish.
"""
for thread in self.threads:
@@ -166,22 +185,22 @@ class WorkerPool:
self.free_workers = self._yield_free_worker()
def close(self):
def close(self) -> None:
""" Close all connections and clear the pool.
"""
for thread in self.threads:
thread.close()
self.threads = []
self.free_workers = None
self.free_workers = iter([])
def next_free_worker(self):
def next_free_worker(self) -> DBConnection:
""" Get the next free connection.
"""
return next(self.free_workers)
def _yield_free_worker(self):
def _yield_free_worker(self) -> Iterator[DBConnection]:
ready = self.threads
command_stat = 0
while True:
@@ -200,17 +219,17 @@ class WorkerPool:
self.wait_time += time.time() - tstart
def _reconnect_threads(self):
def _reconnect_threads(self) -> None:
for thread in self.threads:
while not thread.is_done():
thread.wait()
thread.connect()
def __enter__(self):
def __enter__(self) -> 'WorkerPool':
return self
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.finish_all()
self.close()