reduce from 3 to 2 packages

This commit is contained in:
Sarah Hoffmann
2024-06-27 21:26:12 +02:00
parent 139cea5720
commit 4da4cbfe27
149 changed files with 419 additions and 422 deletions

View File

View File

@@ -0,0 +1,236 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
""" Non-blocking database connections.
"""
from typing import Callable, Any, Optional, Iterator, Sequence
import logging
import select
import time
import psycopg2
from psycopg2.extras import wait_select
# psycopg2 emits different exceptions pre and post 2.8. Detect if the new error
# module is available and adapt the error handling accordingly.
try:
import psycopg2.errors # pylint: disable=no-name-in-module,import-error
__has_psycopg2_errors__ = True
except ImportError:
__has_psycopg2_errors__ = False
from ..typing import T_cursor, Query
LOG = logging.getLogger()
class DeadlockHandler:
""" Context manager that catches deadlock exceptions and calls
the given handler function. All other exceptions are passed on
normally.
"""
def __init__(self, handler: Callable[[], None], ignore_sql_errors: bool = False) -> None:
self.handler = handler
self.ignore_sql_errors = ignore_sql_errors
def __enter__(self) -> 'DeadlockHandler':
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
if __has_psycopg2_errors__:
if exc_type == psycopg2.errors.DeadlockDetected: # pylint: disable=E1101
self.handler()
return True
elif exc_type == psycopg2.extensions.TransactionRollbackError \
and exc_value.pgcode == '40P01':
self.handler()
return True
if self.ignore_sql_errors and isinstance(exc_value, psycopg2.Error):
LOG.info("SQL error ignored: %s", exc_value)
return True
return False
class DBConnection:
""" A single non-blocking database connection.
"""
def __init__(self, dsn: str,
cursor_factory: Optional[Callable[..., T_cursor]] = None,
ignore_sql_errors: bool = False) -> None:
self.dsn = dsn
self.current_query: Optional[Query] = None
self.current_params: Optional[Sequence[Any]] = None
self.ignore_sql_errors = ignore_sql_errors
self.conn: Optional['psycopg2._psycopg.connection'] = None
self.cursor: Optional['psycopg2._psycopg.cursor'] = None
self.connect(cursor_factory=cursor_factory)
def close(self) -> None:
""" Close all open connections. Does not wait for pending requests.
"""
if self.conn is not None:
if self.cursor is not None:
self.cursor.close()
self.cursor = None
self.conn.close()
self.conn = None
def connect(self, cursor_factory: Optional[Callable[..., T_cursor]] = None) -> None:
""" (Re)connect to the database. Creates an asynchronous connection
with JIT and parallel processing disabled. If a connection was
already open, it is closed and a new connection established.
The caller must ensure that no query is pending before reconnecting.
"""
self.close()
# Use a dict to hand in the parameters because async is a reserved
# word in Python3.
self.conn = psycopg2.connect(**{'dsn': self.dsn, 'async': True}) # type: ignore
assert self.conn
self.wait()
if cursor_factory is not None:
self.cursor = self.conn.cursor(cursor_factory=cursor_factory)
else:
self.cursor = self.conn.cursor()
# Disable JIT and parallel workers as they are known to cause problems.
# Update pg_settings instead of using SET because it does not yield
# errors on older versions of Postgres where the settings are not
# implemented.
self.perform(
""" UPDATE pg_settings SET setting = -1 WHERE name = 'jit_above_cost';
UPDATE pg_settings SET setting = 0
WHERE name = 'max_parallel_workers_per_gather';""")
self.wait()
def _deadlock_handler(self) -> None:
LOG.info("Deadlock detected (params = %s), retry.", str(self.current_params))
assert self.cursor is not None
assert self.current_query is not None
assert self.current_params is not None
self.cursor.execute(self.current_query, self.current_params)
def wait(self) -> None:
""" Block until any pending operation is done.
"""
while True:
with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
wait_select(self.conn)
self.current_query = None
return
def perform(self, sql: Query, args: Optional[Sequence[Any]] = None) -> None:
""" Send SQL query to the server. Returns immediately without
blocking.
"""
assert self.cursor is not None
self.current_query = sql
self.current_params = args
self.cursor.execute(sql, args)
def fileno(self) -> int:
""" File descriptor to wait for. (Makes this class select()able.)
"""
assert self.conn is not None
return self.conn.fileno()
def is_done(self) -> bool:
""" Check if the connection is available for a new query.
Also checks if the previous query has run into a deadlock.
If so, then the previous query is repeated.
"""
assert self.conn is not None
if self.current_query is None:
return True
with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
if self.conn.poll() == psycopg2.extensions.POLL_OK:
self.current_query = None
return True
return False
class WorkerPool:
""" A pool of asynchronous database connections.
The pool may be used as a context manager.
"""
REOPEN_CONNECTIONS_AFTER = 100000
def __init__(self, dsn: str, pool_size: int, ignore_sql_errors: bool = False) -> None:
self.threads = [DBConnection(dsn, ignore_sql_errors=ignore_sql_errors)
for _ in range(pool_size)]
self.free_workers = self._yield_free_worker()
self.wait_time = 0.0
def finish_all(self) -> None:
""" Wait for all connection to finish.
"""
for thread in self.threads:
while not thread.is_done():
thread.wait()
self.free_workers = self._yield_free_worker()
def close(self) -> None:
""" Close all connections and clear the pool.
"""
for thread in self.threads:
thread.close()
self.threads = []
self.free_workers = iter([])
def next_free_worker(self) -> DBConnection:
""" Get the next free connection.
"""
return next(self.free_workers)
def _yield_free_worker(self) -> Iterator[DBConnection]:
ready = self.threads
command_stat = 0
while True:
for thread in ready:
if thread.is_done():
command_stat += 1
yield thread
if command_stat > self.REOPEN_CONNECTIONS_AFTER:
self._reconnect_threads()
ready = self.threads
command_stat = 0
else:
tstart = time.time()
_, ready, _ = select.select([], self.threads, [])
self.wait_time += time.time() - tstart
def _reconnect_threads(self) -> None:
for thread in self.threads:
while not thread.is_done():
thread.wait()
thread.connect()
def __enter__(self) -> 'WorkerPool':
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.finish_all()
self.close()

