make DB helper functions free functions

Also changes the drop function so that it can drop multiple tables
at once.
This commit is contained in:
Sarah Hoffmann
2024-07-02 15:15:50 +02:00
parent 71249bd94a
commit 3742fa2929
30 changed files with 347 additions and 364 deletions

View File

@@ -7,7 +7,8 @@
"""
Specialised connection and cursor functions.
"""
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload,\
Tuple, Iterable
import contextlib
import logging
import os
@@ -46,37 +47,6 @@ class Cursor(psycopg2.extras.DictCursor):
psycopg2.extras.execute_values(self, sql, argslist, template=template)
def scalar(self, sql: Query, args: Any = None) -> Any:
""" Execute query that returns a single value. The value is returned.
If the query yields more than one row, a ValueError is raised.
"""
self.execute(sql, args)
if self.rowcount != 1:
raise RuntimeError("Query did not return a single row.")
result = self.fetchone()
assert result is not None
return result[0]
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
""" Drop the table with the given name.
Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored. If 'cascade' is set
to True then all dependent tables are deleted as well.
"""
sql = 'DROP TABLE '
if if_exists:
sql += 'IF EXISTS '
sql += '{}'
if cascade:
sql += ' CASCADE'
self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
class Connection(psycopg2.extensions.connection):
""" A connection that provides the specialised cursor by default and
adds convenience functions for administrating the database.
@@ -99,80 +69,105 @@ class Connection(psycopg2.extensions.connection):
return super().cursor(cursor_factory=cursor_factory, **kwargs)
def table_exists(self, table: str) -> bool:
""" Check that a table with the given name exists in the database.
"""
with self.cursor() as cur:
num = cur.scalar("""SELECT count(*) FROM pg_tables
WHERE tablename = %s and schemaname = 'public'""", (table, ))
return num == 1 if isinstance(num, int) else False
def execute_scalar(conn: Connection, sql: Query, args: Any = None) -> Any:
""" Execute query that returns a single value. The value is returned.
If the query yields more than one row, a ValueError is raised.
"""
with conn.cursor() as cur:
cur.execute(sql, args)
if cur.rowcount != 1:
raise RuntimeError("Query did not return a single row.")
result = cur.fetchone()
assert result is not None
return result[0]
def table_has_column(self, table: str, column: str) -> bool:
""" Check if the table 'table' exists and has a column with name 'column'.
"""
with self.cursor() as cur:
has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
WHERE table_name = %s
and column_name = %s""",
(table, column))
return has_column > 0 if isinstance(has_column, int) else False
def table_exists(conn: Connection, table: str) -> bool:
""" Check that a table with the given name exists in the database.
"""
num = execute_scalar(conn,
"""SELECT count(*) FROM pg_tables
WHERE tablename = %s and schemaname = 'public'""", (table, ))
return num == 1 if isinstance(num, int) else False
def index_exists(self, index: str, table: Optional[str] = None) -> bool:
""" Check that an index with the given name exists in the database.
If table is not None then the index must relate to the given
table.
"""
with self.cursor() as cur:
cur.execute("""SELECT tablename FROM pg_indexes
WHERE indexname = %s and schemaname = 'public'""", (index, ))
if cur.rowcount == 0:
def table_has_column(conn: Connection, table: str, column: str) -> bool:
""" Check if the table 'table' exists and has a column with name 'column'.
"""
has_column = execute_scalar(conn,
"""SELECT count(*) FROM information_schema.columns
WHERE table_name = %s and column_name = %s""",
(table, column))
return has_column > 0 if isinstance(has_column, int) else False
def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> bool:
""" Check that an index with the given name exists in the database.
If table is not None then the index must relate to the given
table.
"""
with conn.cursor() as cur:
cur.execute("""SELECT tablename FROM pg_indexes
WHERE indexname = %s and schemaname = 'public'""", (index, ))
if cur.rowcount == 0:
return False
if table is not None:
row = cur.fetchone()
if row is None or not isinstance(row[0], str):
return False
return row[0] == table
if table is not None:
row = cur.fetchone()
if row is None or not isinstance(row[0], str):
return False
return row[0] == table
return True
return True
def drop_tables(conn: Connection, *names: str,
if_exists: bool = True, cascade: bool = False) -> None:
""" Drop one or more tables with the given names.
Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored. `cascade` will cause
depended objects to be dropped as well.
The caller needs to take care of committing the change.
"""
sql = pysql.SQL('DROP TABLE%s{}%s' % (
' IF EXISTS ' if if_exists else ' ',
' CASCADE' if cascade else ''))
with conn.cursor() as cur:
for name in names:
cur.execute(sql.format(pysql.Identifier(name)))
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
""" Drop the table with the given name.
Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored.
"""
with self.cursor() as cur:
cur.drop_table(name, if_exists, cascade)
self.commit()
def server_version_tuple(conn: Connection) -> Tuple[int, int]:
""" Return the server version as a tuple of (major, minor).
Converts correctly for pre-10 and post-10 PostgreSQL versions.
"""
version = conn.server_version
if version < 100000:
return (int(version / 10000), int((version % 10000) / 100))
return (int(version / 10000), version % 10000)
def server_version_tuple(self) -> Tuple[int, int]:
""" Return the server version as a tuple of (major, minor).
Converts correctly for pre-10 and post-10 PostgreSQL versions.
"""
version = self.server_version
if version < 100000:
return (int(version / 10000), int((version % 10000) / 100))
def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
""" Return the postgis version installed in the database as a
tuple of (major, minor). Assumes that the PostGIS extension
has been installed already.
"""
version = execute_scalar(conn, 'SELECT postgis_lib_version()')
return (int(version / 10000), version % 10000)
version_parts = version.split('.')
if len(version_parts) < 2:
raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
return (int(version_parts[0]), int(version_parts[1]))
def postgis_version_tuple(self) -> Tuple[int, int]:
""" Return the postgis version installed in the database as a
tuple of (major, minor). Assumes that the PostGIS extension
has been installed already.
"""
with self.cursor() as cur:
version = cur.scalar('SELECT postgis_lib_version()')
version_parts = version.split('.')
if len(version_parts) < 2:
raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
return (int(version_parts[0]), int(version_parts[1]))
def register_hstore(conn: Connection) -> None:
""" Register the hstore type with psycopg for the connection.
"""
psycopg2.extras.register_hstore(conn)
class ConnectionContext(ContextManager[Connection]):

