split code into submodules

This commit is contained in:
Sarah Hoffmann
2024-05-16 11:55:17 +02:00
parent 0fb4fe8e4d
commit 6e89310a92
137 changed files with 757 additions and 716 deletions

View File

View File

@@ -0,0 +1,374 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2022 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Nominatim configuration accessor.
"""
from typing import Dict, Any, List, Mapping, Optional
import importlib.util
import logging
import os
import sys
from pathlib import Path
import json
import yaml
from dotenv import dotenv_values
from psycopg2.extensions import parse_dsn
from .typing import StrPath
from .errors import UsageError
from . import paths
LOG = logging.getLogger()
CONFIG_CACHE : Dict[str, Any] = {}
def flatten_config_list(content: Any, section: str = '') -> List[Any]:
""" Flatten YAML configuration lists that contain include sections
which are lists themselves.
"""
if not content:
return []
if not isinstance(content, list):
raise UsageError(f"List expected in section '{section}'.")
output = []
for ele in content:
if isinstance(ele, list):
output.extend(flatten_config_list(ele, section))
else:
output.append(ele)
return output
class Configuration:
""" This class wraps access to the configuration settings
for the Nominatim instance in use.
All Nominatim configuration options are prefixed with 'NOMINATIM_' to
avoid conflicts with other environment variables. All settings can
be accessed as properties of the class under the same name as the
setting but with the `NOMINATIM_` prefix removed. In addition, there
are accessor functions that convert the setting values to types
other than string.
"""
def __init__(self, project_dir: Optional[Path],
environ: Optional[Mapping[str, str]] = None) -> None:
self.environ = environ or os.environ
self.project_dir = project_dir
self.config_dir = paths.CONFIG_DIR
self._config = dotenv_values(str(self.config_dir / 'env.defaults'))
if self.project_dir is not None and (self.project_dir / '.env').is_file():
self.project_dir = self.project_dir.resolve()
self._config.update(dotenv_values(str(self.project_dir / '.env')))
class _LibDirs:
module: Path
osm2pgsql: Path
php = paths.PHPLIB_DIR
sql = paths.SQLLIB_DIR
data = paths.DATA_DIR
self.lib_dir = _LibDirs()
self._private_plugins: Dict[str, object] = {}
def set_libdirs(self, **kwargs: StrPath) -> None:
""" Set paths to library functions and data.
"""
for key, value in kwargs.items():
setattr(self.lib_dir, key, None if value is None else Path(value))
def __getattr__(self, name: str) -> str:
name = 'NOMINATIM_' + name
if name in self.environ:
return self.environ[name]
return self._config[name] or ''
def get_bool(self, name: str) -> bool:
""" Return the given configuration parameter as a boolean.
Parameters:
name: Name of the configuration parameter with the NOMINATIM_
prefix removed.
Returns:
`True` for values of '1', 'yes' and 'true', `False` otherwise.
"""
return getattr(self, name).lower() in ('1', 'yes', 'true')
def get_int(self, name: str) -> int:
""" Return the given configuration parameter as an int.
Parameters:
name: Name of the configuration parameter with the NOMINATIM_
prefix removed.
Returns:
The configuration value converted to int.
Raises:
ValueError: when the value is not a number.
"""
try:
return int(getattr(self, name))
except ValueError as exp:
LOG.fatal("Invalid setting NOMINATIM_%s. Needs to be a number.", name)
raise UsageError("Configuration error.") from exp
def get_str_list(self, name: str) -> Optional[List[str]]:
""" Return the given configuration parameter as a list of strings.
The values are assumed to be given as a comma-sparated list and
will be stripped before returning them.
Parameters:
name: Name of the configuration parameter with the NOMINATIM_
prefix removed.
Returns:
(List[str]): The comma-split parameter as a list. The
elements are stripped of leading and final spaces before
being returned.
(None): The configuration parameter was unset or empty.
"""
raw = getattr(self, name)
return [v.strip() for v in raw.split(',')] if raw else None
def get_path(self, name: str) -> Optional[Path]:
""" Return the given configuration parameter as a Path.
Parameters:
name: Name of the configuration parameter with the NOMINATIM_
prefix removed.
Returns:
(Path): A Path object of the parameter value.
If a relative path is configured, then the function converts this
into an absolute path with the project directory as root path.
(None): The configuration parameter was unset or empty.
"""
value = getattr(self, name)
if not value:
return None
cfgpath = Path(value)
if not cfgpath.is_absolute():
assert self.project_dir is not None
cfgpath = self.project_dir / cfgpath
return cfgpath.resolve()
def get_libpq_dsn(self) -> str:
""" Get configured database DSN converted into the key/value format
understood by libpq and psycopg.
"""
dsn = self.DATABASE_DSN
def quote_param(param: str) -> str:
key, val = param.split('=')
val = val.replace('\\', '\\\\').replace("'", "\\'")
if ' ' in val:
val = "'" + val + "'"
return key + '=' + val
if dsn.startswith('pgsql:'):
# Old PHP DSN format. Convert before returning.
return ' '.join([quote_param(p) for p in dsn[6:].split(';')])
return dsn
def get_database_params(self) -> Mapping[str, str]:
""" Get the configured parameters for the database connection
as a mapping.
"""
dsn = self.DATABASE_DSN
if dsn.startswith('pgsql:'):
return dict((p.split('=', 1) for p in dsn[6:].split(';')))
return parse_dsn(dsn)
def get_import_style_file(self) -> Path:
""" Return the import style file as a path object. Translates the
name of the standard styles automatically into a file in the
config style.
"""
style = getattr(self, 'IMPORT_STYLE')
if style in ('admin', 'street', 'address', 'full', 'extratags'):
return self.config_dir / f'import-{style}.lua'
return self.find_config_file('', 'IMPORT_STYLE')
def get_os_env(self) -> Dict[str, str]:
""" Return a copy of the OS environment with the Nominatim configuration
merged in.
"""
env = {k: v for k, v in self._config.items() if v is not None}
env.update(self.environ)
return env
def load_sub_configuration(self, filename: StrPath,
config: Optional[str] = None) -> Any:
""" Load additional configuration from a file. `filename` is the name
of the configuration file. The file is first searched in the
project directory and then in the global settings directory.
If `config` is set, then the name of the configuration file can
be additionally given through a .env configuration option. When
the option is set, then the file will be exclusively loaded as set:
if the name is an absolute path, the file name is taken as is,
if the name is relative, it is taken to be relative to the
project directory.
The format of the file is determined from the filename suffix.
Currently only files with extension '.yaml' are supported.
YAML files support a special '!include' construct. When the
directive is given, the value is taken to be a filename, the file
is loaded using this function and added at the position in the
configuration tree.
"""
configfile = self.find_config_file(filename, config)
if str(configfile) in CONFIG_CACHE:
return CONFIG_CACHE[str(configfile)]
if configfile.suffix in ('.yaml', '.yml'):
result = self._load_from_yaml(configfile)
elif configfile.suffix == '.json':
with configfile.open('r', encoding='utf-8') as cfg:
result = json.load(cfg)
else:
raise UsageError(f"Config file '{configfile}' has unknown format.")
CONFIG_CACHE[str(configfile)] = result
return result
def load_plugin_module(self, module_name: str, internal_path: str) -> Any:
""" Load a Python module as a plugin.
The module_name may have three variants:
* A name without any '.' is assumed to be an internal module
and will be searched relative to `internal_path`.
* If the name ends in `.py`, module_name is assumed to be a
file name relative to the project directory.
* Any other name is assumed to be an absolute module name.
In either of the variants the module name must start with a letter.
"""
if not module_name or not module_name[0].isidentifier():
raise UsageError(f'Invalid module name {module_name}')
if '.' not in module_name:
module_name = module_name.replace('-', '_')
full_module = f'{internal_path}.{module_name}'
return sys.modules.get(full_module) or importlib.import_module(full_module)
if module_name.endswith('.py'):
if self.project_dir is None or not (self.project_dir / module_name).exists():
raise UsageError(f"Cannot find module '{module_name}' in project directory.")
if module_name in self._private_plugins:
return self._private_plugins[module_name]
file_path = str(self.project_dir / module_name)
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec:
module = importlib.util.module_from_spec(spec)
# Do not add to global modules because there is no standard
# module name that Python can resolve.
self._private_plugins[module_name] = module
assert spec.loader is not None
spec.loader.exec_module(module)
return module
return sys.modules.get(module_name) or importlib.import_module(module_name)
def find_config_file(self, filename: StrPath,
config: Optional[str] = None) -> Path:
""" Resolve the location of a configuration file given a filename and
an optional configuration option with the file name.
Raises a UsageError when the file cannot be found or is not
a regular file.
"""
if config is not None:
cfg_value = getattr(self, config)
if cfg_value:
cfg_filename = Path(cfg_value)
if cfg_filename.is_absolute():
cfg_filename = cfg_filename.resolve()
if not cfg_filename.is_file():
LOG.fatal("Cannot find config file '%s'.", cfg_filename)
raise UsageError("Config file not found.")
return cfg_filename
filename = cfg_filename
search_paths = [self.project_dir, self.config_dir]
for path in search_paths:
if path is not None and (path / filename).is_file():
return path / filename
LOG.fatal("Configuration file '%s' not found.\nDirectories searched: %s",
filename, search_paths)
raise UsageError("Config file not found.")
def _load_from_yaml(self, cfgfile: Path) -> Any:
""" Load a YAML configuration file. This installs a special handler that
allows to include other YAML files using the '!include' operator.
"""
yaml.add_constructor('!include', self._yaml_include_representer,
Loader=yaml.SafeLoader)
return yaml.safe_load(cfgfile.read_text(encoding='utf-8'))
def _yaml_include_representer(self, loader: Any, node: yaml.Node) -> Any:
""" Handler for the '!include' operator in YAML files.
When the filename is relative, then the file is first searched in the
project directory and then in the global settings directory.
"""
fname = loader.construct_scalar(node)
if Path(fname).is_absolute():
configfile = Path(fname)
else:
configfile = self.find_config_file(loader.construct_scalar(node))
if configfile.suffix != '.yaml':
LOG.fatal("Format error while reading '%s': only YAML format supported.",
configfile)
raise UsageError("Cannot handle config file format.")
return yaml.safe_load(configfile.read_text(encoding='utf-8'))

View File

View File

@@ -0,0 +1,236 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
""" Non-blocking database connections.
"""
from typing import Callable, Any, Optional, Iterator, Sequence
import logging
import select
import time
import psycopg2
from psycopg2.extras import wait_select
# psycopg2 emits different exceptions pre and post 2.8. Detect if the new error
# module is available and adapt the error handling accordingly.
try:
import psycopg2.errors # pylint: disable=no-name-in-module,import-error
__has_psycopg2_errors__ = True
except ImportError:
__has_psycopg2_errors__ = False
from ..typing import T_cursor, Query
LOG = logging.getLogger()
class DeadlockHandler:
""" Context manager that catches deadlock exceptions and calls
the given handler function. All other exceptions are passed on
normally.
"""
def __init__(self, handler: Callable[[], None], ignore_sql_errors: bool = False) -> None:
self.handler = handler
self.ignore_sql_errors = ignore_sql_errors
def __enter__(self) -> 'DeadlockHandler':
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
if __has_psycopg2_errors__:
if exc_type == psycopg2.errors.DeadlockDetected: # pylint: disable=E1101
self.handler()
return True
elif exc_type == psycopg2.extensions.TransactionRollbackError \
and exc_value.pgcode == '40P01':
self.handler()
return True
if self.ignore_sql_errors and isinstance(exc_value, psycopg2.Error):
LOG.info("SQL error ignored: %s", exc_value)
return True
return False
class DBConnection:
""" A single non-blocking database connection.
"""
def __init__(self, dsn: str,
cursor_factory: Optional[Callable[..., T_cursor]] = None,
ignore_sql_errors: bool = False) -> None:
self.dsn = dsn
self.current_query: Optional[Query] = None
self.current_params: Optional[Sequence[Any]] = None
self.ignore_sql_errors = ignore_sql_errors
self.conn: Optional['psycopg2._psycopg.connection'] = None
self.cursor: Optional['psycopg2._psycopg.cursor'] = None
self.connect(cursor_factory=cursor_factory)
def close(self) -> None:
""" Close all open connections. Does not wait for pending requests.
"""
if self.conn is not None:
if self.cursor is not None:
self.cursor.close()
self.cursor = None
self.conn.close()
self.conn = None
def connect(self, cursor_factory: Optional[Callable[..., T_cursor]] = None) -> None:
""" (Re)connect to the database. Creates an asynchronous connection
with JIT and parallel processing disabled. If a connection was
already open, it is closed and a new connection established.
The caller must ensure that no query is pending before reconnecting.
"""
self.close()
# Use a dict to hand in the parameters because async is a reserved
# word in Python3.
self.conn = psycopg2.connect(**{'dsn': self.dsn, 'async': True}) # type: ignore
assert self.conn
self.wait()
if cursor_factory is not None:
self.cursor = self.conn.cursor(cursor_factory=cursor_factory)
else:
self.cursor = self.conn.cursor()
# Disable JIT and parallel workers as they are known to cause problems.
# Update pg_settings instead of using SET because it does not yield
# errors on older versions of Postgres where the settings are not
# implemented.
self.perform(
""" UPDATE pg_settings SET setting = -1 WHERE name = 'jit_above_cost';
UPDATE pg_settings SET setting = 0
WHERE name = 'max_parallel_workers_per_gather';""")
self.wait()
def _deadlock_handler(self) -> None:
LOG.info("Deadlock detected (params = %s), retry.", str(self.current_params))
assert self.cursor is not None
assert self.current_query is not None
assert self.current_params is not None
self.cursor.execute(self.current_query, self.current_params)
def wait(self) -> None:
""" Block until any pending operation is done.
"""
while True:
with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
wait_select(self.conn)
self.current_query = None
return
def perform(self, sql: Query, args: Optional[Sequence[Any]] = None) -> None:
""" Send SQL query to the server. Returns immediately without
blocking.
"""
assert self.cursor is not None
self.current_query = sql
self.current_params = args
self.cursor.execute(sql, args)
def fileno(self) -> int:
""" File descriptor to wait for. (Makes this class select()able.)
"""
assert self.conn is not None
return self.conn.fileno()
def is_done(self) -> bool:
""" Check if the connection is available for a new query.
Also checks if the previous query has run into a deadlock.
If so, then the previous query is repeated.
"""
assert self.conn is not None
if self.current_query is None:
return True
with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
if self.conn.poll() == psycopg2.extensions.POLL_OK:
self.current_query = None
return True
return False
class WorkerPool:
""" A pool of asynchronous database connections.
The pool may be used as a context manager.
"""
REOPEN_CONNECTIONS_AFTER = 100000
def __init__(self, dsn: str, pool_size: int, ignore_sql_errors: bool = False) -> None:
self.threads = [DBConnection(dsn, ignore_sql_errors=ignore_sql_errors)
for _ in range(pool_size)]
self.free_workers = self._yield_free_worker()
self.wait_time = 0.0
def finish_all(self) -> None:
""" Wait for all connection to finish.
"""
for thread in self.threads:
while not thread.is_done():
thread.wait()
self.free_workers = self._yield_free_worker()
def close(self) -> None:
""" Close all connections and clear the pool.
"""
for thread in self.threads:
thread.close()
self.threads = []
self.free_workers = iter([])
def next_free_worker(self) -> DBConnection:
""" Get the next free connection.
"""
return next(self.free_workers)
def _yield_free_worker(self) -> Iterator[DBConnection]:
ready = self.threads
command_stat = 0
while True:
for thread in ready:
if thread.is_done():
command_stat += 1
yield thread
if command_stat > self.REOPEN_CONNECTIONS_AFTER:
self._reconnect_threads()
ready = self.threads
command_stat = 0
else:
tstart = time.time()
_, ready, _ = select.select([], self.threads, [])
self.wait_time += time.time() - tstart
def _reconnect_threads(self) -> None:
for thread in self.threads:
while not thread.is_done():
thread.wait()
thread.connect()
def __enter__(self) -> 'WorkerPool':
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.finish_all()
self.close()

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Import the base library to use with asynchronous SQLAlchemy.
"""
# pylint: disable=invalid-name
from typing import Any
try:
import psycopg
PGCORE_LIB = 'psycopg'
PGCORE_ERROR: Any = psycopg.Error
except ModuleNotFoundError:
import asyncpg
PGCORE_LIB = 'asyncpg'
PGCORE_ERROR = asyncpg.PostgresError