View File

@@ -0,0 +1,254 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Specialised connection and cursor functions.
"""
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
import contextlib
import logging
import os
import psycopg2
import psycopg2.extensions
import psycopg2.extras
from psycopg2 import sql as pysql
from ..typing import SysEnv, Query, T_cursor
from ..errors import UsageError
LOG = logging.getLogger()
class Cursor(psycopg2.extras.DictCursor):
""" A cursor returning dict-like objects and providing specialised
execution functions.
"""
# pylint: disable=arguments-renamed,arguments-differ
def execute(self, query: Query, args: Any = None) -> None:
""" Query execution that logs the SQL query when debugging is enabled.
"""
if LOG.isEnabledFor(logging.DEBUG):
LOG.debug(self.mogrify(query, args).decode('utf-8'))
super().execute(query, args)
def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
template: Optional[Query] = None) -> None:
""" Wrapper for the psycopg2 convenience function to execute
SQL for a list of values.
"""
LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
psycopg2.extras.execute_values(self, sql, argslist, template=template)
def scalar(self, sql: Query, args: Any = None) -> Any:
""" Execute query that returns a single value. The value is returned.
If the query yields more than one row, a ValueError is raised.
"""
self.execute(sql, args)
if self.rowcount != 1:
raise RuntimeError("Query did not return a single row.")
result = self.fetchone()
assert result is not None
return result[0]
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
""" Drop the table with the given name.
Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored. If 'cascade' is set
to True then all dependent tables are deleted as well.
"""
sql = 'DROP TABLE '
if if_exists:
sql += 'IF EXISTS '
sql += '{}'
if cascade:
sql += ' CASCADE'
self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
class Connection(psycopg2.extensions.connection):
""" A connection that provides the specialised cursor by default and
adds convenience functions for administrating the database.
"""
@overload # type: ignore[override]
def cursor(self) -> Cursor:
...
@overload
def cursor(self, name: str) -> Cursor:
...
@overload
def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
...
def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore
""" Return a new cursor. By default the specialised cursor is returned.
"""
return super().cursor(cursor_factory=cursor_factory, **kwargs)
def table_exists(self, table: str) -> bool:
""" Check that a table with the given name exists in the database.
"""
with self.cursor() as cur:
num = cur.scalar("""SELECT count(*) FROM pg_tables
WHERE tablename = %s and schemaname = 'public'""", (table, ))
return num == 1 if isinstance(num, int) else False
def table_has_column(self, table: str, column: str) -> bool:
""" Check if the table 'table' exists and has a column with name 'column'.
"""
with self.cursor() as cur:
has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
WHERE table_name = %s
and column_name = %s""",
(table, column))
return has_column > 0 if isinstance(has_column, int) else False
def index_exists(self, index: str, table: Optional[str] = None) -> bool:
""" Check that an index with the given name exists in the database.
If table is not None then the index must relate to the given
table.
"""
with self.cursor() as cur:
cur.execute("""SELECT tablename FROM pg_indexes
WHERE indexname = %s and schemaname = 'public'""", (index, ))
if cur.rowcount == 0:
return False
if table is not None:
row = cur.fetchone()
if row is None or not isinstance(row[0], str):
return False
return row[0] == table
return True
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
""" Drop the table with the given name.
Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored.
"""
with self.cursor() as cur:
cur.drop_table(name, if_exists, cascade)
self.commit()
def server_version_tuple(self) -> Tuple[int, int]:
""" Return the server version as a tuple of (major, minor).
Converts correctly for pre-10 and post-10 PostgreSQL versions.
"""
version = self.server_version
if version < 100000:
return (int(version / 10000), int((version % 10000) / 100))
return (int(version / 10000), version % 10000)
def postgis_version_tuple(self) -> Tuple[int, int]:
""" Return the postgis version installed in the database as a
tuple of (major, minor). Assumes that the PostGIS extension
has been installed already.
"""
with self.cursor() as cur:
version = cur.scalar('SELECT postgis_lib_version()')
version_parts = version.split('.')
if len(version_parts) < 2:
raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
return (int(version_parts[0]), int(version_parts[1]))
def extension_loaded(self, extension_name: str) -> bool:
""" Return True if the hstore extension is loaded in the database.
"""
with self.cursor() as cur:
cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, ))
return cur.rowcount > 0
class ConnectionContext(ContextManager[Connection]):
""" Context manager of the connection that also provides direct access
to the underlying connection.
"""
connection: Connection
def connect(dsn: str) -> ConnectionContext:
""" Open a connection to the database using the specialised connection
factory. The returned object may be used in conjunction with 'with'.
When used outside a context manager, use the `connection` attribute
to get the connection.
"""
try:
conn = psycopg2.connect(dsn, connection_factory=Connection)
ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
ctxmgr.connection = conn
return ctxmgr
except psycopg2.OperationalError as err:
raise UsageError(f"Cannot connect to database: {err}") from err
# Translation from PG connection string parameters to PG environment variables.
# Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
_PG_CONNECTION_STRINGS = {
'host': 'PGHOST',
'hostaddr': 'PGHOSTADDR',
'port': 'PGPORT',
'dbname': 'PGDATABASE',
'user': 'PGUSER',
'password': 'PGPASSWORD',
'passfile': 'PGPASSFILE',
'channel_binding': 'PGCHANNELBINDING',
'service': 'PGSERVICE',
'options': 'PGOPTIONS',
'application_name': 'PGAPPNAME',
'sslmode': 'PGSSLMODE',
'requiressl': 'PGREQUIRESSL',
'sslcompression': 'PGSSLCOMPRESSION',
'sslcert': 'PGSSLCERT',
'sslkey': 'PGSSLKEY',
'sslrootcert': 'PGSSLROOTCERT',
'sslcrl': 'PGSSLCRL',
'requirepeer': 'PGREQUIREPEER',
'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
'gssencmode': 'PGGSSENCMODE',
'krbsrvname': 'PGKRBSRVNAME',
'gsslib': 'PGGSSLIB',
'connect_timeout': 'PGCONNECT_TIMEOUT',
'target_session_attrs': 'PGTARGETSESSIONATTRS',
}
def get_pg_env(dsn: str,
base_env: Optional[SysEnv] = None) -> Dict[str, str]:
""" Return a copy of `base_env` with the environment variables for
PostgreSQL set up from the given database connection string.
If `base_env` is None, then the OS environment is used as a base
environment.
"""
env = dict(base_env if base_env is not None else os.environ)
for param, value in psycopg2.extensions.parse_dsn(dsn).items():
if param in _PG_CONNECTION_STRINGS:
env[_PG_CONNECTION_STRINGS[param]] = value
else:
LOG.error("Unknown connection parameter '%s' ignored.", param)
return env