View File

@@ -9,7 +9,7 @@ Query and access functions for the in-database property table.
"""
from typing import Optional, cast
from .connection import Connection
from .connection import Connection, table_exists
def set_property(conn: Connection, name: str, value: str) -> None:
""" Add or replace the property with the given name.
@@ -31,7 +31,7 @@ def get_property(conn: Connection, name: str) -> Optional[str]:
""" Return the current value of the given property or None if the property
is not set.
"""
if not conn.table_exists('nominatim_properties'):
if not table_exists(conn, 'nominatim_properties'):
return None
with conn.cursor() as cur:

View File

@@ -10,7 +10,7 @@ Preprocessing of SQL files.
from typing import Set, Dict, Any, cast
import jinja2
from .connection import Connection
from .connection import Connection, server_version_tuple, postgis_version_tuple
from .async_connection import WorkerPool
from ..config import Configuration
@@ -66,8 +66,8 @@ def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
""" Set up a dictionary with various optional Postgresql/Postgis features that
depend on the database version.
"""
pg_version = conn.server_version_tuple()
postgis_version = conn.postgis_version_tuple()
pg_version = server_version_tuple(conn)
postgis_version = postgis_version_tuple(conn)
pg11plus = pg_version >= (11, 0, 0)
ps3 = postgis_version >= (3, 0)
return {

View File

@@ -12,7 +12,7 @@ import datetime as dt
import logging
import re
from .connection import Connection
from .connection import Connection, table_exists, execute_scalar
from ..utils.url_utils import get_url
from ..errors import UsageError
from ..typing import TypedDict
@@ -34,7 +34,7 @@ def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetim
data base.
"""
# If there is a date from osm2pgsql available, use that.
if conn.table_exists('osm2pgsql_properties'):
if table_exists(conn, 'osm2pgsql_properties'):
with conn.cursor() as cur:
cur.execute(""" SELECT value FROM osm2pgsql_properties
WHERE property = 'current_timestamp' """)
@@ -47,15 +47,14 @@ def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetim
raise UsageError("Cannot determine database date from data in offline mode.")
# Else, find the node with the highest ID in the database
with conn.cursor() as cur:
if conn.table_exists('place'):
osmid = cur.scalar("SELECT max(osm_id) FROM place WHERE osm_type='N'")
else:
osmid = cur.scalar("SELECT max(osm_id) FROM placex WHERE osm_type='N'")
if table_exists(conn, 'place'):
osmid = execute_scalar(conn, "SELECT max(osm_id) FROM place WHERE osm_type='N'")
else:
osmid = execute_scalar(conn, "SELECT max(osm_id) FROM placex WHERE osm_type='N'")
if osmid is None:
LOG.fatal("No data found in the database.")
raise UsageError("No data found in the database.")
if osmid is None:
LOG.fatal("No data found in the database.")
raise UsageError("No data found in the database.")
LOG.info("Using node id %d for timestamp lookup", osmid)
# Get the node from the API to find the timestamp when it was created.