View File

@@ -0,0 +1,254 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Specialised connection and cursor functions.
"""
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
import contextlib
import logging
import os
import psycopg2
import psycopg2.extensions
import psycopg2.extras
from psycopg2 import sql as pysql
from ..typing import SysEnv, Query, T_cursor
from ..errors import UsageError
LOG = logging.getLogger()
class Cursor(psycopg2.extras.DictCursor):
""" A cursor returning dict-like objects and providing specialised
execution functions.
"""
# pylint: disable=arguments-renamed,arguments-differ
def execute(self, query: Query, args: Any = None) -> None:
""" Query execution that logs the SQL query when debugging is enabled.
"""
if LOG.isEnabledFor(logging.DEBUG):
LOG.debug(self.mogrify(query, args).decode('utf-8'))
super().execute(query, args)
def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
template: Optional[Query] = None) -> None:
""" Wrapper for the psycopg2 convenience function to execute
SQL for a list of values.
"""
LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
psycopg2.extras.execute_values(self, sql, argslist, template=template)
def scalar(self, sql: Query, args: Any = None) -> Any:
""" Execute query that returns a single value. The value is returned.
If the query yields more than one row, a ValueError is raised.
"""
self.execute(sql, args)
if self.rowcount != 1:
raise RuntimeError("Query did not return a single row.")
result = self.fetchone()
assert result is not None
return result[0]
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
""" Drop the table with the given name.
Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored. If 'cascade' is set
to True then all dependent tables are deleted as well.
"""
sql = 'DROP TABLE '
if if_exists:
sql += 'IF EXISTS '
sql += '{}'
if cascade:
sql += ' CASCADE'
self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
class Connection(psycopg2.extensions.connection):
""" A connection that provides the specialised cursor by default and
adds convenience functions for administrating the database.
"""
@overload # type: ignore[override]
def cursor(self) -> Cursor:
...
@overload
def cursor(self, name: str) -> Cursor:
...
@overload
def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
...
def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore
""" Return a new cursor. By default the specialised cursor is returned.
"""
return super().cursor(cursor_factory=cursor_factory, **kwargs)
def table_exists(self, table: str) -> bool:
""" Check that a table with the given name exists in the database.
"""
with self.cursor() as cur:
num = cur.scalar("""SELECT count(*) FROM pg_tables
WHERE tablename = %s and schemaname = 'public'""", (table, ))
return num == 1 if isinstance(num, int) else False
def table_has_column(self, table: str, column: str) -> bool:
""" Check if the table 'table' exists and has a column with name 'column'.
"""
with self.cursor() as cur:
has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
WHERE table_name = %s
and column_name = %s""",
(table, column))
return has_column > 0 if isinstance(has_column, int) else False
def index_exists(self, index: str, table: Optional[str] = None) -> bool:
""" Check that an index with the given name exists in the database.
If table is not None then the index must relate to the given
table.
"""
with self.cursor() as cur:
cur.execute("""SELECT tablename FROM pg_indexes
WHERE indexname = %s and schemaname = 'public'""", (index, ))
if cur.rowcount == 0:
return False
if table is not None:
row = cur.fetchone()
if row is None or not isinstance(row[0], str):
return False
return row[0] == table
return True
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
""" Drop the table with the given name.
Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored.
"""
with self.cursor() as cur:
cur.drop_table(name, if_exists, cascade)
self.commit()
def server_version_tuple(self) -> Tuple[int, int]:
""" Return the server version as a tuple of (major, minor).
Converts correctly for pre-10 and post-10 PostgreSQL versions.
"""
version = self.server_version
if version < 100000:
return (int(version / 10000), int((version % 10000) / 100))
return (int(version / 10000), version % 10000)
def postgis_version_tuple(self) -> Tuple[int, int]:
""" Return the postgis version installed in the database as a
tuple of (major, minor). Assumes that the PostGIS extension
has been installed already.
"""
with self.cursor() as cur:
version = cur.scalar('SELECT postgis_lib_version()')
version_parts = version.split('.')
if len(version_parts) < 2:
raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
return (int(version_parts[0]), int(version_parts[1]))
def extension_loaded(self, extension_name: str) -> bool:
""" Return True if the hstore extension is loaded in the database.
"""
with self.cursor() as cur:
cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, ))
return cur.rowcount > 0
class ConnectionContext(ContextManager[Connection]):
""" Context manager of the connection that also provides direct access
to the underlying connection.
"""
connection: Connection
def connect(dsn: str) -> ConnectionContext:
""" Open a connection to the database using the specialised connection
factory. The returned object may be used in conjunction with 'with'.
When used outside a context manager, use the `connection` attribute
to get the connection.
"""
try:
conn = psycopg2.connect(dsn, connection_factory=Connection)
ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
ctxmgr.connection = conn
return ctxmgr
except psycopg2.OperationalError as err:
raise UsageError(f"Cannot connect to database: {err}") from err
# Translation from PG connection string parameters to PG environment variables.
# Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
_PG_CONNECTION_STRINGS = {
'host': 'PGHOST',
'hostaddr': 'PGHOSTADDR',
'port': 'PGPORT',
'dbname': 'PGDATABASE',
'user': 'PGUSER',
'password': 'PGPASSWORD',
'passfile': 'PGPASSFILE',
'channel_binding': 'PGCHANNELBINDING',
'service': 'PGSERVICE',
'options': 'PGOPTIONS',
'application_name': 'PGAPPNAME',
'sslmode': 'PGSSLMODE',
'requiressl': 'PGREQUIRESSL',
'sslcompression': 'PGSSLCOMPRESSION',
'sslcert': 'PGSSLCERT',
'sslkey': 'PGSSLKEY',
'sslrootcert': 'PGSSLROOTCERT',
'sslcrl': 'PGSSLCRL',
'requirepeer': 'PGREQUIREPEER',
'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
'gssencmode': 'PGGSSENCMODE',
'krbsrvname': 'PGKRBSRVNAME',
'gsslib': 'PGGSSLIB',
'connect_timeout': 'PGCONNECT_TIMEOUT',
'target_session_attrs': 'PGTARGETSESSIONATTRS',
}
def get_pg_env(dsn: str,
base_env: Optional[SysEnv] = None) -> Dict[str, str]:
""" Return a copy of `base_env` with the environment variables for
PostgreSQL set up from the given database connection string.
If `base_env` is None, then the OS environment is used as a base
environment.
"""
env = dict(base_env if base_env is not None else os.environ)
for param, value in psycopg2.extensions.parse_dsn(dsn).items():
if param in _PG_CONNECTION_STRINGS:
env[_PG_CONNECTION_STRINGS[param]] = value
else:
LOG.error("Unknown connection parameter '%s' ignored.", param)
return env