View File

@@ -0,0 +1,47 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Query and access functions for the in-database property table.
"""
from typing import Optional, cast
from .connection import Connection
def set_property(conn: Connection, name: str, value: str) -> None:
""" Add or replace the property with the given name.
"""
with conn.cursor() as cur:
cur.execute('SELECT value FROM nominatim_properties WHERE property = %s',
(name, ))
if cur.rowcount == 0:
sql = 'INSERT INTO nominatim_properties (value, property) VALUES (%s, %s)'
else:
sql = 'UPDATE nominatim_properties SET value = %s WHERE property = %s'
cur.execute(sql, (value, name))
conn.commit()
def get_property(conn: Connection, name: str) -> Optional[str]:
""" Return the current value of the given property or None if the property
is not set.
"""
if not conn.table_exists('nominatim_properties'):
return None
with conn.cursor() as cur:
cur.execute('SELECT value FROM nominatim_properties WHERE property = %s',
(name, ))
if cur.rowcount == 0:
return None
result = cur.fetchone()
assert result is not None
return cast(Optional[str], result[0])

View File

@@ -0,0 +1,143 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Preprocessing of SQL files.
"""
from typing import Set, Dict, Any, cast
import jinja2
from .connection import Connection
from .async_connection import WorkerPool
from ..config import Configuration
def _get_partitions(conn: Connection) -> Set[int]:
""" Get the set of partitions currently in use.
"""
with conn.cursor() as cur:
cur.execute('SELECT DISTINCT partition FROM country_name')
partitions = set([0])
for row in cur:
partitions.add(row[0])
return partitions
def _get_tables(conn: Connection) -> Set[str]:
""" Return the set of tables currently in use.
"""
with conn.cursor() as cur:
cur.execute("SELECT tablename FROM pg_tables WHERE schemaname = 'public'")
return set((row[0] for row in list(cur)))
def _get_middle_db_format(conn: Connection, tables: Set[str]) -> str:
""" Returns the version of the slim middle tables.
"""
if 'osm2pgsql_properties' not in tables:
return '1'
with conn.cursor() as cur:
cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
row = cur.fetchone()
return cast(str, row[0]) if row is not None else '1'
def _setup_tablespace_sql(config: Configuration) -> Dict[str, str]:
""" Returns a dict with tablespace expressions for the different tablespace
kinds depending on whether a tablespace is configured or not.
"""
out = {}
for subset in ('ADDRESS', 'SEARCH', 'AUX'):
for kind in ('DATA', 'INDEX'):
tspace = getattr(config, f'TABLESPACE_{subset}_{kind}')
if tspace:
tspace = f'TABLESPACE "{tspace}"'
out[f'{subset.lower()}_{kind.lower()}'] = tspace
return out
def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
""" Set up a dictionary with various optional Postgresql/Postgis features that
depend on the database version.
"""
pg_version = conn.server_version_tuple()
postgis_version = conn.postgis_version_tuple()
pg11plus = pg_version >= (11, 0, 0)
ps3 = postgis_version >= (3, 0)
return {
'has_index_non_key_column': pg11plus,
'spgist_geom' : 'SPGIST' if pg11plus and ps3 else 'GIST'
}
class SQLPreprocessor:
""" A environment for preprocessing SQL files from the
lib-sql directory.
The preprocessor provides a number of default filters and variables.
The variables may be overwritten when rendering an SQL file.
The preprocessing is currently based on the jinja2 templating library
and follows its syntax.
"""
def __init__(self, conn: Connection, config: Configuration) -> None:
self.env = jinja2.Environment(autoescape=False,
loader=jinja2.FileSystemLoader(str(config.lib_dir.sql)))
db_info: Dict[str, Any] = {}
db_info['partitions'] = _get_partitions(conn)
db_info['tables'] = _get_tables(conn)
db_info['reverse_only'] = 'search_name' not in db_info['tables']
db_info['tablespace'] = _setup_tablespace_sql(config)
db_info['middle_db_format'] = _get_middle_db_format(conn, db_info['tables'])
self.env.globals['config'] = config
self.env.globals['db'] = db_info
self.env.globals['postgres'] = _setup_postgresql_features(conn)
def run_string(self, conn: Connection, template: str, **kwargs: Any) -> None:
""" Execute the given SQL template string on the connection.
The keyword arguments may supply additional parameters
for preprocessing.
"""
sql = self.env.from_string(template).render(**kwargs)
with conn.cursor() as cur:
cur.execute(sql)
conn.commit()
def run_sql_file(self, conn: Connection, name: str, **kwargs: Any) -> None:
""" Execute the given SQL file on the connection. The keyword arguments
may supply additional parameters for preprocessing.
"""
sql = self.env.get_template(name).render(**kwargs)
with conn.cursor() as cur:
cur.execute(sql)
conn.commit()
def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
**kwargs: Any) -> None:
""" Execute the given SQL files using parallel asynchronous connections.
The keyword arguments may supply additional parameters for
preprocessing.
After preprocessing the SQL code is cut at lines containing only
'---'. Each chunk is sent to one of the `num_threads` workers.
"""
sql = self.env.get_template(name).render(**kwargs)
parts = sql.split('\n---\n')
with WorkerPool(dsn, num_threads) as pool:
for part in parts:
pool.next_free_worker().perform(part)

