forked from hans/Nominatim
reduce from 3 to 2 packages
This commit is contained in:
0
src/nominatim_db/db/__init__.py
Normal file
0
src/nominatim_db/db/__init__.py
Normal file
236
src/nominatim_db/db/async_connection.py
Normal file
236
src/nominatim_db/db/async_connection.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
""" Non-blocking database connections.
|
||||
"""
|
||||
from typing import Callable, Any, Optional, Iterator, Sequence
|
||||
import logging
|
||||
import select
|
||||
import time
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import wait_select
|
||||
|
||||
# psycopg2 emits different exceptions pre and post 2.8. Detect if the new error
|
||||
# module is available and adapt the error handling accordingly.
|
||||
try:
|
||||
import psycopg2.errors # pylint: disable=no-name-in-module,import-error
|
||||
__has_psycopg2_errors__ = True
|
||||
except ImportError:
|
||||
__has_psycopg2_errors__ = False
|
||||
|
||||
from ..typing import T_cursor, Query
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
class DeadlockHandler:
|
||||
""" Context manager that catches deadlock exceptions and calls
|
||||
the given handler function. All other exceptions are passed on
|
||||
normally.
|
||||
"""
|
||||
|
||||
def __init__(self, handler: Callable[[], None], ignore_sql_errors: bool = False) -> None:
|
||||
self.handler = handler
|
||||
self.ignore_sql_errors = ignore_sql_errors
|
||||
|
||||
def __enter__(self) -> 'DeadlockHandler':
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
|
||||
if __has_psycopg2_errors__:
|
||||
if exc_type == psycopg2.errors.DeadlockDetected: # pylint: disable=E1101
|
||||
self.handler()
|
||||
return True
|
||||
elif exc_type == psycopg2.extensions.TransactionRollbackError \
|
||||
and exc_value.pgcode == '40P01':
|
||||
self.handler()
|
||||
return True
|
||||
|
||||
if self.ignore_sql_errors and isinstance(exc_value, psycopg2.Error):
|
||||
LOG.info("SQL error ignored: %s", exc_value)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class DBConnection:
|
||||
""" A single non-blocking database connection.
|
||||
"""
|
||||
|
||||
def __init__(self, dsn: str,
|
||||
cursor_factory: Optional[Callable[..., T_cursor]] = None,
|
||||
ignore_sql_errors: bool = False) -> None:
|
||||
self.dsn = dsn
|
||||
|
||||
self.current_query: Optional[Query] = None
|
||||
self.current_params: Optional[Sequence[Any]] = None
|
||||
self.ignore_sql_errors = ignore_sql_errors
|
||||
|
||||
self.conn: Optional['psycopg2._psycopg.connection'] = None
|
||||
self.cursor: Optional['psycopg2._psycopg.cursor'] = None
|
||||
self.connect(cursor_factory=cursor_factory)
|
||||
|
||||
def close(self) -> None:
|
||||
""" Close all open connections. Does not wait for pending requests.
|
||||
"""
|
||||
if self.conn is not None:
|
||||
if self.cursor is not None:
|
||||
self.cursor.close()
|
||||
self.cursor = None
|
||||
self.conn.close()
|
||||
|
||||
self.conn = None
|
||||
|
||||
def connect(self, cursor_factory: Optional[Callable[..., T_cursor]] = None) -> None:
|
||||
""" (Re)connect to the database. Creates an asynchronous connection
|
||||
with JIT and parallel processing disabled. If a connection was
|
||||
already open, it is closed and a new connection established.
|
||||
The caller must ensure that no query is pending before reconnecting.
|
||||
"""
|
||||
self.close()
|
||||
|
||||
# Use a dict to hand in the parameters because async is a reserved
|
||||
# word in Python3.
|
||||
self.conn = psycopg2.connect(**{'dsn': self.dsn, 'async': True}) # type: ignore
|
||||
assert self.conn
|
||||
self.wait()
|
||||
|
||||
if cursor_factory is not None:
|
||||
self.cursor = self.conn.cursor(cursor_factory=cursor_factory)
|
||||
else:
|
||||
self.cursor = self.conn.cursor()
|
||||
# Disable JIT and parallel workers as they are known to cause problems.
|
||||
# Update pg_settings instead of using SET because it does not yield
|
||||
# errors on older versions of Postgres where the settings are not
|
||||
# implemented.
|
||||
self.perform(
|
||||
""" UPDATE pg_settings SET setting = -1 WHERE name = 'jit_above_cost';
|
||||
UPDATE pg_settings SET setting = 0
|
||||
WHERE name = 'max_parallel_workers_per_gather';""")
|
||||
self.wait()
|
||||
|
||||
def _deadlock_handler(self) -> None:
|
||||
LOG.info("Deadlock detected (params = %s), retry.", str(self.current_params))
|
||||
assert self.cursor is not None
|
||||
assert self.current_query is not None
|
||||
assert self.current_params is not None
|
||||
|
||||
self.cursor.execute(self.current_query, self.current_params)
|
||||
|
||||
def wait(self) -> None:
|
||||
""" Block until any pending operation is done.
|
||||
"""
|
||||
while True:
|
||||
with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
|
||||
wait_select(self.conn)
|
||||
self.current_query = None
|
||||
return
|
||||
|
||||
def perform(self, sql: Query, args: Optional[Sequence[Any]] = None) -> None:
|
||||
""" Send SQL query to the server. Returns immediately without
|
||||
blocking.
|
||||
"""
|
||||
assert self.cursor is not None
|
||||
self.current_query = sql
|
||||
self.current_params = args
|
||||
self.cursor.execute(sql, args)
|
||||
|
||||
def fileno(self) -> int:
|
||||
""" File descriptor to wait for. (Makes this class select()able.)
|
||||
"""
|
||||
assert self.conn is not None
|
||||
return self.conn.fileno()
|
||||
|
||||
def is_done(self) -> bool:
|
||||
""" Check if the connection is available for a new query.
|
||||
|
||||
Also checks if the previous query has run into a deadlock.
|
||||
If so, then the previous query is repeated.
|
||||
"""
|
||||
assert self.conn is not None
|
||||
|
||||
if self.current_query is None:
|
||||
return True
|
||||
|
||||
with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
|
||||
if self.conn.poll() == psycopg2.extensions.POLL_OK:
|
||||
self.current_query = None
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class WorkerPool:
|
||||
""" A pool of asynchronous database connections.
|
||||
|
||||
The pool may be used as a context manager.
|
||||
"""
|
||||
REOPEN_CONNECTIONS_AFTER = 100000
|
||||
|
||||
def __init__(self, dsn: str, pool_size: int, ignore_sql_errors: bool = False) -> None:
|
||||
self.threads = [DBConnection(dsn, ignore_sql_errors=ignore_sql_errors)
|
||||
for _ in range(pool_size)]
|
||||
self.free_workers = self._yield_free_worker()
|
||||
self.wait_time = 0.0
|
||||
|
||||
|
||||
def finish_all(self) -> None:
|
||||
""" Wait for all connection to finish.
|
||||
"""
|
||||
for thread in self.threads:
|
||||
while not thread.is_done():
|
||||
thread.wait()
|
||||
|
||||
self.free_workers = self._yield_free_worker()
|
||||
|
||||
def close(self) -> None:
|
||||
""" Close all connections and clear the pool.
|
||||
"""
|
||||
for thread in self.threads:
|
||||
thread.close()
|
||||
self.threads = []
|
||||
self.free_workers = iter([])
|
||||
|
||||
|
||||
def next_free_worker(self) -> DBConnection:
|
||||
""" Get the next free connection.
|
||||
"""
|
||||
return next(self.free_workers)
|
||||
|
||||
|
||||
def _yield_free_worker(self) -> Iterator[DBConnection]:
|
||||
ready = self.threads
|
||||
command_stat = 0
|
||||
while True:
|
||||
for thread in ready:
|
||||
if thread.is_done():
|
||||
command_stat += 1
|
||||
yield thread
|
||||
|
||||
if command_stat > self.REOPEN_CONNECTIONS_AFTER:
|
||||
self._reconnect_threads()
|
||||
ready = self.threads
|
||||
command_stat = 0
|
||||
else:
|
||||
tstart = time.time()
|
||||
_, ready, _ = select.select([], self.threads, [])
|
||||
self.wait_time += time.time() - tstart
|
||||
|
||||
|
||||
def _reconnect_threads(self) -> None:
|
||||
for thread in self.threads:
|
||||
while not thread.is_done():
|
||||
thread.wait()
|
||||
thread.connect()
|
||||
|
||||
|
||||
def __enter__(self) -> 'WorkerPool':
|
||||
return self
|
||||
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
self.finish_all()
|
||||
self.close()
|
||||
254
src/nominatim_db/db/connection.py
Normal file
254
src/nominatim_db/db/connection.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Specialised connection and cursor functions.
|
||||
"""
|
||||
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
import psycopg2.extras
|
||||
from psycopg2 import sql as pysql
|
||||
|
||||
from ..typing import SysEnv, Query, T_cursor
|
||||
from ..errors import UsageError
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
class Cursor(psycopg2.extras.DictCursor):
|
||||
""" A cursor returning dict-like objects and providing specialised
|
||||
execution functions.
|
||||
"""
|
||||
# pylint: disable=arguments-renamed,arguments-differ
|
||||
def execute(self, query: Query, args: Any = None) -> None:
|
||||
""" Query execution that logs the SQL query when debugging is enabled.
|
||||
"""
|
||||
if LOG.isEnabledFor(logging.DEBUG):
|
||||
LOG.debug(self.mogrify(query, args).decode('utf-8'))
|
||||
|
||||
super().execute(query, args)
|
||||
|
||||
|
||||
def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
|
||||
template: Optional[Query] = None) -> None:
|
||||
""" Wrapper for the psycopg2 convenience function to execute
|
||||
SQL for a list of values.
|
||||
"""
|
||||
LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
|
||||
|
||||
psycopg2.extras.execute_values(self, sql, argslist, template=template)
|
||||
|
||||
|
||||
def scalar(self, sql: Query, args: Any = None) -> Any:
|
||||
""" Execute query that returns a single value. The value is returned.
|
||||
If the query yields more than one row, a ValueError is raised.
|
||||
"""
|
||||
self.execute(sql, args)
|
||||
|
||||
if self.rowcount != 1:
|
||||
raise RuntimeError("Query did not return a single row.")
|
||||
|
||||
result = self.fetchone()
|
||||
assert result is not None
|
||||
|
||||
return result[0]
|
||||
|
||||
|
||||
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
|
||||
""" Drop the table with the given name.
|
||||
Set `if_exists` to False if a non-existent table should raise
|
||||
an exception instead of just being ignored. If 'cascade' is set
|
||||
to True then all dependent tables are deleted as well.
|
||||
"""
|
||||
sql = 'DROP TABLE '
|
||||
if if_exists:
|
||||
sql += 'IF EXISTS '
|
||||
sql += '{}'
|
||||
if cascade:
|
||||
sql += ' CASCADE'
|
||||
|
||||
self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
|
||||
|
||||
|
||||
class Connection(psycopg2.extensions.connection):
|
||||
""" A connection that provides the specialised cursor by default and
|
||||
adds convenience functions for administrating the database.
|
||||
"""
|
||||
@overload # type: ignore[override]
|
||||
def cursor(self) -> Cursor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cursor(self, name: str) -> Cursor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
|
||||
...
|
||||
|
||||
def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore
|
||||
""" Return a new cursor. By default the specialised cursor is returned.
|
||||
"""
|
||||
return super().cursor(cursor_factory=cursor_factory, **kwargs)
|
||||
|
||||
|
||||
def table_exists(self, table: str) -> bool:
|
||||
""" Check that a table with the given name exists in the database.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
num = cur.scalar("""SELECT count(*) FROM pg_tables
|
||||
WHERE tablename = %s and schemaname = 'public'""", (table, ))
|
||||
return num == 1 if isinstance(num, int) else False
|
||||
|
||||
|
||||
def table_has_column(self, table: str, column: str) -> bool:
|
||||
""" Check if the table 'table' exists and has a column with name 'column'.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
|
||||
WHERE table_name = %s
|
||||
and column_name = %s""",
|
||||
(table, column))
|
||||
return has_column > 0 if isinstance(has_column, int) else False
|
||||
|
||||
|
||||
def index_exists(self, index: str, table: Optional[str] = None) -> bool:
|
||||
""" Check that an index with the given name exists in the database.
|
||||
If table is not None then the index must relate to the given
|
||||
table.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute("""SELECT tablename FROM pg_indexes
|
||||
WHERE indexname = %s and schemaname = 'public'""", (index, ))
|
||||
if cur.rowcount == 0:
|
||||
return False
|
||||
|
||||
if table is not None:
|
||||
row = cur.fetchone()
|
||||
if row is None or not isinstance(row[0], str):
|
||||
return False
|
||||
return row[0] == table
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
|
||||
""" Drop the table with the given name.
|
||||
Set `if_exists` to False if a non-existent table should raise
|
||||
an exception instead of just being ignored.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.drop_table(name, if_exists, cascade)
|
||||
self.commit()
|
||||
|
||||
|
||||
def server_version_tuple(self) -> Tuple[int, int]:
|
||||
""" Return the server version as a tuple of (major, minor).
|
||||
Converts correctly for pre-10 and post-10 PostgreSQL versions.
|
||||
"""
|
||||
version = self.server_version
|
||||
if version < 100000:
|
||||
return (int(version / 10000), int((version % 10000) / 100))
|
||||
|
||||
return (int(version / 10000), version % 10000)
|
||||
|
||||
|
||||
def postgis_version_tuple(self) -> Tuple[int, int]:
|
||||
""" Return the postgis version installed in the database as a
|
||||
tuple of (major, minor). Assumes that the PostGIS extension
|
||||
has been installed already.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
version = cur.scalar('SELECT postgis_lib_version()')
|
||||
|
||||
version_parts = version.split('.')
|
||||
if len(version_parts) < 2:
|
||||
raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
|
||||
|
||||
return (int(version_parts[0]), int(version_parts[1]))
|
||||
|
||||
|
||||
def extension_loaded(self, extension_name: str) -> bool:
|
||||
""" Return True if the hstore extension is loaded in the database.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, ))
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
class ConnectionContext(ContextManager[Connection]):
|
||||
""" Context manager of the connection that also provides direct access
|
||||
to the underlying connection.
|
||||
"""
|
||||
connection: Connection
|
||||
|
||||
def connect(dsn: str) -> ConnectionContext:
|
||||
""" Open a connection to the database using the specialised connection
|
||||
factory. The returned object may be used in conjunction with 'with'.
|
||||
When used outside a context manager, use the `connection` attribute
|
||||
to get the connection.
|
||||
"""
|
||||
try:
|
||||
conn = psycopg2.connect(dsn, connection_factory=Connection)
|
||||
ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
|
||||
ctxmgr.connection = conn
|
||||
return ctxmgr
|
||||
except psycopg2.OperationalError as err:
|
||||
raise UsageError(f"Cannot connect to database: {err}") from err
|
||||
|
||||
|
||||
# Translation from PG connection string parameters to PG environment variables.
|
||||
# Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
|
||||
_PG_CONNECTION_STRINGS = {
|
||||
'host': 'PGHOST',
|
||||
'hostaddr': 'PGHOSTADDR',
|
||||
'port': 'PGPORT',
|
||||
'dbname': 'PGDATABASE',
|
||||
'user': 'PGUSER',
|
||||
'password': 'PGPASSWORD',
|
||||
'passfile': 'PGPASSFILE',
|
||||
'channel_binding': 'PGCHANNELBINDING',
|
||||
'service': 'PGSERVICE',
|
||||
'options': 'PGOPTIONS',
|
||||
'application_name': 'PGAPPNAME',
|
||||
'sslmode': 'PGSSLMODE',
|
||||
'requiressl': 'PGREQUIRESSL',
|
||||
'sslcompression': 'PGSSLCOMPRESSION',
|
||||
'sslcert': 'PGSSLCERT',
|
||||
'sslkey': 'PGSSLKEY',
|
||||
'sslrootcert': 'PGSSLROOTCERT',
|
||||
'sslcrl': 'PGSSLCRL',
|
||||
'requirepeer': 'PGREQUIREPEER',
|
||||
'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
|
||||
'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
|
||||
'gssencmode': 'PGGSSENCMODE',
|
||||
'krbsrvname': 'PGKRBSRVNAME',
|
||||
'gsslib': 'PGGSSLIB',
|
||||
'connect_timeout': 'PGCONNECT_TIMEOUT',
|
||||
'target_session_attrs': 'PGTARGETSESSIONATTRS',
|
||||
}
|
||||
|
||||
|
||||
def get_pg_env(dsn: str,
|
||||
base_env: Optional[SysEnv] = None) -> Dict[str, str]:
|
||||
""" Return a copy of `base_env` with the environment variables for
|
||||
PostgreSQL set up from the given database connection string.
|
||||
If `base_env` is None, then the OS environment is used as a base
|
||||
environment.
|
||||
"""
|
||||
env = dict(base_env if base_env is not None else os.environ)
|
||||
|
||||
for param, value in psycopg2.extensions.parse_dsn(dsn).items():
|
||||
if param in _PG_CONNECTION_STRINGS:
|
||||
env[_PG_CONNECTION_STRINGS[param]] = value
|
||||
else:
|
||||
LOG.error("Unknown connection parameter '%s' ignored.", param)
|
||||
|
||||
return env
|
||||
47
src/nominatim_db/db/properties.py
Normal file
47
src/nominatim_db/db/properties.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Query and access functions for the in-database property table.
|
||||
"""
|
||||
from typing import Optional, cast
|
||||
|
||||
from .connection import Connection
|
||||
|
||||
def set_property(conn: Connection, name: str, value: str) -> None:
|
||||
""" Add or replace the property with the given name.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('SELECT value FROM nominatim_properties WHERE property = %s',
|
||||
(name, ))
|
||||
|
||||
if cur.rowcount == 0:
|
||||
sql = 'INSERT INTO nominatim_properties (value, property) VALUES (%s, %s)'
|
||||
else:
|
||||
sql = 'UPDATE nominatim_properties SET value = %s WHERE property = %s'
|
||||
|
||||
cur.execute(sql, (value, name))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_property(conn: Connection, name: str) -> Optional[str]:
|
||||
""" Return the current value of the given property or None if the property
|
||||
is not set.
|
||||
"""
|
||||
if not conn.table_exists('nominatim_properties'):
|
||||
return None
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('SELECT value FROM nominatim_properties WHERE property = %s',
|
||||
(name, ))
|
||||
|
||||
if cur.rowcount == 0:
|
||||
return None
|
||||
|
||||
result = cur.fetchone()
|
||||
assert result is not None
|
||||
|
||||
return cast(Optional[str], result[0])
|
||||
143
src/nominatim_db/db/sql_preprocessor.py
Normal file
143
src/nominatim_db/db/sql_preprocessor.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Preprocessing of SQL files.
|
||||
"""
|
||||
from typing import Set, Dict, Any, cast
|
||||
import jinja2
|
||||
|
||||
from .connection import Connection
|
||||
from .async_connection import WorkerPool
|
||||
from ..config import Configuration
|
||||
|
||||
def _get_partitions(conn: Connection) -> Set[int]:
|
||||
""" Get the set of partitions currently in use.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('SELECT DISTINCT partition FROM country_name')
|
||||
partitions = set([0])
|
||||
for row in cur:
|
||||
partitions.add(row[0])
|
||||
|
||||
return partitions
|
||||
|
||||
|
||||
def _get_tables(conn: Connection) -> Set[str]:
|
||||
""" Return the set of tables currently in use.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT tablename FROM pg_tables WHERE schemaname = 'public'")
|
||||
|
||||
return set((row[0] for row in list(cur)))
|
||||
|
||||
def _get_middle_db_format(conn: Connection, tables: Set[str]) -> str:
|
||||
""" Returns the version of the slim middle tables.
|
||||
"""
|
||||
if 'osm2pgsql_properties' not in tables:
|
||||
return '1'
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
|
||||
row = cur.fetchone()
|
||||
|
||||
return cast(str, row[0]) if row is not None else '1'
|
||||
|
||||
|
||||
def _setup_tablespace_sql(config: Configuration) -> Dict[str, str]:
|
||||
""" Returns a dict with tablespace expressions for the different tablespace
|
||||
kinds depending on whether a tablespace is configured or not.
|
||||
"""
|
||||
out = {}
|
||||
for subset in ('ADDRESS', 'SEARCH', 'AUX'):
|
||||
for kind in ('DATA', 'INDEX'):
|
||||
tspace = getattr(config, f'TABLESPACE_{subset}_{kind}')
|
||||
if tspace:
|
||||
tspace = f'TABLESPACE "{tspace}"'
|
||||
out[f'{subset.lower()}_{kind.lower()}'] = tspace
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
|
||||
""" Set up a dictionary with various optional Postgresql/Postgis features that
|
||||
depend on the database version.
|
||||
"""
|
||||
pg_version = conn.server_version_tuple()
|
||||
postgis_version = conn.postgis_version_tuple()
|
||||
pg11plus = pg_version >= (11, 0, 0)
|
||||
ps3 = postgis_version >= (3, 0)
|
||||
return {
|
||||
'has_index_non_key_column': pg11plus,
|
||||
'spgist_geom' : 'SPGIST' if pg11plus and ps3 else 'GIST'
|
||||
}
|
||||
|
||||
class SQLPreprocessor:
|
||||
""" A environment for preprocessing SQL files from the
|
||||
lib-sql directory.
|
||||
|
||||
The preprocessor provides a number of default filters and variables.
|
||||
The variables may be overwritten when rendering an SQL file.
|
||||
|
||||
The preprocessing is currently based on the jinja2 templating library
|
||||
and follows its syntax.
|
||||
"""
|
||||
|
||||
def __init__(self, conn: Connection, config: Configuration) -> None:
|
||||
self.env = jinja2.Environment(autoescape=False,
|
||||
loader=jinja2.FileSystemLoader(str(config.lib_dir.sql)))
|
||||
|
||||
db_info: Dict[str, Any] = {}
|
||||
db_info['partitions'] = _get_partitions(conn)
|
||||
db_info['tables'] = _get_tables(conn)
|
||||
db_info['reverse_only'] = 'search_name' not in db_info['tables']
|
||||
db_info['tablespace'] = _setup_tablespace_sql(config)
|
||||
db_info['middle_db_format'] = _get_middle_db_format(conn, db_info['tables'])
|
||||
|
||||
self.env.globals['config'] = config
|
||||
self.env.globals['db'] = db_info
|
||||
self.env.globals['postgres'] = _setup_postgresql_features(conn)
|
||||
|
||||
|
||||
def run_string(self, conn: Connection, template: str, **kwargs: Any) -> None:
|
||||
""" Execute the given SQL template string on the connection.
|
||||
The keyword arguments may supply additional parameters
|
||||
for preprocessing.
|
||||
"""
|
||||
sql = self.env.from_string(template).render(**kwargs)
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def run_sql_file(self, conn: Connection, name: str, **kwargs: Any) -> None:
|
||||
""" Execute the given SQL file on the connection. The keyword arguments
|
||||
may supply additional parameters for preprocessing.
|
||||
"""
|
||||
sql = self.env.get_template(name).render(**kwargs)
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
|
||||
**kwargs: Any) -> None:
|
||||
""" Execute the given SQL files using parallel asynchronous connections.
|
||||
The keyword arguments may supply additional parameters for
|
||||
preprocessing.
|
||||
|
||||
After preprocessing the SQL code is cut at lines containing only
|
||||
'---'. Each chunk is sent to one of the `num_threads` workers.
|
||||
"""
|
||||
sql = self.env.get_template(name).render(**kwargs)
|
||||
|
||||
parts = sql.split('\n---\n')
|
||||
|
||||
with WorkerPool(dsn, num_threads) as pool:
|
||||
for part in parts:
|
||||
pool.next_free_worker().perform(part)
|
||||
127
src/nominatim_db/db/status.py
Normal file
127
src/nominatim_db/db/status.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Access and helper functions for the status and status log table.
|
||||
"""
|
||||
from typing import Optional, Tuple, cast
|
||||
import datetime as dt
|
||||
import logging
|
||||
import re
|
||||
|
||||
from .connection import Connection
|
||||
from ..utils.url_utils import get_url
|
||||
from ..errors import UsageError
|
||||
from ..typing import TypedDict
|
||||
|
||||
LOG = logging.getLogger()
|
||||
ISODATE_FORMAT = '%Y-%m-%dT%H:%M:%S'
|
||||
|
||||
|
||||
class StatusRow(TypedDict):
|
||||
""" Dictionary of columns of the import_status table.
|
||||
"""
|
||||
lastimportdate: dt.datetime
|
||||
sequence_id: Optional[int]
|
||||
indexed: Optional[bool]
|
||||
|
||||
|
||||
def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetime:
|
||||
""" Determine the date of the database from the newest object in the
|
||||
data base.
|
||||
"""
|
||||
# If there is a date from osm2pgsql available, use that.
|
||||
if conn.table_exists('osm2pgsql_properties'):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(""" SELECT value FROM osm2pgsql_properties
|
||||
WHERE property = 'current_timestamp' """)
|
||||
row = cur.fetchone()
|
||||
if row is not None:
|
||||
return dt.datetime.strptime(row[0], "%Y-%m-%dT%H:%M:%SZ")\
|
||||
.replace(tzinfo=dt.timezone.utc)
|
||||
|
||||
if offline:
|
||||
raise UsageError("Cannot determine database date from data in offline mode.")
|
||||
|
||||
# Else, find the node with the highest ID in the database
|
||||
with conn.cursor() as cur:
|
||||
if conn.table_exists('place'):
|
||||
osmid = cur.scalar("SELECT max(osm_id) FROM place WHERE osm_type='N'")
|
||||
else:
|
||||
osmid = cur.scalar("SELECT max(osm_id) FROM placex WHERE osm_type='N'")
|
||||
|
||||
if osmid is None:
|
||||
LOG.fatal("No data found in the database.")
|
||||
raise UsageError("No data found in the database.")
|
||||
|
||||
LOG.info("Using node id %d for timestamp lookup", osmid)
|
||||
# Get the node from the API to find the timestamp when it was created.
|
||||
node_url = f'https://www.openstreetmap.org/api/0.6/node/{osmid}/1'
|
||||
data = get_url(node_url)
|
||||
|
||||
match = re.search(r'timestamp="((\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2}):(\d{2}))Z"', data)
|
||||
|
||||
if match is None:
|
||||
LOG.fatal("The node data downloaded from the API does not contain valid data.\n"
|
||||
"URL used: %s", node_url)
|
||||
raise UsageError("Bad API data.")
|
||||
|
||||
LOG.debug("Found timestamp %s", match.group(1))
|
||||
|
||||
return dt.datetime.strptime(match.group(1), ISODATE_FORMAT).replace(tzinfo=dt.timezone.utc)
|
||||
|
||||
|
||||
def set_status(conn: Connection, date: Optional[dt.datetime],
|
||||
seq: Optional[int] = None, indexed: bool = True) -> None:
|
||||
""" Replace the current status with the given status. If date is `None`
|
||||
then only sequence and indexed will be updated as given. Otherwise
|
||||
the whole status is replaced.
|
||||
The change will be committed to the database.
|
||||
"""
|
||||
assert date is None or date.tzinfo == dt.timezone.utc
|
||||
with conn.cursor() as cur:
|
||||
if date is None:
|
||||
cur.execute("UPDATE import_status set sequence_id = %s, indexed = %s",
|
||||
(seq, indexed))
|
||||
else:
|
||||
cur.execute("TRUNCATE TABLE import_status")
|
||||
cur.execute("""INSERT INTO import_status (lastimportdate, sequence_id, indexed)
|
||||
VALUES (%s, %s, %s)""", (date, seq, indexed))
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_status(conn: Connection) -> Tuple[Optional[dt.datetime], Optional[int], Optional[bool]]:
|
||||
""" Return the current status as a triple of (date, sequence, indexed).
|
||||
If status has not been set up yet, a triple of None is returned.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT * FROM import_status LIMIT 1")
|
||||
if cur.rowcount < 1:
|
||||
return None, None, None
|
||||
|
||||
row = cast(StatusRow, cur.fetchone())
|
||||
return row['lastimportdate'], row['sequence_id'], row['indexed']
|
||||
|
||||
|
||||
def set_indexed(conn: Connection, state: bool) -> None:
|
||||
""" Set the indexed flag in the status table to the given state.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("UPDATE import_status SET indexed = %s", (state, ))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def log_status(conn: Connection, start: dt.datetime,
|
||||
event: str, batchsize: Optional[int] = None) -> None:
|
||||
""" Write a new status line to the `import_osmosis_log` table.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""INSERT INTO import_osmosis_log
|
||||
(batchend, batchseq, batchsize, starttime, endtime, event)
|
||||
SELECT lastimportdate, sequence_id, %s, %s, now(), %s FROM import_status""",
|
||||
(batchsize, start, event))
|
||||
conn.commit()
|
||||
129
src/nominatim_db/db/utils.py
Normal file
129
src/nominatim_db/db/utils.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Helper functions for handling DB accesses.
|
||||
"""
|
||||
from typing import IO, Optional, Union, Any, Iterable
|
||||
import subprocess
|
||||
import logging
|
||||
import gzip
|
||||
import io
|
||||
from pathlib import Path
|
||||
|
||||
from .connection import get_pg_env, Cursor
|
||||
from ..errors import UsageError
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
def _pipe_to_proc(proc: 'subprocess.Popen[bytes]',
|
||||
fdesc: Union[IO[bytes], gzip.GzipFile]) -> int:
|
||||
assert proc.stdin is not None
|
||||
chunk = fdesc.read(2048)
|
||||
while chunk and proc.poll() is None:
|
||||
try:
|
||||
proc.stdin.write(chunk)
|
||||
except BrokenPipeError as exc:
|
||||
raise UsageError("Failed to execute SQL file.") from exc
|
||||
chunk = fdesc.read(2048)
|
||||
|
||||
return len(chunk)
|
||||
|
||||
def execute_file(dsn: str, fname: Path,
|
||||
ignore_errors: bool = False,
|
||||
pre_code: Optional[str] = None,
|
||||
post_code: Optional[str] = None) -> None:
|
||||
""" Read an SQL file and run its contents against the given database
|
||||
using psql. Use `pre_code` and `post_code` to run extra commands
|
||||
before or after executing the file. The commands are run within the
|
||||
same session, so they may be used to wrap the file execution in a
|
||||
transaction.
|
||||
"""
|
||||
cmd = ['psql']
|
||||
if not ignore_errors:
|
||||
cmd.extend(('-v', 'ON_ERROR_STOP=1'))
|
||||
if not LOG.isEnabledFor(logging.INFO):
|
||||
cmd.append('--quiet')
|
||||
|
||||
with subprocess.Popen(cmd, env=get_pg_env(dsn), stdin=subprocess.PIPE) as proc:
|
||||
assert proc.stdin is not None
|
||||
try:
|
||||
if not LOG.isEnabledFor(logging.INFO):
|
||||
proc.stdin.write('set client_min_messages to WARNING;'.encode('utf-8'))
|
||||
|
||||
if pre_code:
|
||||
proc.stdin.write((pre_code + ';').encode('utf-8'))
|
||||
|
||||
if fname.suffix == '.gz':
|
||||
with gzip.open(str(fname), 'rb') as fdesc:
|
||||
remain = _pipe_to_proc(proc, fdesc)
|
||||
else:
|
||||
with fname.open('rb') as fdesc:
|
||||
remain = _pipe_to_proc(proc, fdesc)
|
||||
|
||||
if remain == 0 and post_code:
|
||||
proc.stdin.write((';' + post_code).encode('utf-8'))
|
||||
finally:
|
||||
proc.stdin.close()
|
||||
ret = proc.wait()
|
||||
|
||||
if ret != 0 or remain > 0:
|
||||
raise UsageError("Failed to execute SQL file.")
|
||||
|
||||
|
||||
# List of characters that need to be quoted for the copy command.
|
||||
_SQL_TRANSLATION = {ord('\\'): '\\\\',
|
||||
ord('\t'): '\\t',
|
||||
ord('\n'): '\\n'}
|
||||
|
||||
|
||||
class CopyBuffer:
|
||||
""" Data collector for the copy_from command.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.buffer = io.StringIO()
|
||||
|
||||
|
||||
def __enter__(self) -> 'CopyBuffer':
|
||||
return self
|
||||
|
||||
|
||||
def size(self) -> int:
|
||||
""" Return the number of bytes the buffer currently contains.
|
||||
"""
|
||||
return self.buffer.tell()
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
if self.buffer is not None:
|
||||
self.buffer.close()
|
||||
|
||||
|
||||
def add(self, *data: Any) -> None:
|
||||
""" Add another row of data to the copy buffer.
|
||||
"""
|
||||
first = True
|
||||
for column in data:
|
||||
if first:
|
||||
first = False
|
||||
else:
|
||||
self.buffer.write('\t')
|
||||
if column is None:
|
||||
self.buffer.write('\\N')
|
||||
else:
|
||||
self.buffer.write(str(column).translate(_SQL_TRANSLATION))
|
||||
self.buffer.write('\n')
|
||||
|
||||
|
||||
def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None:
|
||||
""" Copy all collected data into the given table.
|
||||
|
||||
The buffer is empty and reusable after this operation.
|
||||
"""
|
||||
if self.buffer.tell() > 0:
|
||||
self.buffer.seek(0)
|
||||
cur.copy_from(self.buffer, table, columns=columns)
|
||||
self.buffer = io.StringIO()
|
||||
Reference in New Issue
Block a user