View File

@@ -0,0 +1,47 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Query and access functions for the in-database property table.
"""
from typing import Optional, cast
from .connection import Connection
def set_property(conn: Connection, name: str, value: str) -> None:
""" Add or replace the property with the given name.
"""
with conn.cursor() as cur:
cur.execute('SELECT value FROM nominatim_properties WHERE property = %s',
(name, ))
if cur.rowcount == 0:
sql = 'INSERT INTO nominatim_properties (value, property) VALUES (%s, %s)'
else:
sql = 'UPDATE nominatim_properties SET value = %s WHERE property = %s'
cur.execute(sql, (value, name))
conn.commit()
def get_property(conn: Connection, name: str) -> Optional[str]:
""" Return the current value of the given property or None if the property
is not set.
"""
if not conn.table_exists('nominatim_properties'):
return None
with conn.cursor() as cur:
cur.execute('SELECT value FROM nominatim_properties WHERE property = %s',
(name, ))
if cur.rowcount == 0:
return None
result = cur.fetchone()
assert result is not None
return cast(Optional[str], result[0])

View File

@@ -0,0 +1,143 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Preprocessing of SQL files.
"""
from typing import Set, Dict, Any, cast
import jinja2
from .connection import Connection
from .async_connection import WorkerPool
from ..config import Configuration
def _get_partitions(conn: Connection) -> Set[int]:
""" Get the set of partitions currently in use.
"""
with conn.cursor() as cur:
cur.execute('SELECT DISTINCT partition FROM country_name')
partitions = set([0])
for row in cur:
partitions.add(row[0])
return partitions
def _get_tables(conn: Connection) -> Set[str]:
""" Return the set of tables currently in use.
"""
with conn.cursor() as cur:
cur.execute("SELECT tablename FROM pg_tables WHERE schemaname = 'public'")
return set((row[0] for row in list(cur)))
def _get_middle_db_format(conn: Connection, tables: Set[str]) -> str:
""" Returns the version of the slim middle tables.
"""
if 'osm2pgsql_properties' not in tables:
return '1'
with conn.cursor() as cur:
cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
row = cur.fetchone()
return cast(str, row[0]) if row is not None else '1'
def _setup_tablespace_sql(config: Configuration) -> Dict[str, str]:
""" Returns a dict with tablespace expressions for the different tablespace
kinds depending on whether a tablespace is configured or not.
"""
out = {}
for subset in ('ADDRESS', 'SEARCH', 'AUX'):
for kind in ('DATA', 'INDEX'):
tspace = getattr(config, f'TABLESPACE_{subset}_{kind}')
if tspace:
tspace = f'TABLESPACE "{tspace}"'
out[f'{subset.lower()}_{kind.lower()}'] = tspace
return out
def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
""" Set up a dictionary with various optional Postgresql/Postgis features that
depend on the database version.
"""
pg_version = conn.server_version_tuple()
postgis_version = conn.postgis_version_tuple()
pg11plus = pg_version >= (11, 0, 0)
ps3 = postgis_version >= (3, 0)
return {
'has_index_non_key_column': pg11plus,
'spgist_geom' : 'SPGIST' if pg11plus and ps3 else 'GIST'
}
class SQLPreprocessor:
""" A environment for preprocessing SQL files from the
lib-sql directory.
The preprocessor provides a number of default filters and variables.
The variables may be overwritten when rendering an SQL file.
The preprocessing is currently based on the jinja2 templating library
and follows its syntax.
"""
def __init__(self, conn: Connection, config: Configuration) -> None:
self.env = jinja2.Environment(autoescape=False,
loader=jinja2.FileSystemLoader(str(config.lib_dir.sql)))
db_info: Dict[str, Any] = {}
db_info['partitions'] = _get_partitions(conn)
db_info['tables'] = _get_tables(conn)
db_info['reverse_only'] = 'search_name' not in db_info['tables']
db_info['tablespace'] = _setup_tablespace_sql(config)
db_info['middle_db_format'] = _get_middle_db_format(conn, db_info['tables'])
self.env.globals['config'] = config
self.env.globals['db'] = db_info
self.env.globals['postgres'] = _setup_postgresql_features(conn)
def run_string(self, conn: Connection, template: str, **kwargs: Any) -> None:
""" Execute the given SQL template string on the connection.
The keyword arguments may supply additional parameters
for preprocessing.
"""
sql = self.env.from_string(template).render(**kwargs)
with conn.cursor() as cur:
cur.execute(sql)
conn.commit()
def run_sql_file(self, conn: Connection, name: str, **kwargs: Any) -> None:
""" Execute the given SQL file on the connection. The keyword arguments
may supply additional parameters for preprocessing.
"""
sql = self.env.get_template(name).render(**kwargs)
with conn.cursor() as cur:
cur.execute(sql)
conn.commit()
def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
**kwargs: Any) -> None:
""" Execute the given SQL files using parallel asynchronous connections.
The keyword arguments may supply additional parameters for
preprocessing.
After preprocessing the SQL code is cut at lines containing only
'---'. Each chunk is sent to one of the `num_threads` workers.
"""
sql = self.env.get_template(name).render(**kwargs)
parts = sql.split('\n---\n')
with WorkerPool(dsn, num_threads) as pool:
for part in parts:
pool.next_free_worker().perform(part)

