mirror of
https://github.com/osm-search/Nominatim.git
synced 2026-03-12 05:44:06 +00:00
type annotations for DB connection
This commit is contained in:
@@ -7,6 +7,7 @@
|
|||||||
"""
|
"""
|
||||||
Specialised connection and cursor functions.
|
Specialised connection and cursor functions.
|
||||||
"""
|
"""
|
||||||
|
from typing import Union, List, Optional, Any, Callable, ContextManager, Mapping, cast, TypeVar, overload, Tuple, Sequence
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -18,23 +19,26 @@ from psycopg2 import sql as pysql
|
|||||||
|
|
||||||
from nominatim.errors import UsageError
|
from nominatim.errors import UsageError
|
||||||
|
|
||||||
|
Query = Union[str, bytes, pysql.Composable]
|
||||||
|
T = TypeVar('T', bound=psycopg2.extensions.cursor)
|
||||||
|
|
||||||
LOG = logging.getLogger()
|
LOG = logging.getLogger()
|
||||||
|
|
||||||
class _Cursor(psycopg2.extras.DictCursor):
|
class _Cursor(psycopg2.extras.DictCursor):
|
||||||
""" A cursor returning dict-like objects and providing specialised
|
""" A cursor returning dict-like objects and providing specialised
|
||||||
execution functions.
|
execution functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=arguments-renamed,arguments-differ
|
# pylint: disable=arguments-renamed,arguments-differ
|
||||||
def execute(self, query, args=None):
|
def execute(self, query: Query, args: Any = None) -> None:
|
||||||
""" Query execution that logs the SQL query when debugging is enabled.
|
""" Query execution that logs the SQL query when debugging is enabled.
|
||||||
"""
|
"""
|
||||||
LOG.debug(self.mogrify(query, args).decode('utf-8'))
|
if LOG.isEnabledFor(logging.DEBUG):
|
||||||
|
LOG.debug(self.mogrify(query, args).decode('utf-8')) # type: ignore
|
||||||
|
|
||||||
super().execute(query, args)
|
super().execute(query, args)
|
||||||
|
|
||||||
|
|
||||||
def execute_values(self, sql, argslist, template=None):
|
def execute_values(self, sql: Query, argslist: List[Any], template: Optional[str] = None) -> None:
|
||||||
""" Wrapper for the psycopg2 convenience function to execute
|
""" Wrapper for the psycopg2 convenience function to execute
|
||||||
SQL for a list of values.
|
SQL for a list of values.
|
||||||
"""
|
"""
|
||||||
@@ -43,7 +47,7 @@ class _Cursor(psycopg2.extras.DictCursor):
|
|||||||
psycopg2.extras.execute_values(self, sql, argslist, template=template)
|
psycopg2.extras.execute_values(self, sql, argslist, template=template)
|
||||||
|
|
||||||
|
|
||||||
def scalar(self, sql, args=None):
|
def scalar(self, sql: Query, args: Any = None) -> Any:
|
||||||
""" Execute query that returns a single value. The value is returned.
|
""" Execute query that returns a single value. The value is returned.
|
||||||
If the query yields more than one row, a ValueError is raised.
|
If the query yields more than one row, a ValueError is raised.
|
||||||
"""
|
"""
|
||||||
@@ -52,10 +56,13 @@ class _Cursor(psycopg2.extras.DictCursor):
|
|||||||
if self.rowcount != 1:
|
if self.rowcount != 1:
|
||||||
raise RuntimeError("Query did not return a single row.")
|
raise RuntimeError("Query did not return a single row.")
|
||||||
|
|
||||||
return self.fetchone()[0]
|
result = self.fetchone() # type: ignore
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
|
||||||
def drop_table(self, name, if_exists=True, cascade=False):
|
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
|
||||||
""" Drop the table with the given name.
|
""" Drop the table with the given name.
|
||||||
Set `if_exists` to False if a non-existant table should raise
|
Set `if_exists` to False if a non-existant table should raise
|
||||||
an exception instead of just being ignored. If 'cascade' is set
|
an exception instead of just being ignored. If 'cascade' is set
|
||||||
@@ -68,30 +75,41 @@ class _Cursor(psycopg2.extras.DictCursor):
|
|||||||
if cascade:
|
if cascade:
|
||||||
sql += ' CASCADE'
|
sql += ' CASCADE'
|
||||||
|
|
||||||
self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
|
self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class _Connection(psycopg2.extensions.connection):
|
class _Connection(psycopg2.extensions.connection):
|
||||||
""" A connection that provides the specialised cursor by default and
|
""" A connection that provides the specialised cursor by default and
|
||||||
adds convenience functions for administrating the database.
|
adds convenience functions for administrating the database.
|
||||||
"""
|
"""
|
||||||
|
@overload # type: ignore[override]
|
||||||
|
def cursor(self) -> _Cursor:
|
||||||
|
...
|
||||||
|
|
||||||
def cursor(self, cursor_factory=_Cursor, **kwargs):
|
@overload
|
||||||
|
def cursor(self, name: str) -> _Cursor:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def cursor(self, cursor_factory: Callable[..., T]) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
|
def cursor(self, cursor_factory = _Cursor, **kwargs): # type: ignore
|
||||||
""" Return a new cursor. By default the specialised cursor is returned.
|
""" Return a new cursor. By default the specialised cursor is returned.
|
||||||
"""
|
"""
|
||||||
return super().cursor(cursor_factory=cursor_factory, **kwargs)
|
return super().cursor(cursor_factory=cursor_factory, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def table_exists(self, table):
|
def table_exists(self, table: str) -> bool:
|
||||||
""" Check that a table with the given name exists in the database.
|
""" Check that a table with the given name exists in the database.
|
||||||
"""
|
"""
|
||||||
with self.cursor() as cur:
|
with self.cursor() as cur:
|
||||||
num = cur.scalar("""SELECT count(*) FROM pg_tables
|
num = cur.scalar("""SELECT count(*) FROM pg_tables
|
||||||
WHERE tablename = %s and schemaname = 'public'""", (table, ))
|
WHERE tablename = %s and schemaname = 'public'""", (table, ))
|
||||||
return num == 1
|
return num == 1 if isinstance(num, int) else False
|
||||||
|
|
||||||
|
|
||||||
def table_has_column(self, table, column):
|
def table_has_column(self, table: str, column: str) -> bool:
|
||||||
""" Check if the table 'table' exists and has a column with name 'column'.
|
""" Check if the table 'table' exists and has a column with name 'column'.
|
||||||
"""
|
"""
|
||||||
with self.cursor() as cur:
|
with self.cursor() as cur:
|
||||||
@@ -99,10 +117,10 @@ class _Connection(psycopg2.extensions.connection):
|
|||||||
WHERE table_name = %s
|
WHERE table_name = %s
|
||||||
and column_name = %s""",
|
and column_name = %s""",
|
||||||
(table, column))
|
(table, column))
|
||||||
return has_column > 0
|
return has_column > 0 if isinstance(has_column, int) else False
|
||||||
|
|
||||||
|
|
||||||
def index_exists(self, index, table=None):
|
def index_exists(self, index: str, table: Optional[str] = None) -> bool:
|
||||||
""" Check that an index with the given name exists in the database.
|
""" Check that an index with the given name exists in the database.
|
||||||
If table is not None then the index must relate to the given
|
If table is not None then the index must relate to the given
|
||||||
table.
|
table.
|
||||||
@@ -114,13 +132,15 @@ class _Connection(psycopg2.extensions.connection):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if table is not None:
|
if table is not None:
|
||||||
row = cur.fetchone()
|
row = cur.fetchone() # type: ignore
|
||||||
|
if row is None or not isinstance(row[0], str):
|
||||||
|
return False
|
||||||
return row[0] == table
|
return row[0] == table
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def drop_table(self, name, if_exists=True, cascade=False):
|
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
|
||||||
""" Drop the table with the given name.
|
""" Drop the table with the given name.
|
||||||
Set `if_exists` to False if a non-existant table should raise
|
Set `if_exists` to False if a non-existant table should raise
|
||||||
an exception instead of just being ignored.
|
an exception instead of just being ignored.
|
||||||
@@ -130,18 +150,18 @@ class _Connection(psycopg2.extensions.connection):
|
|||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
|
|
||||||
def server_version_tuple(self):
|
def server_version_tuple(self) -> Tuple[int, int]:
|
||||||
""" Return the server version as a tuple of (major, minor).
|
""" Return the server version as a tuple of (major, minor).
|
||||||
Converts correctly for pre-10 and post-10 PostgreSQL versions.
|
Converts correctly for pre-10 and post-10 PostgreSQL versions.
|
||||||
"""
|
"""
|
||||||
version = self.server_version
|
version = self.server_version
|
||||||
if version < 100000:
|
if version < 100000:
|
||||||
return (int(version / 10000), (version % 10000) / 100)
|
return (int(version / 10000), int((version % 10000) / 100))
|
||||||
|
|
||||||
return (int(version / 10000), version % 10000)
|
return (int(version / 10000), version % 10000)
|
||||||
|
|
||||||
|
|
||||||
def postgis_version_tuple(self):
|
def postgis_version_tuple(self) -> Tuple[int, int]:
|
||||||
""" Return the postgis version installed in the database as a
|
""" Return the postgis version installed in the database as a
|
||||||
tuple of (major, minor). Assumes that the PostGIS extension
|
tuple of (major, minor). Assumes that the PostGIS extension
|
||||||
has been installed already.
|
has been installed already.
|
||||||
@@ -149,10 +169,16 @@ class _Connection(psycopg2.extensions.connection):
|
|||||||
with self.cursor() as cur:
|
with self.cursor() as cur:
|
||||||
version = cur.scalar('SELECT postgis_lib_version()')
|
version = cur.scalar('SELECT postgis_lib_version()')
|
||||||
|
|
||||||
return tuple((int(x) for x in version.split('.')[:2]))
|
version_parts = version.split('.')
|
||||||
|
if len(version_parts) < 2:
|
||||||
|
raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
|
||||||
|
|
||||||
|
return (int(version_parts[0]), int(version_parts[1]))
|
||||||
|
|
||||||
def connect(dsn):
|
class _ConnectionContext(ContextManager[_Connection]):
|
||||||
|
connection: _Connection
|
||||||
|
|
||||||
|
def connect(dsn: str) -> _ConnectionContext:
|
||||||
""" Open a connection to the database using the specialised connection
|
""" Open a connection to the database using the specialised connection
|
||||||
factory. The returned object may be used in conjunction with 'with'.
|
factory. The returned object may be used in conjunction with 'with'.
|
||||||
When used outside a context manager, use the `connection` attribute
|
When used outside a context manager, use the `connection` attribute
|
||||||
@@ -160,8 +186,8 @@ def connect(dsn):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
conn = psycopg2.connect(dsn, connection_factory=_Connection)
|
conn = psycopg2.connect(dsn, connection_factory=_Connection)
|
||||||
ctxmgr = contextlib.closing(conn)
|
ctxmgr = cast(_ConnectionContext, contextlib.closing(conn))
|
||||||
ctxmgr.connection = conn
|
ctxmgr.connection = cast(_Connection, conn)
|
||||||
return ctxmgr
|
return ctxmgr
|
||||||
except psycopg2.OperationalError as err:
|
except psycopg2.OperationalError as err:
|
||||||
raise UsageError(f"Cannot connect to database: {err}") from err
|
raise UsageError(f"Cannot connect to database: {err}") from err
|
||||||
@@ -199,7 +225,8 @@ _PG_CONNECTION_STRINGS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_pg_env(dsn, base_env=None):
|
def get_pg_env(dsn: str,
|
||||||
|
base_env: Optional[Mapping[str, Optional[str]]] = None) -> Mapping[str, Optional[str]]:
|
||||||
""" Return a copy of `base_env` with the environment variables for
|
""" Return a copy of `base_env` with the environment variables for
|
||||||
PostgresSQL set up from the given database connection string.
|
PostgresSQL set up from the given database connection string.
|
||||||
If `base_env` is None, then the OS environment is used as a base
|
If `base_env` is None, then the OS environment is used as a base
|
||||||
@@ -207,7 +234,7 @@ def get_pg_env(dsn, base_env=None):
|
|||||||
"""
|
"""
|
||||||
env = dict(base_env if base_env is not None else os.environ)
|
env = dict(base_env if base_env is not None else os.environ)
|
||||||
|
|
||||||
for param, value in psycopg2.extensions.parse_dsn(dsn).items():
|
for param, value in psycopg2.extensions.parse_dsn(dsn).items(): # type: ignore
|
||||||
if param in _PG_CONNECTION_STRINGS:
|
if param in _PG_CONNECTION_STRINGS:
|
||||||
env[_PG_CONNECTION_STRINGS[param]] = value
|
env[_PG_CONNECTION_STRINGS[param]] = value
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user