add type annotations to freeze functions

This commit is contained in:
Sarah Hoffmann
2022-07-03 19:04:05 +02:00
parent aaf2b6032e
commit 845c43137a
2 changed files with 15 additions and 9 deletions

View File

@@ -77,7 +77,7 @@ class _Cursor(psycopg2.extras.DictCursor):
self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore
class _Connection(psycopg2.extensions.connection): class Connection(psycopg2.extensions.connection):
""" A connection that provides the specialised cursor by default and """ A connection that provides the specialised cursor by default and
adds convenience functions for administrating the database. adds convenience functions for administrating the database.
""" """
@@ -174,19 +174,22 @@ class _Connection(psycopg2.extensions.connection):
return (int(version_parts[0]), int(version_parts[1])) return (int(version_parts[0]), int(version_parts[1]))
class _ConnectionContext(ContextManager[_Connection]): class ConnectionContext(ContextManager[Connection]):
connection: _Connection """ Context manager of the connection that also provides direct access
to the underlying connection.
"""
connection: Connection
def connect(dsn: str) -> _ConnectionContext: def connect(dsn: str) -> ConnectionContext:
""" Open a connection to the database using the specialised connection """ Open a connection to the database using the specialised connection
factory. The returned object may be used in conjunction with 'with'. factory. The returned object may be used in conjunction with 'with'.
When used outside a context manager, use the `connection` attribute When used outside a context manager, use the `connection` attribute
to get the connection. to get the connection.
""" """
try: try:
conn = psycopg2.connect(dsn, connection_factory=_Connection) conn = psycopg2.connect(dsn, connection_factory=Connection)
ctxmgr = cast(_ConnectionContext, contextlib.closing(conn)) ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
ctxmgr.connection = cast(_Connection, conn) ctxmgr.connection = cast(Connection, conn)
return ctxmgr return ctxmgr
except psycopg2.OperationalError as err: except psycopg2.OperationalError as err:
raise UsageError(f"Cannot connect to database: {err}") from err raise UsageError(f"Cannot connect to database: {err}") from err

View File

@@ -7,10 +7,13 @@
""" """
Functions for removing unnecessary data from the database. Functions for removing unnecessary data from the database.
""" """
from typing import Optional
from pathlib import Path from pathlib import Path
from psycopg2 import sql as pysql from psycopg2 import sql as pysql
from nominatim.db.connection import Connection
UPDATE_TABLES = [ UPDATE_TABLES = [
'address_levels', 'address_levels',
'gb_postcode', 'gb_postcode',
@@ -25,7 +28,7 @@ UPDATE_TABLES = [
'wikipedia_%' 'wikipedia_%'
] ]
def drop_update_tables(conn): def drop_update_tables(conn: Connection) -> None:
""" Drop all tables only necessary for updating the database from """ Drop all tables only necessary for updating the database from
OSM replication data. OSM replication data.
""" """
@@ -42,7 +45,7 @@ def drop_update_tables(conn):
conn.commit() conn.commit()
def drop_flatnode_file(fpath): def drop_flatnode_file(fpath: Optional[Path]) -> None:
""" Remove the flatnode file if it exists. """ Remove the flatnode file if it exists.
""" """
if fpath and fpath.exists(): if fpath and fpath.exists():