View File

@@ -0,0 +1,119 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
SQLAlchemy definitions for all tables used by the frontend.
"""
import sqlalchemy as sa
from .sqlalchemy_types import Geometry, KeyValueStore, IntArray
#pylint: disable=too-many-instance-attributes
class SearchTables:
""" Data class that holds the tables of the Nominatim database.
This schema strictly reflects the read-access view of the database.
Any data used for updates only will not be visible.
"""
def __init__(self, meta: sa.MetaData) -> None:
self.meta = meta
self.import_status = sa.Table('import_status', meta,
sa.Column('lastimportdate', sa.DateTime(True), nullable=False),
sa.Column('sequence_id', sa.Integer),
sa.Column('indexed', sa.Boolean))
self.properties = sa.Table('nominatim_properties', meta,
sa.Column('property', sa.Text, nullable=False),
sa.Column('value', sa.Text))
self.placex = sa.Table('placex', meta,
sa.Column('place_id', sa.BigInteger, nullable=False),
sa.Column('parent_place_id', sa.BigInteger),
sa.Column('linked_place_id', sa.BigInteger),
sa.Column('importance', sa.Float),
sa.Column('indexed_date', sa.DateTime),
sa.Column('rank_address', sa.SmallInteger),
sa.Column('rank_search', sa.SmallInteger),
sa.Column('indexed_status', sa.SmallInteger),
sa.Column('osm_type', sa.String(1), nullable=False),
sa.Column('osm_id', sa.BigInteger, nullable=False),
sa.Column('class', sa.Text, nullable=False, key='class_'),
sa.Column('type', sa.Text, nullable=False),
sa.Column('admin_level', sa.SmallInteger),
sa.Column('name', KeyValueStore),
sa.Column('address', KeyValueStore),
sa.Column('extratags', KeyValueStore),
sa.Column('geometry', Geometry, nullable=False),
sa.Column('wikipedia', sa.Text),
sa.Column('country_code', sa.String(2)),
sa.Column('housenumber', sa.Text),
sa.Column('postcode', sa.Text),
sa.Column('centroid', Geometry))
self.addressline = sa.Table('place_addressline', meta,
sa.Column('place_id', sa.BigInteger),
sa.Column('address_place_id', sa.BigInteger),
sa.Column('distance', sa.Float),
sa.Column('fromarea', sa.Boolean),
sa.Column('isaddress', sa.Boolean))
self.postcode = sa.Table('location_postcode', meta,
sa.Column('place_id', sa.BigInteger),
sa.Column('parent_place_id', sa.BigInteger),
sa.Column('rank_search', sa.SmallInteger),
sa.Column('rank_address', sa.SmallInteger),
sa.Column('indexed_status', sa.SmallInteger),
sa.Column('indexed_date', sa.DateTime),
sa.Column('country_code', sa.String(2)),
sa.Column('postcode', sa.Text),
sa.Column('geometry', Geometry))
self.osmline = sa.Table('location_property_osmline', meta,
sa.Column('place_id', sa.BigInteger, nullable=False),
sa.Column('osm_id', sa.BigInteger),
sa.Column('parent_place_id', sa.BigInteger),
sa.Column('indexed_date', sa.DateTime),
sa.Column('startnumber', sa.Integer),
sa.Column('endnumber', sa.Integer),
sa.Column('step', sa.SmallInteger),
sa.Column('indexed_status', sa.SmallInteger),
sa.Column('linegeo', Geometry),
sa.Column('address', KeyValueStore),
sa.Column('postcode', sa.Text),
sa.Column('country_code', sa.String(2)))
self.country_name = sa.Table('country_name', meta,
sa.Column('country_code', sa.String(2)),
sa.Column('name', KeyValueStore),
sa.Column('derived_name', KeyValueStore),
sa.Column('partition', sa.Integer))
self.country_grid = sa.Table('country_osm_grid', meta,
sa.Column('country_code', sa.String(2)),
sa.Column('area', sa.Float),
sa.Column('geometry', Geometry))
# The following tables are not necessarily present.
self.search_name = sa.Table('search_name', meta,
sa.Column('place_id', sa.BigInteger),
sa.Column('importance', sa.Float),
sa.Column('search_rank', sa.SmallInteger),
sa.Column('address_rank', sa.SmallInteger),
sa.Column('name_vector', IntArray),
sa.Column('nameaddress_vector', IntArray),
sa.Column('country_code', sa.String(2)),
sa.Column('centroid', Geometry))
self.tiger = sa.Table('location_property_tiger', meta,
sa.Column('place_id', sa.BigInteger),
sa.Column('parent_place_id', sa.BigInteger),
sa.Column('startnumber', sa.Integer),
sa.Column('endnumber', sa.Integer),
sa.Column('step', sa.SmallInteger),
sa.Column('linegeo', Geometry),
sa.Column('postcode', sa.Text))

View File

@@ -0,0 +1,17 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Module with custom types for SQLAlchemy
"""
# See also https://github.com/PyCQA/pylint/issues/6006
# pylint: disable=useless-import-alias
from .geometry import (Geometry as Geometry)
from .int_array import (IntArray as IntArray)
from .key_value import (KeyValueStore as KeyValueStore)
from .json import (Json as Json)

