mirror of
https://github.com/osm-search/Nominatim.git
synced 2026-02-26 11:08:13 +00:00
make DB helper functions free functions
Also changes the drop function so that it can drop multiple tables at once.
This commit is contained in:
@@ -7,7 +7,8 @@
|
||||
"""
|
||||
Specialised connection and cursor functions.
|
||||
"""
|
||||
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
|
||||
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload,\
|
||||
Tuple, Iterable
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
@@ -46,37 +47,6 @@ class Cursor(psycopg2.extras.DictCursor):
|
||||
psycopg2.extras.execute_values(self, sql, argslist, template=template)
|
||||
|
||||
|
||||
def scalar(self, sql: Query, args: Any = None) -> Any:
|
||||
""" Execute query that returns a single value. The value is returned.
|
||||
If the query yields more than one row, a ValueError is raised.
|
||||
"""
|
||||
self.execute(sql, args)
|
||||
|
||||
if self.rowcount != 1:
|
||||
raise RuntimeError("Query did not return a single row.")
|
||||
|
||||
result = self.fetchone()
|
||||
assert result is not None
|
||||
|
||||
return result[0]
|
||||
|
||||
|
||||
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
|
||||
""" Drop the table with the given name.
|
||||
Set `if_exists` to False if a non-existent table should raise
|
||||
an exception instead of just being ignored. If 'cascade' is set
|
||||
to True then all dependent tables are deleted as well.
|
||||
"""
|
||||
sql = 'DROP TABLE '
|
||||
if if_exists:
|
||||
sql += 'IF EXISTS '
|
||||
sql += '{}'
|
||||
if cascade:
|
||||
sql += ' CASCADE'
|
||||
|
||||
self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
|
||||
|
||||
|
||||
class Connection(psycopg2.extensions.connection):
|
||||
""" A connection that provides the specialised cursor by default and
|
||||
adds convenience functions for administrating the database.
|
||||
@@ -99,80 +69,105 @@ class Connection(psycopg2.extensions.connection):
|
||||
return super().cursor(cursor_factory=cursor_factory, **kwargs)
|
||||
|
||||
|
||||
def table_exists(self, table: str) -> bool:
|
||||
""" Check that a table with the given name exists in the database.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
num = cur.scalar("""SELECT count(*) FROM pg_tables
|
||||
WHERE tablename = %s and schemaname = 'public'""", (table, ))
|
||||
return num == 1 if isinstance(num, int) else False
|
||||
def execute_scalar(conn: Connection, sql: Query, args: Any = None) -> Any:
|
||||
""" Execute query that returns a single value. The value is returned.
|
||||
If the query yields more than one row, a ValueError is raised.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, args)
|
||||
|
||||
if cur.rowcount != 1:
|
||||
raise RuntimeError("Query did not return a single row.")
|
||||
|
||||
result = cur.fetchone()
|
||||
|
||||
assert result is not None
|
||||
return result[0]
|
||||
|
||||
|
||||
def table_has_column(self, table: str, column: str) -> bool:
|
||||
""" Check if the table 'table' exists and has a column with name 'column'.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
|
||||
WHERE table_name = %s
|
||||
and column_name = %s""",
|
||||
(table, column))
|
||||
return has_column > 0 if isinstance(has_column, int) else False
|
||||
def table_exists(conn: Connection, table: str) -> bool:
|
||||
""" Check that a table with the given name exists in the database.
|
||||
"""
|
||||
num = execute_scalar(conn,
|
||||
"""SELECT count(*) FROM pg_tables
|
||||
WHERE tablename = %s and schemaname = 'public'""", (table, ))
|
||||
return num == 1 if isinstance(num, int) else False
|
||||
|
||||
|
||||
def index_exists(self, index: str, table: Optional[str] = None) -> bool:
|
||||
""" Check that an index with the given name exists in the database.
|
||||
If table is not None then the index must relate to the given
|
||||
table.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute("""SELECT tablename FROM pg_indexes
|
||||
WHERE indexname = %s and schemaname = 'public'""", (index, ))
|
||||
if cur.rowcount == 0:
|
||||
def table_has_column(conn: Connection, table: str, column: str) -> bool:
|
||||
""" Check if the table 'table' exists and has a column with name 'column'.
|
||||
"""
|
||||
has_column = execute_scalar(conn,
|
||||
"""SELECT count(*) FROM information_schema.columns
|
||||
WHERE table_name = %s and column_name = %s""",
|
||||
(table, column))
|
||||
return has_column > 0 if isinstance(has_column, int) else False
|
||||
|
||||
|
||||
def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> bool:
|
||||
""" Check that an index with the given name exists in the database.
|
||||
If table is not None then the index must relate to the given
|
||||
table.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""SELECT tablename FROM pg_indexes
|
||||
WHERE indexname = %s and schemaname = 'public'""", (index, ))
|
||||
if cur.rowcount == 0:
|
||||
return False
|
||||
|
||||
if table is not None:
|
||||
row = cur.fetchone()
|
||||
if row is None or not isinstance(row[0], str):
|
||||
return False
|
||||
return row[0] == table
|
||||
|
||||
if table is not None:
|
||||
row = cur.fetchone()
|
||||
if row is None or not isinstance(row[0], str):
|
||||
return False
|
||||
return row[0] == table
|
||||
return True
|
||||
|
||||
return True
|
||||
def drop_tables(conn: Connection, *names: str,
|
||||
if_exists: bool = True, cascade: bool = False) -> None:
|
||||
""" Drop one or more tables with the given names.
|
||||
Set `if_exists` to False if a non-existent table should raise
|
||||
an exception instead of just being ignored. `cascade` will cause
|
||||
depended objects to be dropped as well.
|
||||
The caller needs to take care of committing the change.
|
||||
"""
|
||||
sql = pysql.SQL('DROP TABLE%s{}%s' % (
|
||||
' IF EXISTS ' if if_exists else ' ',
|
||||
' CASCADE' if cascade else ''))
|
||||
|
||||
with conn.cursor() as cur:
|
||||
for name in names:
|
||||
cur.execute(sql.format(pysql.Identifier(name)))
|
||||
|
||||
|
||||
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
|
||||
""" Drop the table with the given name.
|
||||
Set `if_exists` to False if a non-existent table should raise
|
||||
an exception instead of just being ignored.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.drop_table(name, if_exists, cascade)
|
||||
self.commit()
|
||||
def server_version_tuple(conn: Connection) -> Tuple[int, int]:
|
||||
""" Return the server version as a tuple of (major, minor).
|
||||
Converts correctly for pre-10 and post-10 PostgreSQL versions.
|
||||
"""
|
||||
version = conn.server_version
|
||||
if version < 100000:
|
||||
return (int(version / 10000), int((version % 10000) / 100))
|
||||
|
||||
return (int(version / 10000), version % 10000)
|
||||
|
||||
|
||||
def server_version_tuple(self) -> Tuple[int, int]:
|
||||
""" Return the server version as a tuple of (major, minor).
|
||||
Converts correctly for pre-10 and post-10 PostgreSQL versions.
|
||||
"""
|
||||
version = self.server_version
|
||||
if version < 100000:
|
||||
return (int(version / 10000), int((version % 10000) / 100))
|
||||
def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
|
||||
""" Return the postgis version installed in the database as a
|
||||
tuple of (major, minor). Assumes that the PostGIS extension
|
||||
has been installed already.
|
||||
"""
|
||||
version = execute_scalar(conn, 'SELECT postgis_lib_version()')
|
||||
|
||||
return (int(version / 10000), version % 10000)
|
||||
version_parts = version.split('.')
|
||||
if len(version_parts) < 2:
|
||||
raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
|
||||
|
||||
return (int(version_parts[0]), int(version_parts[1]))
|
||||
|
||||
def postgis_version_tuple(self) -> Tuple[int, int]:
|
||||
""" Return the postgis version installed in the database as a
|
||||
tuple of (major, minor). Assumes that the PostGIS extension
|
||||
has been installed already.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
version = cur.scalar('SELECT postgis_lib_version()')
|
||||
|
||||
version_parts = version.split('.')
|
||||
if len(version_parts) < 2:
|
||||
raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
|
||||
|
||||
return (int(version_parts[0]), int(version_parts[1]))
|
||||
def register_hstore(conn: Connection) -> None:
|
||||
""" Register the hstore type with psycopg for the connection.
|
||||
"""
|
||||
psycopg2.extras.register_hstore(conn)
|
||||
|
||||
|
||||
class ConnectionContext(ContextManager[Connection]):
|
||||
|
||||
@@ -9,7 +9,7 @@ Query and access functions for the in-database property table.
|
||||
"""
|
||||
from typing import Optional, cast
|
||||
|
||||
from .connection import Connection
|
||||
from .connection import Connection, table_exists
|
||||
|
||||
def set_property(conn: Connection, name: str, value: str) -> None:
|
||||
""" Add or replace the property with the given name.
|
||||
@@ -31,7 +31,7 @@ def get_property(conn: Connection, name: str) -> Optional[str]:
|
||||
""" Return the current value of the given property or None if the property
|
||||
is not set.
|
||||
"""
|
||||
if not conn.table_exists('nominatim_properties'):
|
||||
if not table_exists(conn, 'nominatim_properties'):
|
||||
return None
|
||||
|
||||
with conn.cursor() as cur:
|
||||
|
||||
@@ -10,7 +10,7 @@ Preprocessing of SQL files.
|
||||
from typing import Set, Dict, Any, cast
|
||||
import jinja2
|
||||
|
||||
from .connection import Connection
|
||||
from .connection import Connection, server_version_tuple, postgis_version_tuple
|
||||
from .async_connection import WorkerPool
|
||||
from ..config import Configuration
|
||||
|
||||
@@ -66,8 +66,8 @@ def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
|
||||
""" Set up a dictionary with various optional Postgresql/Postgis features that
|
||||
depend on the database version.
|
||||
"""
|
||||
pg_version = conn.server_version_tuple()
|
||||
postgis_version = conn.postgis_version_tuple()
|
||||
pg_version = server_version_tuple(conn)
|
||||
postgis_version = postgis_version_tuple(conn)
|
||||
pg11plus = pg_version >= (11, 0, 0)
|
||||
ps3 = postgis_version >= (3, 0)
|
||||
return {
|
||||
|
||||
@@ -12,7 +12,7 @@ import datetime as dt
|
||||
import logging
|
||||
import re
|
||||
|
||||
from .connection import Connection
|
||||
from .connection import Connection, table_exists, execute_scalar
|
||||
from ..utils.url_utils import get_url
|
||||
from ..errors import UsageError
|
||||
from ..typing import TypedDict
|
||||
@@ -34,7 +34,7 @@ def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetim
|
||||
data base.
|
||||
"""
|
||||
# If there is a date from osm2pgsql available, use that.
|
||||
if conn.table_exists('osm2pgsql_properties'):
|
||||
if table_exists(conn, 'osm2pgsql_properties'):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(""" SELECT value FROM osm2pgsql_properties
|
||||
WHERE property = 'current_timestamp' """)
|
||||
@@ -47,15 +47,14 @@ def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetim
|
||||
raise UsageError("Cannot determine database date from data in offline mode.")
|
||||
|
||||
# Else, find the node with the highest ID in the database
|
||||
with conn.cursor() as cur:
|
||||
if conn.table_exists('place'):
|
||||
osmid = cur.scalar("SELECT max(osm_id) FROM place WHERE osm_type='N'")
|
||||
else:
|
||||
osmid = cur.scalar("SELECT max(osm_id) FROM placex WHERE osm_type='N'")
|
||||
if table_exists(conn, 'place'):
|
||||
osmid = execute_scalar(conn, "SELECT max(osm_id) FROM place WHERE osm_type='N'")
|
||||
else:
|
||||
osmid = execute_scalar(conn, "SELECT max(osm_id) FROM placex WHERE osm_type='N'")
|
||||
|
||||
if osmid is None:
|
||||
LOG.fatal("No data found in the database.")
|
||||
raise UsageError("No data found in the database.")
|
||||
if osmid is None:
|
||||
LOG.fatal("No data found in the database.")
|
||||
raise UsageError("No data found in the database.")
|
||||
|
||||
LOG.info("Using node id %d for timestamp lookup", osmid)
|
||||
# Get the node from the API to find the timestamp when it was created.
|
||||
|
||||
Reference in New Issue
Block a user