type annotations for DB connection

This commit is contained in:
Sarah Hoffmann
2022-07-01 13:55:24 +02:00
parent 9d716f0f7d
commit e6ee3c772c

View File

@@ -7,6 +7,7 @@
""" """
Specialised connection and cursor functions. Specialised connection and cursor functions.
""" """
from typing import Union, List, Optional, Any, Callable, ContextManager, Mapping, cast, TypeVar, overload, Tuple, Sequence
import contextlib import contextlib
import logging import logging
import os import os
@@ -18,23 +19,26 @@ from psycopg2 import sql as pysql
from nominatim.errors import UsageError from nominatim.errors import UsageError
Query = Union[str, bytes, pysql.Composable]
T = TypeVar('T', bound=psycopg2.extensions.cursor)
LOG = logging.getLogger() LOG = logging.getLogger()
class _Cursor(psycopg2.extras.DictCursor): class _Cursor(psycopg2.extras.DictCursor):
""" A cursor returning dict-like objects and providing specialised """ A cursor returning dict-like objects and providing specialised
execution functions. execution functions.
""" """
# pylint: disable=arguments-renamed,arguments-differ # pylint: disable=arguments-renamed,arguments-differ
def execute(self, query, args=None): def execute(self, query: Query, args: Any = None) -> None:
""" Query execution that logs the SQL query when debugging is enabled. """ Query execution that logs the SQL query when debugging is enabled.
""" """
LOG.debug(self.mogrify(query, args).decode('utf-8')) if LOG.isEnabledFor(logging.DEBUG):
LOG.debug(self.mogrify(query, args).decode('utf-8')) # type: ignore
super().execute(query, args) super().execute(query, args)
def execute_values(self, sql, argslist, template=None): def execute_values(self, sql: Query, argslist: List[Any], template: Optional[str] = None) -> None:
""" Wrapper for the psycopg2 convenience function to execute """ Wrapper for the psycopg2 convenience function to execute
SQL for a list of values. SQL for a list of values.
""" """
@@ -43,7 +47,7 @@ class _Cursor(psycopg2.extras.DictCursor):
psycopg2.extras.execute_values(self, sql, argslist, template=template) psycopg2.extras.execute_values(self, sql, argslist, template=template)
def scalar(self, sql, args=None): def scalar(self, sql: Query, args: Any = None) -> Any:
""" Execute query that returns a single value. The value is returned. """ Execute query that returns a single value. The value is returned.
If the query yields more than one row, a ValueError is raised. If the query yields more than one row, a ValueError is raised.
""" """
@@ -52,10 +56,13 @@ class _Cursor(psycopg2.extras.DictCursor):
if self.rowcount != 1: if self.rowcount != 1:
raise RuntimeError("Query did not return a single row.") raise RuntimeError("Query did not return a single row.")
return self.fetchone()[0] result = self.fetchone() # type: ignore
assert result is not None
return result[0]
def drop_table(self, name, if_exists=True, cascade=False): def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
""" Drop the table with the given name. """ Drop the table with the given name.
Set `if_exists` to False if a non-existant table should raise Set `if_exists` to False if a non-existant table should raise
an exception instead of just being ignored. If 'cascade' is set an exception instead of just being ignored. If 'cascade' is set
@@ -68,30 +75,41 @@ class _Cursor(psycopg2.extras.DictCursor):
if cascade: if cascade:
sql += ' CASCADE' sql += ' CASCADE'
self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore
class _Connection(psycopg2.extensions.connection): class _Connection(psycopg2.extensions.connection):
""" A connection that provides the specialised cursor by default and """ A connection that provides the specialised cursor by default and
adds convenience functions for administrating the database. adds convenience functions for administrating the database.
""" """
@overload # type: ignore[override]
def cursor(self) -> _Cursor:
...
def cursor(self, cursor_factory=_Cursor, **kwargs): @overload
def cursor(self, name: str) -> _Cursor:
...
@overload
def cursor(self, cursor_factory: Callable[..., T]) -> T:
...
def cursor(self, cursor_factory = _Cursor, **kwargs): # type: ignore
""" Return a new cursor. By default the specialised cursor is returned. """ Return a new cursor. By default the specialised cursor is returned.
""" """
return super().cursor(cursor_factory=cursor_factory, **kwargs) return super().cursor(cursor_factory=cursor_factory, **kwargs)
def table_exists(self, table): def table_exists(self, table: str) -> bool:
""" Check that a table with the given name exists in the database. """ Check that a table with the given name exists in the database.
""" """
with self.cursor() as cur: with self.cursor() as cur:
num = cur.scalar("""SELECT count(*) FROM pg_tables num = cur.scalar("""SELECT count(*) FROM pg_tables
WHERE tablename = %s and schemaname = 'public'""", (table, )) WHERE tablename = %s and schemaname = 'public'""", (table, ))
return num == 1 return num == 1 if isinstance(num, int) else False
def table_has_column(self, table, column): def table_has_column(self, table: str, column: str) -> bool:
""" Check if the table 'table' exists and has a column with name 'column'. """ Check if the table 'table' exists and has a column with name 'column'.
""" """
with self.cursor() as cur: with self.cursor() as cur:
@@ -99,10 +117,10 @@ class _Connection(psycopg2.extensions.connection):
WHERE table_name = %s WHERE table_name = %s
and column_name = %s""", and column_name = %s""",
(table, column)) (table, column))
return has_column > 0 return has_column > 0 if isinstance(has_column, int) else False
def index_exists(self, index, table=None): def index_exists(self, index: str, table: Optional[str] = None) -> bool:
""" Check that an index with the given name exists in the database. """ Check that an index with the given name exists in the database.
If table is not None then the index must relate to the given If table is not None then the index must relate to the given
table. table.
@@ -114,13 +132,15 @@ class _Connection(psycopg2.extensions.connection):
return False return False
if table is not None: if table is not None:
row = cur.fetchone() row = cur.fetchone() # type: ignore
if row is None or not isinstance(row[0], str):
return False
return row[0] == table return row[0] == table
return True return True
def drop_table(self, name, if_exists=True, cascade=False): def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
""" Drop the table with the given name. """ Drop the table with the given name.
Set `if_exists` to False if a non-existant table should raise Set `if_exists` to False if a non-existant table should raise
an exception instead of just being ignored. an exception instead of just being ignored.
@@ -130,18 +150,18 @@ class _Connection(psycopg2.extensions.connection):
self.commit() self.commit()
def server_version_tuple(self): def server_version_tuple(self) -> Tuple[int, int]:
""" Return the server version as a tuple of (major, minor). """ Return the server version as a tuple of (major, minor).
Converts correctly for pre-10 and post-10 PostgreSQL versions. Converts correctly for pre-10 and post-10 PostgreSQL versions.
""" """
version = self.server_version version = self.server_version
if version < 100000: if version < 100000:
return (int(version / 10000), (version % 10000) / 100) return (int(version / 10000), int((version % 10000) / 100))
return (int(version / 10000), version % 10000) return (int(version / 10000), version % 10000)
def postgis_version_tuple(self): def postgis_version_tuple(self) -> Tuple[int, int]:
""" Return the postgis version installed in the database as a """ Return the postgis version installed in the database as a
tuple of (major, minor). Assumes that the PostGIS extension tuple of (major, minor). Assumes that the PostGIS extension
has been installed already. has been installed already.
@@ -149,10 +169,16 @@ class _Connection(psycopg2.extensions.connection):
with self.cursor() as cur: with self.cursor() as cur:
version = cur.scalar('SELECT postgis_lib_version()') version = cur.scalar('SELECT postgis_lib_version()')
return tuple((int(x) for x in version.split('.')[:2])) version_parts = version.split('.')
if len(version_parts) < 2:
raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
return (int(version_parts[0]), int(version_parts[1]))
def connect(dsn): class _ConnectionContext(ContextManager[_Connection]):
connection: _Connection
def connect(dsn: str) -> _ConnectionContext:
""" Open a connection to the database using the specialised connection """ Open a connection to the database using the specialised connection
factory. The returned object may be used in conjunction with 'with'. factory. The returned object may be used in conjunction with 'with'.
When used outside a context manager, use the `connection` attribute When used outside a context manager, use the `connection` attribute
@@ -160,8 +186,8 @@ def connect(dsn):
""" """
try: try:
conn = psycopg2.connect(dsn, connection_factory=_Connection) conn = psycopg2.connect(dsn, connection_factory=_Connection)
ctxmgr = contextlib.closing(conn) ctxmgr = cast(_ConnectionContext, contextlib.closing(conn))
ctxmgr.connection = conn ctxmgr.connection = cast(_Connection, conn)
return ctxmgr return ctxmgr
except psycopg2.OperationalError as err: except psycopg2.OperationalError as err:
raise UsageError(f"Cannot connect to database: {err}") from err raise UsageError(f"Cannot connect to database: {err}") from err
@@ -199,7 +225,8 @@ _PG_CONNECTION_STRINGS = {
} }
def get_pg_env(dsn, base_env=None): def get_pg_env(dsn: str,
base_env: Optional[Mapping[str, Optional[str]]] = None) -> Mapping[str, Optional[str]]:
""" Return a copy of `base_env` with the environment variables for """ Return a copy of `base_env` with the environment variables for
PostgresSQL set up from the given database connection string. PostgresSQL set up from the given database connection string.
If `base_env` is None, then the OS environment is used as a base If `base_env` is None, then the OS environment is used as a base
@@ -207,7 +234,7 @@ def get_pg_env(dsn, base_env=None):
""" """
env = dict(base_env if base_env is not None else os.environ) env = dict(base_env if base_env is not None else os.environ)
for param, value in psycopg2.extensions.parse_dsn(dsn).items(): for param, value in psycopg2.extensions.parse_dsn(dsn).items(): # type: ignore
if param in _PG_CONNECTION_STRINGS: if param in _PG_CONNECTION_STRINGS:
env[_PG_CONNECTION_STRINGS[param]] = value env[_PG_CONNECTION_STRINGS[param]] = value
else: else: