port code to psycopg3

This commit is contained in:
Sarah Hoffmann
2024-07-05 10:43:10 +02:00
parent 3742fa2929
commit 9659afbade
57 changed files with 800 additions and 1330 deletions

View File

@@ -1,236 +0,0 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
""" Non-blocking database connections.
"""
from typing import Callable, Any, Optional, Iterator, Sequence
import logging
import select
import time
import psycopg2
from psycopg2.extras import wait_select
# psycopg2 emits different exceptions pre and post 2.8. Detect if the new error
# module is available and adapt the error handling accordingly.
try:
import psycopg2.errors # pylint: disable=no-name-in-module,import-error
__has_psycopg2_errors__ = True
except ImportError:
__has_psycopg2_errors__ = False
from ..typing import T_cursor, Query
LOG = logging.getLogger()
class DeadlockHandler:
""" Context manager that catches deadlock exceptions and calls
the given handler function. All other exceptions are passed on
normally.
"""
def __init__(self, handler: Callable[[], None], ignore_sql_errors: bool = False) -> None:
self.handler = handler
self.ignore_sql_errors = ignore_sql_errors
def __enter__(self) -> 'DeadlockHandler':
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
if __has_psycopg2_errors__:
if exc_type == psycopg2.errors.DeadlockDetected: # pylint: disable=E1101
self.handler()
return True
elif exc_type == psycopg2.extensions.TransactionRollbackError \
and exc_value.pgcode == '40P01':
self.handler()
return True
if self.ignore_sql_errors and isinstance(exc_value, psycopg2.Error):
LOG.info("SQL error ignored: %s", exc_value)
return True
return False
class DBConnection:
""" A single non-blocking database connection.
"""
def __init__(self, dsn: str,
cursor_factory: Optional[Callable[..., T_cursor]] = None,
ignore_sql_errors: bool = False) -> None:
self.dsn = dsn
self.current_query: Optional[Query] = None
self.current_params: Optional[Sequence[Any]] = None
self.ignore_sql_errors = ignore_sql_errors
self.conn: Optional['psycopg2._psycopg.connection'] = None
self.cursor: Optional['psycopg2._psycopg.cursor'] = None
self.connect(cursor_factory=cursor_factory)
def close(self) -> None:
""" Close all open connections. Does not wait for pending requests.
"""
if self.conn is not None:
if self.cursor is not None:
self.cursor.close()
self.cursor = None
self.conn.close()
self.conn = None
def connect(self, cursor_factory: Optional[Callable[..., T_cursor]] = None) -> None:
""" (Re)connect to the database. Creates an asynchronous connection
with JIT and parallel processing disabled. If a connection was
already open, it is closed and a new connection established.
The caller must ensure that no query is pending before reconnecting.
"""
self.close()
# Use a dict to hand in the parameters because async is a reserved
# word in Python3.
self.conn = psycopg2.connect(**{'dsn': self.dsn, 'async': True}) # type: ignore
assert self.conn
self.wait()
if cursor_factory is not None:
self.cursor = self.conn.cursor(cursor_factory=cursor_factory)
else:
self.cursor = self.conn.cursor()
# Disable JIT and parallel workers as they are known to cause problems.
# Update pg_settings instead of using SET because it does not yield
# errors on older versions of Postgres where the settings are not
# implemented.
self.perform(
""" UPDATE pg_settings SET setting = -1 WHERE name = 'jit_above_cost';
UPDATE pg_settings SET setting = 0
WHERE name = 'max_parallel_workers_per_gather';""")
self.wait()
def _deadlock_handler(self) -> None:
LOG.info("Deadlock detected (params = %s), retry.", str(self.current_params))
assert self.cursor is not None
assert self.current_query is not None
assert self.current_params is not None
self.cursor.execute(self.current_query, self.current_params)
def wait(self) -> None:
""" Block until any pending operation is done.
"""
while True:
with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
wait_select(self.conn)
self.current_query = None
return
def perform(self, sql: Query, args: Optional[Sequence[Any]] = None) -> None:
""" Send SQL query to the server. Returns immediately without
blocking.
"""
assert self.cursor is not None
self.current_query = sql
self.current_params = args
self.cursor.execute(sql, args)
def fileno(self) -> int:
""" File descriptor to wait for. (Makes this class select()able.)
"""
assert self.conn is not None
return self.conn.fileno()
def is_done(self) -> bool:
""" Check if the connection is available for a new query.
Also checks if the previous query has run into a deadlock.
If so, then the previous query is repeated.
"""
assert self.conn is not None
if self.current_query is None:
return True
with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
if self.conn.poll() == psycopg2.extensions.POLL_OK:
self.current_query = None
return True
return False
class WorkerPool:
""" A pool of asynchronous database connections.
The pool may be used as a context manager.
"""
REOPEN_CONNECTIONS_AFTER = 100000
def __init__(self, dsn: str, pool_size: int, ignore_sql_errors: bool = False) -> None:
self.threads = [DBConnection(dsn, ignore_sql_errors=ignore_sql_errors)
for _ in range(pool_size)]
self.free_workers = self._yield_free_worker()
self.wait_time = 0.0
def finish_all(self) -> None:
""" Wait for all connection to finish.
"""
for thread in self.threads:
while not thread.is_done():
thread.wait()
self.free_workers = self._yield_free_worker()
def close(self) -> None:
""" Close all connections and clear the pool.
"""
for thread in self.threads:
thread.close()
self.threads = []
self.free_workers = iter([])
def next_free_worker(self) -> DBConnection:
""" Get the next free connection.
"""
return next(self.free_workers)
def _yield_free_worker(self) -> Iterator[DBConnection]:
ready = self.threads
command_stat = 0
while True:
for thread in ready:
if thread.is_done():
command_stat += 1
yield thread
if command_stat > self.REOPEN_CONNECTIONS_AFTER:
self._reconnect_threads()
ready = self.threads
command_stat = 0
else:
tstart = time.time()
_, ready, _ = select.select([], self.threads, [])
self.wait_time += time.time() - tstart
def _reconnect_threads(self) -> None:
for thread in self.threads:
while not thread.is_done():
thread.wait()
thread.connect()
def __enter__(self) -> 'WorkerPool':
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.finish_all()
self.close()

