add type annotation to DB utils

As a cursor is needed as type, make this a public type.
This commit is contained in:
Sarah Hoffmann
2022-07-05 10:46:55 +02:00
parent e6775e713c
commit 26f30bff28
2 changed files with 12 additions and 12 deletions

View File

@@ -22,7 +22,7 @@ from nominatim.errors import UsageError
LOG = logging.getLogger() LOG = logging.getLogger()
class _Cursor(psycopg2.extras.DictCursor): class Cursor(psycopg2.extras.DictCursor):
""" A cursor returning dict-like objects and providing specialised """ A cursor returning dict-like objects and providing specialised
execution functions. execution functions.
""" """
@@ -82,18 +82,18 @@ class Connection(psycopg2.extensions.connection):
adds convenience functions for administrating the database. adds convenience functions for administrating the database.
""" """
@overload # type: ignore[override] @overload # type: ignore[override]
def cursor(self) -> _Cursor: def cursor(self) -> Cursor:
... ...
@overload @overload
def cursor(self, name: str) -> _Cursor: def cursor(self, name: str) -> Cursor:
... ...
@overload @overload
def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor: def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
... ...
def cursor(self, cursor_factory = _Cursor, **kwargs): # type: ignore def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore
""" Return a new cursor. By default the specialised cursor is returned. """ Return a new cursor. By default the specialised cursor is returned.
""" """
return super().cursor(cursor_factory=cursor_factory, **kwargs) return super().cursor(cursor_factory=cursor_factory, **kwargs)

View File

@@ -7,14 +7,14 @@
""" """
Helper functions for handling DB accesses. Helper functions for handling DB accesses.
""" """
from typing import IO, Optional, Union from typing import IO, Optional, Union, Any, Iterable
import subprocess import subprocess
import logging import logging
import gzip import gzip
import io import io
from pathlib import Path from pathlib import Path
from nominatim.db.connection import get_pg_env from nominatim.db.connection import get_pg_env, Cursor
from nominatim.errors import UsageError from nominatim.errors import UsageError
LOG = logging.getLogger() LOG = logging.getLogger()
@@ -84,20 +84,20 @@ class CopyBuffer:
""" Data collector for the copy_from command. """ Data collector for the copy_from command.
""" """
def __init__(self): def __init__(self) -> None:
self.buffer = io.StringIO() self.buffer = io.StringIO()
def __enter__(self): def __enter__(self) -> 'CopyBuffer':
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self.buffer is not None: if self.buffer is not None:
self.buffer.close() self.buffer.close()
def add(self, *data): def add(self, *data: Any) -> None:
""" Add another row of data to the copy buffer. """ Add another row of data to the copy buffer.
""" """
first = True first = True
@@ -113,9 +113,9 @@ class CopyBuffer:
self.buffer.write('\n') self.buffer.write('\n')
def copy_out(self, cur, table, columns=None): def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None:
""" Copy all collected data into the given table. """ Copy all collected data into the given table.
""" """
if self.buffer.tell() > 0: if self.buffer.tell() > 0:
self.buffer.seek(0) self.buffer.seek(0)
cur.copy_from(self.buffer, table, columns=columns) cur.copy_from(self.buffer, table, columns=columns) # type: ignore[no-untyped-call]