fix style issue found by flake8

This commit is contained in:
Sarah Hoffmann
2024-11-10 22:47:14 +01:00
parent 8c14df55a6
commit 1f07967787
112 changed files with 656 additions and 1109 deletions

View File

@@ -23,6 +23,7 @@ LOG = logging.getLogger()
Cursor = psycopg.Cursor[Any]
Connection = psycopg.Connection[Any]
def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -> Any:
""" Execute query that returns a single value. The value is returned.
If the query yields more than one row, a ValueError is raised.
@@ -42,9 +43,10 @@ def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -
def table_exists(conn: Connection, table: str) -> bool:
""" Check that a table with the given name exists in the database.
"""
num = execute_scalar(conn,
"""SELECT count(*) FROM pg_tables
WHERE tablename = %s and schemaname = 'public'""", (table, ))
num = execute_scalar(
conn,
"""SELECT count(*) FROM pg_tables
WHERE tablename = %s and schemaname = 'public'""", (table, ))
return num == 1 if isinstance(num, int) else False
@@ -52,9 +54,9 @@ def table_has_column(conn: Connection, table: str, column: str) -> bool:
""" Check if the table 'table' exists and has a column with name 'column'.
"""
has_column = execute_scalar(conn,
"""SELECT count(*) FROM information_schema.columns
WHERE table_name = %s and column_name = %s""",
(table, column))
"""SELECT count(*) FROM information_schema.columns
WHERE table_name = %s and column_name = %s""",
(table, column))
return has_column > 0 if isinstance(has_column, int) else False
@@ -77,8 +79,9 @@ def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> b
return True
def drop_tables(conn: Connection, *names: str,
if_exists: bool = True, cascade: bool = False) -> None:
if_exists: bool = True, cascade: bool = False) -> None:
""" Drop one or more tables with the given names.
Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored. `cascade` will cause

View File

@@ -11,6 +11,7 @@ from typing import Optional, cast
from .connection import Connection, table_exists
def set_property(conn: Connection, name: str, value: str) -> None:
""" Add or replace the property with the given name.
"""

View File

@@ -18,6 +18,7 @@ LOG = logging.getLogger()
QueueItem = Optional[Tuple[psycopg.abc.Query, Any]]
class QueryPool:
""" Pool to run SQL queries in parallel asynchronous execution.
@@ -32,7 +33,6 @@ class QueryPool:
self.pool = [asyncio.create_task(self._worker_loop(dsn, **conn_args))
for _ in range(pool_size)]
async def put_query(self, query: psycopg.abc.Query, params: Any) -> None:
""" Schedule a query for execution.
"""
@@ -41,7 +41,6 @@ class QueryPool:
self.wait_time += time.time() - tstart
await asyncio.sleep(0)
async def finish(self) -> None:
""" Wait for all queries to finish and close the pool.
"""
@@ -57,7 +56,6 @@ class QueryPool:
if excp is not None:
raise excp
async def _worker_loop(self, dsn: str, **conn_args: Any) -> None:
conn_args['autocommit'] = True
aconn = await psycopg.AsyncConnection.connect(dsn, **conn_args)
@@ -78,10 +76,8 @@ class QueryPool:
str(item[0]), str(item[1]))
# item is still valid here, causing a retry
async def __aenter__(self) -> 'QueryPool':
return self
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
await self.finish()

View File

@@ -15,6 +15,7 @@ from .connection import Connection, server_version_tuple, postgis_version_tuple
from ..config import Configuration
from ..db.query_pool import QueryPool
def _get_partitions(conn: Connection) -> Set[int]:
""" Get the set of partitions currently in use.
"""
@@ -35,6 +36,7 @@ def _get_tables(conn: Connection) -> Set[str]:
return set((row[0] for row in list(cur)))
def _get_middle_db_format(conn: Connection, tables: Set[str]) -> str:
""" Returns the version of the slim middle tables.
"""
@@ -73,9 +75,10 @@ def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
ps3 = postgis_version >= (3, 0)
return {
'has_index_non_key_column': pg11plus,
'spgist_geom' : 'SPGIST' if pg11plus and ps3 else 'GIST'
'spgist_geom': 'SPGIST' if pg11plus and ps3 else 'GIST'
}
class SQLPreprocessor:
""" A environment for preprocessing SQL files from the
lib-sql directory.
@@ -102,7 +105,6 @@ class SQLPreprocessor:
self.env.globals['db'] = db_info
self.env.globals['postgres'] = _setup_postgresql_features(conn)
def run_string(self, conn: Connection, template: str, **kwargs: Any) -> None:
""" Execute the given SQL template string on the connection.
The keyword arguments may supply additional parameters
@@ -114,7 +116,6 @@ class SQLPreprocessor:
cur.execute(sql)
conn.commit()
def run_sql_file(self, conn: Connection, name: str, **kwargs: Any) -> None:
""" Execute the given SQL file on the connection. The keyword arguments
may supply additional parameters for preprocessing.
@@ -125,7 +126,6 @@ class SQLPreprocessor:
cur.execute(sql)
conn.commit()
async def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
**kwargs: Any) -> None:
""" Execute the given SQL files using parallel asynchronous connections.

View File

@@ -18,6 +18,7 @@ from ..errors import UsageError
LOG = logging.getLogger()
def _pipe_to_proc(proc: 'subprocess.Popen[bytes]',
fdesc: Union[IO[bytes], gzip.GzipFile]) -> int:
assert proc.stdin is not None
@@ -31,6 +32,7 @@ def _pipe_to_proc(proc: 'subprocess.Popen[bytes]',
return len(chunk)
def execute_file(dsn: str, fname: Path,
ignore_errors: bool = False,
pre_code: Optional[str] = None,