View File

@@ -7,73 +7,27 @@
"""
Specialised connection and cursor functions.
"""
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload,\
Tuple, Iterable
import contextlib
from typing import Optional, Any, Dict, Tuple
import logging
import os
import psycopg2
import psycopg2.extensions
import psycopg2.extras
from psycopg2 import sql as pysql
import psycopg
import psycopg.types.hstore
from psycopg import sql as pysql
from ..typing import SysEnv, Query, T_cursor
from ..typing import SysEnv
from ..errors import UsageError
LOG = logging.getLogger()
class Cursor(psycopg2.extras.DictCursor):
""" A cursor returning dict-like objects and providing specialised
execution functions.
"""
# pylint: disable=arguments-renamed,arguments-differ
def execute(self, query: Query, args: Any = None) -> None:
""" Query execution that logs the SQL query when debugging is enabled.
"""
if LOG.isEnabledFor(logging.DEBUG):
LOG.debug(self.mogrify(query, args).decode('utf-8'))
Cursor = psycopg.Cursor[Any]
Connection = psycopg.Connection[Any]
super().execute(query, args)
def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
template: Optional[Query] = None) -> None:
""" Wrapper for the psycopg2 convenience function to execute
SQL for a list of values.
"""
LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
psycopg2.extras.execute_values(self, sql, argslist, template=template)
class Connection(psycopg2.extensions.connection):
""" A connection that provides the specialised cursor by default and
adds convenience functions for administrating the database.
"""
@overload # type: ignore[override]
def cursor(self) -> Cursor:
...
@overload
def cursor(self, name: str) -> Cursor:
...
@overload
def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
...
def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore
""" Return a new cursor. By default the specialised cursor is returned.
"""
return super().cursor(cursor_factory=cursor_factory, **kwargs)
def execute_scalar(conn: Connection, sql: Query, args: Any = None) -> Any:
def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -> Any:
""" Execute query that returns a single value. The value is returned.
If the query yields more than one row, a ValueError is raised.
"""
with conn.cursor() as cur:
with conn.cursor(row_factory=psycopg.rows.tuple_row) as cur:
cur.execute(sql, args)
if cur.rowcount != 1:
@@ -144,7 +98,7 @@ def server_version_tuple(conn: Connection) -> Tuple[int, int]:
""" Return the server version as a tuple of (major, minor).
Converts correctly for pre-10 and post-10 PostgreSQL versions.
"""
version = conn.server_version
version = conn.info.server_version
if version < 100000:
return (int(version / 10000), int((version % 10000) / 100))
@@ -164,31 +118,25 @@ def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
return (int(version_parts[0]), int(version_parts[1]))
def register_hstore(conn: Connection) -> None:
""" Register the hstore type with psycopg for the connection.
"""
psycopg2.extras.register_hstore(conn)
info = psycopg.types.TypeInfo.fetch(conn, "hstore")
if info is None:
raise RuntimeError('Hstore extension is requested but not installed.')
psycopg.types.hstore.register_hstore(info, conn)
class ConnectionContext(ContextManager[Connection]):
""" Context manager of the connection that also provides direct access
to the underlying connection.
"""
connection: Connection
def connect(dsn: str) -> ConnectionContext:
def connect(dsn: str, **kwargs: Any) -> Connection:
""" Open a connection to the database using the specialised connection
factory. The returned object may be used in conjunction with 'with'.
When used outside a context manager, use the `connection` attribute
to get the connection.
"""
try:
conn = psycopg2.connect(dsn, connection_factory=Connection)
ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
ctxmgr.connection = conn
return ctxmgr
except psycopg2.OperationalError as err:
return psycopg.connect(dsn, row_factory=psycopg.rows.namedtuple_row, **kwargs)
except psycopg.OperationalError as err:
raise UsageError(f"Cannot connect to database: {err}") from err
@@ -233,10 +181,18 @@ def get_pg_env(dsn: str,
"""
env = dict(base_env if base_env is not None else os.environ)
for param, value in psycopg2.extensions.parse_dsn(dsn).items():
for param, value in psycopg.conninfo.conninfo_to_dict(dsn).items():
if param in _PG_CONNECTION_STRINGS:
env[_PG_CONNECTION_STRINGS[param]] = value
env[_PG_CONNECTION_STRINGS[param]] = str(value)
else:
LOG.error("Unknown connection parameter '%s' ignored.", param)
return env
async def run_async_query(dsn: str, query: psycopg.abc.Query) -> None:
""" Open a connection to the database and run a single query
asynchronously.
"""
async with await psycopg.AsyncConnection.connect(dsn) as aconn:
await aconn.execute(query)

View File

@@ -0,0 +1,87 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
A connection pool that executes incoming queries in parallel.
"""
from typing import Any, Tuple, Optional
import asyncio
import logging
import time
import psycopg
LOG = logging.getLogger()
QueueItem = Optional[Tuple[psycopg.abc.Query, Any]]
class QueryPool:
""" Pool to run SQL queries in parallel asynchronous execution.
All queries are run in autocommit mode. If parallel execution leads
to a deadlock, then the query is repeated.
The results of the queries is discarded.
"""
def __init__(self, dsn: str, pool_size: int = 1, **conn_args: Any) -> None:
self.wait_time = 0.0
self.query_queue: 'asyncio.Queue[QueueItem]' = asyncio.Queue(maxsize=2 * pool_size)
self.pool = [asyncio.create_task(self._worker_loop(dsn, **conn_args))
for _ in range(pool_size)]
async def put_query(self, query: psycopg.abc.Query, params: Any) -> None:
""" Schedule a query for execution.
"""
tstart = time.time()
await self.query_queue.put((query, params))
self.wait_time += time.time() - tstart
await asyncio.sleep(0)
async def finish(self) -> None:
""" Wait for all queries to finish and close the pool.
"""
for _ in self.pool:
await self.query_queue.put(None)
tstart = time.time()
await asyncio.wait(self.pool)
self.wait_time += time.time() - tstart
for task in self.pool:
excp = task.exception()
if excp is not None:
raise excp
async def _worker_loop(self, dsn: str, **conn_args: Any) -> None:
conn_args['autocommit'] = True
aconn = await psycopg.AsyncConnection.connect(dsn, **conn_args)
async with aconn:
async with aconn.cursor() as cur:
item = await self.query_queue.get()
while item is not None:
try:
if item[1] is None:
await cur.execute(item[0])
else:
await cur.execute(item[0], item[1])
item = await self.query_queue.get()
except psycopg.errors.DeadlockDetected:
assert item is not None
LOG.info("Deadlock detected (sql = %s, params = %s), retry.",
str(item[0]), str(item[1]))
# item is still valid here, causing a retry
async def __aenter__(self) -> 'QueryPool':
return self
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
await self.finish()

View File

@@ -8,11 +8,12 @@
Preprocessing of SQL files.
"""
from typing import Set, Dict, Any, cast
import jinja2
from .connection import Connection, server_version_tuple, postgis_version_tuple
from .async_connection import WorkerPool
from ..config import Configuration
from ..db.query_pool import QueryPool
def _get_partitions(conn: Connection) -> Set[int]:
""" Get the set of partitions currently in use.
@@ -125,8 +126,8 @@ class SQLPreprocessor:
conn.commit()
def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
**kwargs: Any) -> None:
async def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
**kwargs: Any) -> None:
""" Execute the given SQL files using parallel asynchronous connections.
The keyword arguments may supply additional parameters for
preprocessing.
@@ -138,6 +139,6 @@ class SQLPreprocessor:
parts = sql.split('\n---\n')
with WorkerPool(dsn, num_threads) as pool:
async with QueryPool(dsn, num_threads) as pool:
for part in parts:
pool.next_free_worker().perform(part)
await pool.put_query(part, None)

View File

@@ -7,7 +7,7 @@
"""
Access and helper functions for the status and status log table.
"""
from typing import Optional, Tuple, cast
from typing import Optional, Tuple
import datetime as dt
import logging
import re
@@ -15,20 +15,11 @@ import re
from .connection import Connection, table_exists, execute_scalar
from ..utils.url_utils import get_url
from ..errors import UsageError
from ..typing import TypedDict
LOG = logging.getLogger()
ISODATE_FORMAT = '%Y-%m-%dT%H:%M:%S'
class StatusRow(TypedDict):
""" Dictionary of columns of the import_status table.
"""
lastimportdate: dt.datetime
sequence_id: Optional[int]
indexed: Optional[bool]
def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetime:
""" Determine the date of the database from the newest object in the
data base.
@@ -102,8 +93,9 @@ def get_status(conn: Connection) -> Tuple[Optional[dt.datetime], Optional[int],
if cur.rowcount < 1:
return None, None, None
row = cast(StatusRow, cur.fetchone())
return row['lastimportdate'], row['sequence_id'], row['indexed']
row = cur.fetchone()
assert row
return row.lastimportdate, row.sequence_id, row.indexed
def set_indexed(conn: Connection, state: bool) -> None:

View File

@@ -7,14 +7,13 @@
"""
Helper functions for handling DB accesses.
"""
from typing import IO, Optional, Union, Any, Iterable
from typing import IO, Optional, Union
import subprocess
import logging
import gzip
import io
from pathlib import Path
from .connection import get_pg_env, Cursor
from .connection import get_pg_env
from ..errors import UsageError
LOG = logging.getLogger()
@@ -72,58 +71,3 @@ def execute_file(dsn: str, fname: Path,
if ret != 0 or remain > 0:
raise UsageError("Failed to execute SQL file.")
# List of characters that need to be quoted for the copy command.
_SQL_TRANSLATION = {ord('\\'): '\\\\',
ord('\t'): '\\t',
ord('\n'): '\\n'}
class CopyBuffer:
""" Data collector for the copy_from command.
"""
def __init__(self) -> None:
self.buffer = io.StringIO()
def __enter__(self) -> 'CopyBuffer':
return self
def size(self) -> int:
""" Return the number of bytes the buffer currently contains.
"""
return self.buffer.tell()
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self.buffer is not None:
self.buffer.close()
def add(self, *data: Any) -> None:
""" Add another row of data to the copy buffer.
"""
first = True
for column in data:
if first:
first = False
else:
self.buffer.write('\t')
if column is None:
self.buffer.write('\\N')
else:
self.buffer.write(str(column).translate(_SQL_TRANSLATION))
self.buffer.write('\n')
def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None:
""" Copy all collected data into the given table.
The buffer is empty and reusable after this operation.
"""
if self.buffer.tell() > 0:
self.buffer.seek(0)
cur.copy_from(self.buffer, table, columns=columns)
self.buffer = io.StringIO()