mirror of
https://github.com/osm-search/Nominatim.git
synced 2026-02-26 11:08:13 +00:00
port code to psycopg3
This commit is contained in:
@@ -7,73 +7,27 @@
|
||||
"""
|
||||
Specialised connection and cursor functions.
|
||||
"""
|
||||
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload,\
|
||||
Tuple, Iterable
|
||||
import contextlib
|
||||
from typing import Optional, Any, Dict, Tuple
|
||||
import logging
|
||||
import os
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
import psycopg2.extras
|
||||
from psycopg2 import sql as pysql
|
||||
import psycopg
|
||||
import psycopg.types.hstore
|
||||
from psycopg import sql as pysql
|
||||
|
||||
from ..typing import SysEnv, Query, T_cursor
|
||||
from ..typing import SysEnv
|
||||
from ..errors import UsageError
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
class Cursor(psycopg2.extras.DictCursor):
|
||||
""" A cursor returning dict-like objects and providing specialised
|
||||
execution functions.
|
||||
"""
|
||||
# pylint: disable=arguments-renamed,arguments-differ
|
||||
def execute(self, query: Query, args: Any = None) -> None:
|
||||
""" Query execution that logs the SQL query when debugging is enabled.
|
||||
"""
|
||||
if LOG.isEnabledFor(logging.DEBUG):
|
||||
LOG.debug(self.mogrify(query, args).decode('utf-8'))
|
||||
Cursor = psycopg.Cursor[Any]
|
||||
Connection = psycopg.Connection[Any]
|
||||
|
||||
super().execute(query, args)
|
||||
|
||||
|
||||
def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
|
||||
template: Optional[Query] = None) -> None:
|
||||
""" Wrapper for the psycopg2 convenience function to execute
|
||||
SQL for a list of values.
|
||||
"""
|
||||
LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
|
||||
|
||||
psycopg2.extras.execute_values(self, sql, argslist, template=template)
|
||||
|
||||
|
||||
class Connection(psycopg2.extensions.connection):
|
||||
""" A connection that provides the specialised cursor by default and
|
||||
adds convenience functions for administrating the database.
|
||||
"""
|
||||
@overload # type: ignore[override]
|
||||
def cursor(self) -> Cursor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cursor(self, name: str) -> Cursor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
|
||||
...
|
||||
|
||||
def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore
|
||||
""" Return a new cursor. By default the specialised cursor is returned.
|
||||
"""
|
||||
return super().cursor(cursor_factory=cursor_factory, **kwargs)
|
||||
|
||||
|
||||
def execute_scalar(conn: Connection, sql: Query, args: Any = None) -> Any:
|
||||
def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -> Any:
|
||||
""" Execute query that returns a single value. The value is returned.
|
||||
If the query yields more than one row, a ValueError is raised.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
with conn.cursor(row_factory=psycopg.rows.tuple_row) as cur:
|
||||
cur.execute(sql, args)
|
||||
|
||||
if cur.rowcount != 1:
|
||||
@@ -144,7 +98,7 @@ def server_version_tuple(conn: Connection) -> Tuple[int, int]:
|
||||
""" Return the server version as a tuple of (major, minor).
|
||||
Converts correctly for pre-10 and post-10 PostgreSQL versions.
|
||||
"""
|
||||
version = conn.server_version
|
||||
version = conn.info.server_version
|
||||
if version < 100000:
|
||||
return (int(version / 10000), int((version % 10000) / 100))
|
||||
|
||||
@@ -164,31 +118,25 @@ def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
|
||||
|
||||
return (int(version_parts[0]), int(version_parts[1]))
|
||||
|
||||
|
||||
def register_hstore(conn: Connection) -> None:
|
||||
""" Register the hstore type with psycopg for the connection.
|
||||
"""
|
||||
psycopg2.extras.register_hstore(conn)
|
||||
info = psycopg.types.TypeInfo.fetch(conn, "hstore")
|
||||
if info is None:
|
||||
raise RuntimeError('Hstore extension is requested but not installed.')
|
||||
psycopg.types.hstore.register_hstore(info, conn)
|
||||
|
||||
|
||||
class ConnectionContext(ContextManager[Connection]):
|
||||
""" Context manager of the connection that also provides direct access
|
||||
to the underlying connection.
|
||||
"""
|
||||
connection: Connection
|
||||
|
||||
|
||||
def connect(dsn: str) -> ConnectionContext:
|
||||
def connect(dsn: str, **kwargs: Any) -> Connection:
|
||||
""" Open a connection to the database using the specialised connection
|
||||
factory. The returned object may be used in conjunction with 'with'.
|
||||
When used outside a context manager, use the `connection` attribute
|
||||
to get the connection.
|
||||
"""
|
||||
try:
|
||||
conn = psycopg2.connect(dsn, connection_factory=Connection)
|
||||
ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
|
||||
ctxmgr.connection = conn
|
||||
return ctxmgr
|
||||
except psycopg2.OperationalError as err:
|
||||
return psycopg.connect(dsn, row_factory=psycopg.rows.namedtuple_row, **kwargs)
|
||||
except psycopg.OperationalError as err:
|
||||
raise UsageError(f"Cannot connect to database: {err}") from err
|
||||
|
||||
|
||||
@@ -233,10 +181,18 @@ def get_pg_env(dsn: str,
|
||||
"""
|
||||
env = dict(base_env if base_env is not None else os.environ)
|
||||
|
||||
for param, value in psycopg2.extensions.parse_dsn(dsn).items():
|
||||
for param, value in psycopg.conninfo.conninfo_to_dict(dsn).items():
|
||||
if param in _PG_CONNECTION_STRINGS:
|
||||
env[_PG_CONNECTION_STRINGS[param]] = value
|
||||
env[_PG_CONNECTION_STRINGS[param]] = str(value)
|
||||
else:
|
||||
LOG.error("Unknown connection parameter '%s' ignored.", param)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
async def run_async_query(dsn: str, query: psycopg.abc.Query) -> None:
|
||||
""" Open a connection to the database and run a single query
|
||||
asynchronously.
|
||||
"""
|
||||
async with await psycopg.AsyncConnection.connect(dsn) as aconn:
|
||||
await aconn.execute(query)
|
||||
|
||||
Reference in New Issue
Block a user