View File

@@ -0,0 +1,127 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Access and helper functions for the status and status log table.
"""
from typing import Optional, Tuple, cast
import datetime as dt
import logging
import re
from .connection import Connection
from ..utils.url_utils import get_url
from ..errors import UsageError
from ..typing import TypedDict
LOG = logging.getLogger()
ISODATE_FORMAT = '%Y-%m-%dT%H:%M:%S'
class StatusRow(TypedDict):
""" Dictionary of columns of the import_status table.
"""
lastimportdate: dt.datetime
sequence_id: Optional[int]
indexed: Optional[bool]
def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetime:
""" Determine the date of the database from the newest object in the
data base.
"""
# If there is a date from osm2pgsql available, use that.
if conn.table_exists('osm2pgsql_properties'):
with conn.cursor() as cur:
cur.execute(""" SELECT value FROM osm2pgsql_properties
WHERE property = 'current_timestamp' """)
row = cur.fetchone()
if row is not None:
return dt.datetime.strptime(row[0], "%Y-%m-%dT%H:%M:%SZ")\
.replace(tzinfo=dt.timezone.utc)
if offline:
raise UsageError("Cannot determine database date from data in offline mode.")
# Else, find the node with the highest ID in the database
with conn.cursor() as cur:
if conn.table_exists('place'):
osmid = cur.scalar("SELECT max(osm_id) FROM place WHERE osm_type='N'")
else:
osmid = cur.scalar("SELECT max(osm_id) FROM placex WHERE osm_type='N'")
if osmid is None:
LOG.fatal("No data found in the database.")
raise UsageError("No data found in the database.")
LOG.info("Using node id %d for timestamp lookup", osmid)
# Get the node from the API to find the timestamp when it was created.
node_url = f'https://www.openstreetmap.org/api/0.6/node/{osmid}/1'
data = get_url(node_url)
match = re.search(r'timestamp="((\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2}):(\d{2}))Z"', data)
if match is None:
LOG.fatal("The node data downloaded from the API does not contain valid data.\n"
"URL used: %s", node_url)
raise UsageError("Bad API data.")
LOG.debug("Found timestamp %s", match.group(1))
return dt.datetime.strptime(match.group(1), ISODATE_FORMAT).replace(tzinfo=dt.timezone.utc)
def set_status(conn: Connection, date: Optional[dt.datetime],
seq: Optional[int] = None, indexed: bool = True) -> None:
""" Replace the current status with the given status. If date is `None`
then only sequence and indexed will be updated as given. Otherwise
the whole status is replaced.
The change will be committed to the database.
"""
assert date is None or date.tzinfo == dt.timezone.utc
with conn.cursor() as cur:
if date is None:
cur.execute("UPDATE import_status set sequence_id = %s, indexed = %s",
(seq, indexed))
else:
cur.execute("TRUNCATE TABLE import_status")
cur.execute("""INSERT INTO import_status (lastimportdate, sequence_id, indexed)
VALUES (%s, %s, %s)""", (date, seq, indexed))
conn.commit()
def get_status(conn: Connection) -> Tuple[Optional[dt.datetime], Optional[int], Optional[bool]]:
""" Return the current status as a triple of (date, sequence, indexed).
If status has not been set up yet, a triple of None is returned.
"""
with conn.cursor() as cur:
cur.execute("SELECT * FROM import_status LIMIT 1")
if cur.rowcount < 1:
return None, None, None
row = cast(StatusRow, cur.fetchone())
return row['lastimportdate'], row['sequence_id'], row['indexed']
def set_indexed(conn: Connection, state: bool) -> None:
""" Set the indexed flag in the status table to the given state.
"""
with conn.cursor() as cur:
cur.execute("UPDATE import_status SET indexed = %s", (state, ))
conn.commit()
def log_status(conn: Connection, start: dt.datetime,
event: str, batchsize: Optional[int] = None) -> None:
""" Write a new status line to the `import_osmosis_log` table.
"""
with conn.cursor() as cur:
cur.execute("""INSERT INTO import_osmosis_log
(batchend, batchseq, batchsize, starttime, endtime, event)
SELECT lastimportdate, sequence_id, %s, %s, now(), %s FROM import_status""",
(batchsize, start, event))
conn.commit()

View File

@@ -0,0 +1,129 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Helper functions for handling DB accesses.
"""
from typing import IO, Optional, Union, Any, Iterable
import subprocess
import logging
import gzip
import io
from pathlib import Path
from .connection import get_pg_env, Cursor
from ..errors import UsageError
LOG = logging.getLogger()
def _pipe_to_proc(proc: 'subprocess.Popen[bytes]',
fdesc: Union[IO[bytes], gzip.GzipFile]) -> int:
assert proc.stdin is not None
chunk = fdesc.read(2048)
while chunk and proc.poll() is None:
try:
proc.stdin.write(chunk)
except BrokenPipeError as exc:
raise UsageError("Failed to execute SQL file.") from exc
chunk = fdesc.read(2048)
return len(chunk)
def execute_file(dsn: str, fname: Path,
ignore_errors: bool = False,
pre_code: Optional[str] = None,
post_code: Optional[str] = None) -> None:
""" Read an SQL file and run its contents against the given database
using psql. Use `pre_code` and `post_code` to run extra commands
before or after executing the file. The commands are run within the
same session, so they may be used to wrap the file execution in a
transaction.
"""
cmd = ['psql']
if not ignore_errors:
cmd.extend(('-v', 'ON_ERROR_STOP=1'))
if not LOG.isEnabledFor(logging.INFO):
cmd.append('--quiet')
with subprocess.Popen(cmd, env=get_pg_env(dsn), stdin=subprocess.PIPE) as proc:
assert proc.stdin is not None
try:
if not LOG.isEnabledFor(logging.INFO):
proc.stdin.write('set client_min_messages to WARNING;'.encode('utf-8'))
if pre_code:
proc.stdin.write((pre_code + ';').encode('utf-8'))
if fname.suffix == '.gz':
with gzip.open(str(fname), 'rb') as fdesc:
remain = _pipe_to_proc(proc, fdesc)
else:
with fname.open('rb') as fdesc:
remain = _pipe_to_proc(proc, fdesc)
if remain == 0 and post_code:
proc.stdin.write((';' + post_code).encode('utf-8'))
finally:
proc.stdin.close()
ret = proc.wait()
if ret != 0 or remain > 0:
raise UsageError("Failed to execute SQL file.")
# List of characters that need to be quoted for the copy command.
_SQL_TRANSLATION = {ord('\\'): '\\\\',
ord('\t'): '\\t',
ord('\n'): '\\n'}
class CopyBuffer:
""" Data collector for the copy_from command.
"""
def __init__(self) -> None:
self.buffer = io.StringIO()
def __enter__(self) -> 'CopyBuffer':
return self
def size(self) -> int:
""" Return the number of bytes the buffer currently contains.
"""
return self.buffer.tell()
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self.buffer is not None:
self.buffer.close()
def add(self, *data: Any) -> None:
""" Add another row of data to the copy buffer.
"""
first = True
for column in data:
if first:
first = False
else:
self.buffer.write('\t')
if column is None:
self.buffer.write('\\N')
else:
self.buffer.write(str(column).translate(_SQL_TRANSLATION))
self.buffer.write('\n')
def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None:
""" Copy all collected data into the given table.
The buffer is empty and reusable after this operation.
"""
if self.buffer.tell() > 0:
self.buffer.seek(0)
cur.copy_from(self.buffer, table, columns=columns)
self.buffer = io.StringIO()