port code to psycopg3

This commit is contained in:
Sarah Hoffmann
2024-07-05 10:43:10 +02:00
parent 3742fa2929
commit 9659afbade
57 changed files with 800 additions and 1330 deletions

View File

@@ -7,73 +7,27 @@
"""
Specialised connection and cursor functions.
"""
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload,\
Tuple, Iterable
import contextlib
from typing import Optional, Any, Dict, Tuple
import logging
import os
import psycopg2
import psycopg2.extensions
import psycopg2.extras
from psycopg2 import sql as pysql
import psycopg
import psycopg.types.hstore
from psycopg import sql as pysql
from ..typing import SysEnv, Query, T_cursor
from ..typing import SysEnv
from ..errors import UsageError
LOG = logging.getLogger()
class Cursor(psycopg2.extras.DictCursor):
""" A cursor returning dict-like objects and providing specialised
execution functions.
"""
# pylint: disable=arguments-renamed,arguments-differ
def execute(self, query: Query, args: Any = None) -> None:
""" Query execution that logs the SQL query when debugging is enabled.
"""
if LOG.isEnabledFor(logging.DEBUG):
LOG.debug(self.mogrify(query, args).decode('utf-8'))
Cursor = psycopg.Cursor[Any]
Connection = psycopg.Connection[Any]
super().execute(query, args)
def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
template: Optional[Query] = None) -> None:
""" Wrapper for the psycopg2 convenience function to execute
SQL for a list of values.
"""
LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
psycopg2.extras.execute_values(self, sql, argslist, template=template)
class Connection(psycopg2.extensions.connection):
""" A connection that provides the specialised cursor by default and
adds convenience functions for administrating the database.
"""
@overload # type: ignore[override]
def cursor(self) -> Cursor:
...
@overload
def cursor(self, name: str) -> Cursor:
...
@overload
def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
...
def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore
""" Return a new cursor. By default the specialised cursor is returned.
"""
return super().cursor(cursor_factory=cursor_factory, **kwargs)
def execute_scalar(conn: Connection, sql: Query, args: Any = None) -> Any:
def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -> Any:
""" Execute query that returns a single value. The value is returned.
If the query yields more than one row, a ValueError is raised.
"""
with conn.cursor() as cur:
with conn.cursor(row_factory=psycopg.rows.tuple_row) as cur:
cur.execute(sql, args)
if cur.rowcount != 1:
@@ -144,7 +98,7 @@ def server_version_tuple(conn: Connection) -> Tuple[int, int]:
""" Return the server version as a tuple of (major, minor).
Converts correctly for pre-10 and post-10 PostgreSQL versions.
"""
version = conn.server_version
version = conn.info.server_version
if version < 100000:
return (int(version / 10000), int((version % 10000) / 100))
@@ -164,31 +118,25 @@ def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
return (int(version_parts[0]), int(version_parts[1]))
def register_hstore(conn: Connection) -> None:
""" Register the hstore type with psycopg for the connection.
"""
psycopg2.extras.register_hstore(conn)
info = psycopg.types.TypeInfo.fetch(conn, "hstore")
if info is None:
raise RuntimeError('Hstore extension is requested but not installed.')
psycopg.types.hstore.register_hstore(info, conn)
class ConnectionContext(ContextManager[Connection]):
""" Context manager of the connection that also provides direct access
to the underlying connection.
"""
connection: Connection
def connect(dsn: str) -> ConnectionContext:
def connect(dsn: str, **kwargs: Any) -> Connection:
""" Open a connection to the database using the specialised connection
factory. The returned object may be used in conjunction with 'with'.
When used outside a context manager, use the `connection` attribute
to get the connection.
"""
try:
conn = psycopg2.connect(dsn, connection_factory=Connection)
ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
ctxmgr.connection = conn
return ctxmgr
except psycopg2.OperationalError as err:
return psycopg.connect(dsn, row_factory=psycopg.rows.namedtuple_row, **kwargs)
except psycopg.OperationalError as err:
raise UsageError(f"Cannot connect to database: {err}") from err
@@ -233,10 +181,18 @@ def get_pg_env(dsn: str,
"""
env = dict(base_env if base_env is not None else os.environ)
for param, value in psycopg2.extensions.parse_dsn(dsn).items():
for param, value in psycopg.conninfo.conninfo_to_dict(dsn).items():
if param in _PG_CONNECTION_STRINGS:
env[_PG_CONNECTION_STRINGS[param]] = value
env[_PG_CONNECTION_STRINGS[param]] = str(value)
else:
LOG.error("Unknown connection parameter '%s' ignored.", param)
return env
async def run_async_query(dsn: str, query: psycopg.abc.Query) -> None:
""" Open a connection to the database and run a single query
asynchronously.
"""
async with await psycopg.AsyncConnection.connect(dsn) as aconn:
await aconn.execute(query)