View File

@@ -0,0 +1,308 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Custom types for SQLAlchemy.
"""
from __future__ import annotations
from typing import Callable, Any, cast
import sys
import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy import types
from ...typing import SaColumn, SaBind
#pylint: disable=all
class Geometry_DistanceSpheroid(sa.sql.expression.FunctionElement[float]):
""" Function to compute the spherical distance in meters.
"""
type = sa.Float()
name = 'Geometry_DistanceSpheroid'
inherit_cache = True
@compiles(Geometry_DistanceSpheroid) # type: ignore[no-untyped-call, misc]
def _default_distance_spheroid(element: Geometry_DistanceSpheroid,
compiler: 'sa.Compiled', **kw: Any) -> str:
return "ST_DistanceSpheroid(%s,"\
" 'SPHEROID[\"WGS 84\",6378137,298.257223563, AUTHORITY[\"EPSG\",\"7030\"]]')"\
% compiler.process(element.clauses, **kw)
@compiles(Geometry_DistanceSpheroid, 'sqlite') # type: ignore[no-untyped-call, misc]
def _spatialite_distance_spheroid(element: Geometry_DistanceSpheroid,
compiler: 'sa.Compiled', **kw: Any) -> str:
return "COALESCE(Distance(%s, true), 0.0)" % compiler.process(element.clauses, **kw)
class Geometry_IsLineLike(sa.sql.expression.FunctionElement[Any]):
""" Check if the geometry is a line or multiline.
"""
name = 'Geometry_IsLineLike'
inherit_cache = True
@compiles(Geometry_IsLineLike) # type: ignore[no-untyped-call, misc]
def _default_is_line_like(element: Geometry_IsLineLike,
compiler: 'sa.Compiled', **kw: Any) -> str:
return "ST_GeometryType(%s) IN ('ST_LineString', 'ST_MultiLineString')" % \
compiler.process(element.clauses, **kw)
@compiles(Geometry_IsLineLike, 'sqlite') # type: ignore[no-untyped-call, misc]
def _sqlite_is_line_like(element: Geometry_IsLineLike,
compiler: 'sa.Compiled', **kw: Any) -> str:
return "ST_GeometryType(%s) IN ('LINESTRING', 'MULTILINESTRING')" % \
compiler.process(element.clauses, **kw)
class Geometry_IsAreaLike(sa.sql.expression.FunctionElement[Any]):
""" Check if the geometry is a polygon or multipolygon.
"""
name = 'Geometry_IsLineLike'
inherit_cache = True
@compiles(Geometry_IsAreaLike) # type: ignore[no-untyped-call, misc]
def _default_is_area_like(element: Geometry_IsAreaLike,
compiler: 'sa.Compiled', **kw: Any) -> str:
return "ST_GeometryType(%s) IN ('ST_Polygon', 'ST_MultiPolygon')" % \
compiler.process(element.clauses, **kw)
@compiles(Geometry_IsAreaLike, 'sqlite') # type: ignore[no-untyped-call, misc]
def _sqlite_is_area_like(element: Geometry_IsAreaLike,
compiler: 'sa.Compiled', **kw: Any) -> str:
return "ST_GeometryType(%s) IN ('POLYGON', 'MULTIPOLYGON')" % \
compiler.process(element.clauses, **kw)
class Geometry_IntersectsBbox(sa.sql.expression.FunctionElement[Any]):
""" Check if the bounding boxes of the given geometries intersect.
"""
name = 'Geometry_IntersectsBbox'
inherit_cache = True
@compiles(Geometry_IntersectsBbox) # type: ignore[no-untyped-call, misc]
def _default_intersects(element: Geometry_IntersectsBbox,
compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
@compiles(Geometry_IntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
def _sqlite_intersects(element: Geometry_IntersectsBbox,
compiler: 'sa.Compiled', **kw: Any) -> str:
return "MbrIntersects(%s) = 1" % compiler.process(element.clauses, **kw)
class Geometry_ColumnIntersectsBbox(sa.sql.expression.FunctionElement[Any]):
""" Check if the bounding box of the geometry intersects with the
given table column, using the spatial index for the column.
The index must exist or the query may return nothing.
"""
name = 'Geometry_ColumnIntersectsBbox'
inherit_cache = True
@compiles(Geometry_ColumnIntersectsBbox) # type: ignore[no-untyped-call, misc]
def default_intersects_column(element: Geometry_ColumnIntersectsBbox,
compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
@compiles(Geometry_ColumnIntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
def spatialite_intersects_column(element: Geometry_ColumnIntersectsBbox,
compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "MbrIntersects(%s, %s) = 1 and "\
"%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
"WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
"AND search_frame = %s)" %(
compiler.process(arg1, **kw),
compiler.process(arg2, **kw),
arg1.table.name, arg1.table.name, arg1.name,
compiler.process(arg2, **kw))
class Geometry_ColumnDWithin(sa.sql.expression.FunctionElement[Any]):
""" Check if the geometry is within the distance of the
given table column, using the spatial index for the column.
The index must exist or the query may return nothing.
"""
name = 'Geometry_ColumnDWithin'
inherit_cache = True
@compiles(Geometry_ColumnDWithin) # type: ignore[no-untyped-call, misc]
def default_dwithin_column(element: Geometry_ColumnDWithin,
compiler: 'sa.Compiled', **kw: Any) -> str:
return "ST_DWithin(%s)" % compiler.process(element.clauses, **kw)
@compiles(Geometry_ColumnDWithin, 'sqlite') # type: ignore[no-untyped-call, misc]
def spatialite_dwithin_column(element: Geometry_ColumnDWithin,
compiler: 'sa.Compiled', **kw: Any) -> str:
geom1, geom2, dist = list(element.clauses)
return "ST_Distance(%s, %s) < %s and "\
"%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
"WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
"AND search_frame = ST_Expand(%s, %s))" %(
compiler.process(geom1, **kw),
compiler.process(geom2, **kw),
compiler.process(dist, **kw),
geom1.table.name, geom1.table.name, geom1.name,
compiler.process(geom2, **kw),
compiler.process(dist, **kw))
class Geometry(types.UserDefinedType): # type: ignore[type-arg]
""" Simplified type decorator for PostGIS geometry. This type
only supports geometries in 4326 projection.
"""
cache_ok = True
def __init__(self, subtype: str = 'Geometry'):
self.subtype = subtype
def get_col_spec(self) -> str:
return f'GEOMETRY({self.subtype}, 4326)'
def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]:
def process(value: Any) -> str:
if isinstance(value, str):
return value
return cast(str, value.to_wkt())
return process
def result_processor(self, dialect: 'sa.Dialect', coltype: object) -> Callable[[Any], str]:
def process(value: Any) -> str:
assert isinstance(value, str)
return value
return process
def column_expression(self, col: SaColumn) -> SaColumn:
return sa.func.ST_AsEWKB(col)
def bind_expression(self, bindvalue: SaBind) -> SaColumn:
return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self)
class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
def intersects(self, other: SaColumn, use_index: bool = True) -> 'sa.Operators':
if not use_index:
return Geometry_IntersectsBbox(sa.func.coalesce(sa.null(), self.expr), other)
if isinstance(self.expr, sa.Column):
return Geometry_ColumnIntersectsBbox(self.expr, other)
return Geometry_IntersectsBbox(self.expr, other)
def is_line_like(self) -> SaColumn:
return Geometry_IsLineLike(self)
def is_area(self) -> SaColumn:
return Geometry_IsAreaLike(self)
def within_distance(self, other: SaColumn, distance: SaColumn) -> SaColumn:
if isinstance(self.expr, sa.Column):
return Geometry_ColumnDWithin(self.expr, other, distance)
return self.ST_Distance(other) < distance
def ST_Distance(self, other: SaColumn) -> SaColumn:
return sa.func.ST_Distance(self, other, type_=sa.Float)
def ST_Contains(self, other: SaColumn) -> SaColumn:
return sa.func.ST_Contains(self, other, type_=sa.Boolean)
def ST_CoveredBy(self, other: SaColumn) -> SaColumn:
return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean)
def ST_ClosestPoint(self, other: SaColumn) -> SaColumn:
return sa.func.coalesce(sa.func.ST_ClosestPoint(self, other, type_=Geometry),
other)
def ST_Buffer(self, other: SaColumn) -> SaColumn:
return sa.func.ST_Buffer(self, other, type_=Geometry)
def ST_Expand(self, other: SaColumn) -> SaColumn:
return sa.func.ST_Expand(self, other, type_=Geometry)
def ST_Collect(self) -> SaColumn:
return sa.func.ST_Collect(self, type_=Geometry)
def ST_Centroid(self) -> SaColumn:
return sa.func.ST_Centroid(self, type_=Geometry)
def ST_LineInterpolatePoint(self, other: SaColumn) -> SaColumn:
return sa.func.ST_LineInterpolatePoint(self, other, type_=Geometry)
def ST_LineLocatePoint(self, other: SaColumn) -> SaColumn:
return sa.func.ST_LineLocatePoint(self, other, type_=sa.Float)
def distance_spheroid(self, other: SaColumn) -> SaColumn:
return Geometry_DistanceSpheroid(self, other)
@compiles(Geometry, 'sqlite') # type: ignore[no-untyped-call]
def get_col_spec(self, *args, **kwargs): # type: ignore[no-untyped-def]
return 'GEOMETRY'
SQLITE_FUNCTION_ALIAS = (
('ST_AsEWKB', sa.Text, 'AsEWKB'),
('ST_GeomFromEWKT', Geometry, 'GeomFromEWKT'),
('ST_AsGeoJSON', sa.Text, 'AsGeoJSON'),
('ST_AsKML', sa.Text, 'AsKML'),
('ST_AsSVG', sa.Text, 'AsSVG'),
('ST_LineLocatePoint', sa.Float, 'ST_Line_Locate_Point'),
('ST_LineInterpolatePoint', sa.Float, 'ST_Line_Interpolate_Point'),
)
def _add_function_alias(func: str, ftype: type, alias: str) -> None:
_FuncDef = type(func, (sa.sql.functions.GenericFunction, ), {
"type": ftype(),
"name": func,
"identifier": func,
"inherit_cache": True})
func_templ = f"{alias}(%s)"
def _sqlite_impl(element: Any, compiler: Any, **kw: Any) -> Any:
return func_templ % compiler.process(element.clauses, **kw)
compiles(_FuncDef, 'sqlite')(_sqlite_impl) # type: ignore[no-untyped-call]
for alias in SQLITE_FUNCTION_ALIAS:
_add_function_alias(*alias)

View File

@@ -0,0 +1,123 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Custom type for an array of integers.
"""
from typing import Any, List, cast, Optional
import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.dialects.postgresql import ARRAY
from ...typing import SaDialect, SaColumn
# pylint: disable=all
class IntList(sa.types.TypeDecorator[Any]):
""" A list of integers saved as a text of comma-separated numbers.
"""
impl = sa.types.Unicode
cache_ok = True
def process_bind_param(self, value: Optional[Any], dialect: 'sa.Dialect') -> Optional[str]:
if value is None:
return None
assert isinstance(value, list)
return ','.join(map(str, value))
def process_result_value(self, value: Optional[Any],
dialect: SaDialect) -> Optional[List[int]]:
return [int(v) for v in value.split(',')] if value is not None else None
def copy(self, **kw: Any) -> 'IntList':
return IntList(self.impl.length)
class IntArray(sa.types.TypeDecorator[Any]):
""" Dialect-independent list of integers.
"""
impl = IntList
cache_ok = True
def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
if dialect.name == 'postgresql':
return ARRAY(sa.Integer()) #pylint: disable=invalid-name
return IntList()
class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
def __add__(self, other: SaColumn) -> 'sa.ColumnOperators':
""" Concate the array with the given array. If one of the
operants is null, the value of the other will be returned.
"""
return ArrayCat(self.expr, other)
def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators':
""" Return true if the array contains all the value of the argument
array.
"""
return ArrayContains(self.expr, other)
class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
""" Aggregate function to collect elements in an array.
"""
type = IntArray()
identifier = 'ArrayAgg'
name = 'array_agg'
inherit_cache = True
@compiles(ArrayAgg, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_array_agg(element: ArrayAgg, compiler: 'sa.Compiled', **kw: Any) -> str:
return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw)
class ArrayContains(sa.sql.expression.FunctionElement[Any]):
""" Function to check if an array is fully contained in another.
"""
name = 'ArrayContains'
inherit_cache = True
@compiles(ArrayContains) # type: ignore[no-untyped-call, misc]
def generic_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "(%s @> %s)" % (compiler.process(arg1, **kw),
compiler.process(arg2, **kw))
@compiles(ArrayContains, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
return "array_contains(%s)" % compiler.process(element.clauses, **kw)
class ArrayCat(sa.sql.expression.FunctionElement[Any]):
""" Function to check if an array is fully contained in another.
"""
type = IntArray()
identifier = 'ArrayCat'
inherit_cache = True
@compiles(ArrayCat) # type: ignore[no-untyped-call, misc]
def generic_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
return "array_cat(%s)" % compiler.process(element.clauses, **kw)
@compiles(ArrayCat, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "(%s || ',' || %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))

View File

@@ -0,0 +1,30 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Common json type for different dialects.
"""
from typing import Any
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.dialects.sqlite import JSON as sqlite_json
from ...typing import SaDialect
# pylint: disable=all
class Json(sa.types.TypeDecorator[Any]):
""" Dialect-independent type for JSON.
"""
impl = sa.types.JSON
cache_ok = True
def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
if dialect.name == 'postgresql':
return JSONB(none_as_null=True) # type: ignore[no-untyped-call]
return sqlite_json(none_as_null=True)

View File

@@ -0,0 +1,62 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
A custom type that implements a simple key-value store of strings.
"""
from typing import Any
import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.dialects.postgresql import HSTORE
from sqlalchemy.dialects.sqlite import JSON as sqlite_json
from ...typing import SaDialect, SaColumn
# pylint: disable=all
class KeyValueStore(sa.types.TypeDecorator[Any]):
""" Dialect-independent type of a simple key-value store of strings.
"""
impl = HSTORE
cache_ok = True
def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
if dialect.name == 'postgresql':
return HSTORE() # type: ignore[no-untyped-call]
return sqlite_json(none_as_null=True)
class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
def merge(self, other: SaColumn) -> 'sa.Operators':
""" Merge the values from the given KeyValueStore into this
one, overwriting values where necessary. When the argument
is null, nothing happens.
"""
return KeyValueConcat(self.expr, other)
class KeyValueConcat(sa.sql.expression.FunctionElement[Any]):
""" Return the merged key-value store from the input parameters.
"""
type = KeyValueStore()
name = 'JsonConcat'
inherit_cache = True
@compiles(KeyValueConcat) # type: ignore[no-untyped-call, misc]
def default_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "(%s || coalesce(%s, ''::hstore))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
@compiles(KeyValueConcat, 'sqlite') # type: ignore[no-untyped-call, misc]
def sqlite_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
arg1, arg2 = list(element.clauses)
return "json_patch(%s, coalesce(%s, '{}'))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))

View File

@@ -0,0 +1,127 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Access and helper functions for the status and status log table.
"""
from typing import Optional, Tuple, cast
import datetime as dt
import logging
import re
from .connection import Connection
from ..utils.url_utils import get_url
from ..errors import UsageError
from ..typing import TypedDict
LOG = logging.getLogger()
ISODATE_FORMAT = '%Y-%m-%dT%H:%M:%S'
class StatusRow(TypedDict):
""" Dictionary of columns of the import_status table.
"""
lastimportdate: dt.datetime
sequence_id: Optional[int]
indexed: Optional[bool]
def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetime:
""" Determine the date of the database from the newest object in the
data base.
"""
# If there is a date from osm2pgsql available, use that.
if conn.table_exists('osm2pgsql_properties'):
with conn.cursor() as cur:
cur.execute(""" SELECT value FROM osm2pgsql_properties
WHERE property = 'current_timestamp' """)
row = cur.fetchone()
if row is not None:
return dt.datetime.strptime(row[0], "%Y-%m-%dT%H:%M:%SZ")\
.replace(tzinfo=dt.timezone.utc)
if offline:
raise UsageError("Cannot determine database date from data in offline mode.")
# Else, find the node with the highest ID in the database
with conn.cursor() as cur:
if conn.table_exists('place'):
osmid = cur.scalar("SELECT max(osm_id) FROM place WHERE osm_type='N'")
else:
osmid = cur.scalar("SELECT max(osm_id) FROM placex WHERE osm_type='N'")
if osmid is None:
LOG.fatal("No data found in the database.")
raise UsageError("No data found in the database.")
LOG.info("Using node id %d for timestamp lookup", osmid)
# Get the node from the API to find the timestamp when it was created.
node_url = f'https://www.openstreetmap.org/api/0.6/node/{osmid}/1'
data = get_url(node_url)
match = re.search(r'timestamp="((\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2}):(\d{2}))Z"', data)
if match is None:
LOG.fatal("The node data downloaded from the API does not contain valid data.\n"
"URL used: %s", node_url)
raise UsageError("Bad API data.")
LOG.debug("Found timestamp %s", match.group(1))
return dt.datetime.strptime(match.group(1), ISODATE_FORMAT).replace(tzinfo=dt.timezone.utc)
def set_status(conn: Connection, date: Optional[dt.datetime],
seq: Optional[int] = None, indexed: bool = True) -> None:
""" Replace the current status with the given status. If date is `None`
then only sequence and indexed will be updated as given. Otherwise
the whole status is replaced.
The change will be committed to the database.
"""
assert date is None or date.tzinfo == dt.timezone.utc
with conn.cursor() as cur:
if date is None:
cur.execute("UPDATE import_status set sequence_id = %s, indexed = %s",
(seq, indexed))
else:
cur.execute("TRUNCATE TABLE import_status")
cur.execute("""INSERT INTO import_status (lastimportdate, sequence_id, indexed)
VALUES (%s, %s, %s)""", (date, seq, indexed))
conn.commit()
def get_status(conn: Connection) -> Tuple[Optional[dt.datetime], Optional[int], Optional[bool]]:
""" Return the current status as a triple of (date, sequence, indexed).
If status has not been set up yet, a triple of None is returned.
"""
with conn.cursor() as cur:
cur.execute("SELECT * FROM import_status LIMIT 1")
if cur.rowcount < 1:
return None, None, None
row = cast(StatusRow, cur.fetchone())
return row['lastimportdate'], row['sequence_id'], row['indexed']
def set_indexed(conn: Connection, state: bool) -> None:
""" Set the indexed flag in the status table to the given state.
"""
with conn.cursor() as cur:
cur.execute("UPDATE import_status SET indexed = %s", (state, ))
conn.commit()
def log_status(conn: Connection, start: dt.datetime,
event: str, batchsize: Optional[int] = None) -> None:
""" Write a new status line to the `import_osmosis_log` table.
"""
with conn.cursor() as cur:
cur.execute("""INSERT INTO import_osmosis_log
(batchend, batchseq, batchsize, starttime, endtime, event)
SELECT lastimportdate, sequence_id, %s, %s, now(), %s FROM import_status""",
(batchsize, start, event))
conn.commit()

View File

@@ -0,0 +1,129 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Helper functions for handling DB accesses.
"""
from typing import IO, Optional, Union, Any, Iterable
import subprocess
import logging
import gzip
import io
from pathlib import Path
from .connection import get_pg_env, Cursor
from ..errors import UsageError
LOG = logging.getLogger()
def _pipe_to_proc(proc: 'subprocess.Popen[bytes]',
fdesc: Union[IO[bytes], gzip.GzipFile]) -> int:
assert proc.stdin is not None
chunk = fdesc.read(2048)
while chunk and proc.poll() is None:
try:
proc.stdin.write(chunk)
except BrokenPipeError as exc:
raise UsageError("Failed to execute SQL file.") from exc
chunk = fdesc.read(2048)
return len(chunk)
def execute_file(dsn: str, fname: Path,
ignore_errors: bool = False,
pre_code: Optional[str] = None,
post_code: Optional[str] = None) -> None:
""" Read an SQL file and run its contents against the given database
using psql. Use `pre_code` and `post_code` to run extra commands
before or after executing the file. The commands are run within the
same session, so they may be used to wrap the file execution in a
transaction.
"""
cmd = ['psql']
if not ignore_errors:
cmd.extend(('-v', 'ON_ERROR_STOP=1'))
if not LOG.isEnabledFor(logging.INFO):
cmd.append('--quiet')
with subprocess.Popen(cmd, env=get_pg_env(dsn), stdin=subprocess.PIPE) as proc:
assert proc.stdin is not None
try:
if not LOG.isEnabledFor(logging.INFO):
proc.stdin.write('set client_min_messages to WARNING;'.encode('utf-8'))
if pre_code:
proc.stdin.write((pre_code + ';').encode('utf-8'))
if fname.suffix == '.gz':
with gzip.open(str(fname), 'rb') as fdesc:
remain = _pipe_to_proc(proc, fdesc)
else:
with fname.open('rb') as fdesc:
remain = _pipe_to_proc(proc, fdesc)
if remain == 0 and post_code:
proc.stdin.write((';' + post_code).encode('utf-8'))
finally:
proc.stdin.close()
ret = proc.wait()
if ret != 0 or remain > 0:
raise UsageError("Failed to execute SQL file.")
# List of characters that need to be quoted for the copy command.
_SQL_TRANSLATION = {ord('\\'): '\\\\',
ord('\t'): '\\t',
ord('\n'): '\\n'}
class CopyBuffer:
""" Data collector for the copy_from command.
"""
def __init__(self) -> None:
self.buffer = io.StringIO()
def __enter__(self) -> 'CopyBuffer':
return self
def size(self) -> int:
""" Return the number of bytes the buffer currently contains.
"""
return self.buffer.tell()
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self.buffer is not None:
self.buffer.close()
def add(self, *data: Any) -> None:
""" Add another row of data to the copy buffer.
"""
first = True
for column in data:
if first:
first = False
else:
self.buffer.write('\t')
if column is None:
self.buffer.write('\\N')
else:
self.buffer.write(str(column).translate(_SQL_TRANSLATION))
self.buffer.write('\n')
def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None:
""" Copy all collected data into the given table.
The buffer is empty and reusable after this operation.
"""
if self.buffer.tell() > 0:
self.buffer.seek(0)
cur.copy_from(self.buffer, table, columns=columns)
self.buffer = io.StringIO()

View File

@@ -0,0 +1,14 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Custom exception and error classes for Nominatim.
"""
class UsageError(Exception):
""" An error raised because of bad user input. This error will usually
not cause a stack trace to be printed unless debugging is enabled.
"""

View File

@@ -0,0 +1,15 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Path settings for extra data used by Nominatim.
"""
from pathlib import Path
PHPLIB_DIR = (Path(__file__) / '..' / '..' / '..' / 'lib-php').resolve()
SQLLIB_DIR = (Path(__file__) / '..' / '..' / '..' / 'lib-sql').resolve()
DATA_DIR = (Path(__file__) / '..' / '..' / '..' / 'data').resolve()
CONFIG_DIR = (Path(__file__) / '..' / '..' / '..' / 'settings').resolve()

View File

View File

@@ -0,0 +1,75 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Type definitions for typing annotations.
Complex type definitions are moved here, to keep the source files readable.
"""
from typing import Any, Union, Mapping, TypeVar, Sequence, TYPE_CHECKING
# Generics variable names do not confirm to naming styles, ignore globally here.
# pylint: disable=invalid-name,abstract-method,multiple-statements
# pylint: disable=missing-class-docstring,useless-import-alias
if TYPE_CHECKING:
import psycopg2.sql
import psycopg2.extensions
import psycopg2.extras
import os
StrPath = Union[str, 'os.PathLike[str]']
SysEnv = Mapping[str, str]
# psycopg2-related types
Query = Union[str, bytes, 'psycopg2.sql.Composable']
T_ResultKey = TypeVar('T_ResultKey', int, str)
class DictCursorResult(Mapping[str, Any]):
def __getitem__(self, x: Union[int, str]) -> Any: ...
DictCursorResults = Sequence[DictCursorResult]
T_cursor = TypeVar('T_cursor', bound='psycopg2.extensions.cursor')
# The following typing features require typing_extensions to work
# on all supported Python versions.
# Only require this for type checking but not for normal operations.
if TYPE_CHECKING:
from typing_extensions import (Protocol as Protocol,
Final as Final,
TypedDict as TypedDict)
else:
Protocol = object
Final = 'Final'
TypedDict = dict
# SQLAlchemy introduced generic types in version 2.0 making typing
# incompatible with older versions. Add wrappers here so we don't have
# to litter the code with bare-string types.
if TYPE_CHECKING:
import sqlalchemy as sa
from typing_extensions import (TypeAlias as TypeAlias)
else:
TypeAlias = str
SaLambdaSelect: TypeAlias = 'Union[sa.Select[Any], sa.StatementLambdaElement]'
SaSelect: TypeAlias = 'sa.Select[Any]'
SaScalarSelect: TypeAlias = 'sa.ScalarSelect[Any]'
SaRow: TypeAlias = 'sa.Row[Any]'
SaColumn: TypeAlias = 'sa.ColumnElement[Any]'
SaExpression: TypeAlias = 'sa.ColumnElement[bool]'
SaLabel: TypeAlias = 'sa.Label[Any]'
SaFromClause: TypeAlias = 'sa.FromClause'
SaSelectable: TypeAlias = 'sa.Selectable'
SaBind: TypeAlias = 'sa.BindParameter[Any]'
SaDialect: TypeAlias = 'sa.Dialect'

View File

View File

@@ -0,0 +1,49 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Functions for computation of centroids.
"""
from typing import Tuple, Any
from collections.abc import Collection
class PointsCentroid:
""" Centroid computation from single points using an online algorithm.
More points may be added at any time.
Coordinates are internally treated as a 7-digit fixed-point float
(i.e. in OSM style).
"""
def __init__(self) -> None:
self.sum_x = 0
self.sum_y = 0
self.count = 0
def centroid(self) -> Tuple[float, float]:
""" Return the centroid of all points collected so far.
"""
if self.count == 0:
raise ValueError("No points available for centroid.")
return (float(self.sum_x/self.count)/10000000,
float(self.sum_y/self.count)/10000000)
def __len__(self) -> int:
return self.count
def __iadd__(self, other: Any) -> 'PointsCentroid':
if isinstance(other, Collection) and len(other) == 2:
if all(isinstance(p, (float, int)) for p in other):
x, y = other
self.sum_x += int(x * 10000000)
self.sum_y += int(y * 10000000)
self.count += 1
return self
raise ValueError("Can only add 2-element tuples to centroid.")

View File

@@ -0,0 +1,149 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Streaming JSON encoder.
"""
from typing import Any, TypeVar, Optional, Callable
import io
try:
import ujson as json
except ModuleNotFoundError:
import json # type: ignore[no-redef]
T = TypeVar('T') # pylint: disable=invalid-name
class JsonWriter:
""" JSON encoder that renders the output directly into an output
stream. This is a very simple writer which produces JSON in a
compact as possible form.
The writer does not check for syntactic correctness. It is the
responsibility of the caller to call the write functions in an
order that produces correct JSON.
All functions return the writer object itself so that function
calls can be chained.
"""
def __init__(self) -> None:
self.data = io.StringIO()
self.pending = ''
def __call__(self) -> str:
""" Return the rendered JSON content as a string.
The writer remains usable after calling this function.
"""
if self.pending:
assert self.pending in (']', '}')
self.data.write(self.pending)
self.pending = ''
return self.data.getvalue()
def start_object(self) -> 'JsonWriter':
""" Write the open bracket of a JSON object.
"""
if self.pending:
self.data.write(self.pending)
self.pending = '{'
return self
def end_object(self) -> 'JsonWriter':
""" Write the closing bracket of a JSON object.
"""
assert self.pending in (',', '{', '')
if self.pending == '{':
self.data.write(self.pending)
self.pending = '}'
return self
def start_array(self) -> 'JsonWriter':
""" Write the opening bracket of a JSON array.
"""
if self.pending:
self.data.write(self.pending)
self.pending = '['
return self
def end_array(self) -> 'JsonWriter':
""" Write the closing bracket of a JSON array.
"""
assert self.pending in (',', '[', ']', ')', '')
if self.pending not in (',', ''):
self.data.write(self.pending)
self.pending = ']'
return self
def key(self, name: str) -> 'JsonWriter':
""" Write the key string of a JSON object.
"""
assert self.pending
self.data.write(self.pending)
self.data.write(json.dumps(name, ensure_ascii=False))
self.pending = ':'
return self
def value(self, value: Any) -> 'JsonWriter':
""" Write out a value as JSON. The function uses the json.dumps()
function for encoding the JSON. Thus any value that can be
encoded by that function is permissible here.
"""
return self.raw(json.dumps(value, ensure_ascii=False))
def float(self, value: float, precision: int) -> 'JsonWriter':
""" Write out a float value with the given precision.
"""
return self.raw(f"{value:0.{precision}f}")
def next(self) -> 'JsonWriter':
""" Write out a delimiter comma between JSON object or array elements.
"""
if self.pending:
self.data.write(self.pending)
self.pending = ','
return self
def raw(self, raw_json: str) -> 'JsonWriter':
""" Write out the given value as is. This function is useful if
a value is already available in JSON format.
"""
if self.pending:
self.data.write(self.pending)
self.pending = ''
self.data.write(raw_json)
return self
def keyval(self, key: str, value: Any) -> 'JsonWriter':
""" Write out an object element with the given key and value.
This is a shortcut for calling 'key()', 'value()' and 'next()'.
"""
self.key(key)
self.value(value)
return self.next()
def keyval_not_none(self, key: str, value: Optional[T],
transform: Optional[Callable[[T], Any]] = None) -> 'JsonWriter':
""" Write out an object element only if the value is not None.
If 'transform' is given, it must be a function that takes the
value type and returns a JSON encodable type. The transform
function will be called before the value is written out.
"""
if value is not None:
self.key(key)
self.value(transform(value) if transform else value)
self.next()
return self

View File

@@ -0,0 +1,31 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Helper functions for accessing URL.
"""
from typing import IO
import logging
import urllib.request as urlrequest
from ..version import NOMINATIM_CORE_VERSION
LOG = logging.getLogger()
def get_url(url: str) -> str:
""" Get the contents from the given URL and return it as a UTF-8 string.
This version makes sure that an appropriate user agent is sent.
"""
headers = {"User-Agent": f"Nominatim/{NOMINATIM_CORE_VERSION!s}"}
try:
request = urlrequest.Request(url, headers=headers)
with urlrequest.urlopen(request) as response: # type: IO[bytes]
return response.read().decode('utf-8')
except Exception:
LOG.fatal('Failed to load URL: %s', url)
raise

View File

@@ -0,0 +1,11 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Version information for the Nominatim core package.
"""
NOMINATIM_CORE_VERSION = '4.4.99'