forked from hans/Nominatim
split code into submodules
This commit is contained in:
0
src/nominatim_core/__init__.py
Normal file
0
src/nominatim_core/__init__.py
Normal file
374
src/nominatim_core/config.py
Normal file
374
src/nominatim_core/config.py
Normal file
@@ -0,0 +1,374 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2022 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Nominatim configuration accessor.
|
||||
"""
|
||||
from typing import Dict, Any, List, Mapping, Optional
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import json
|
||||
import yaml
|
||||
|
||||
from dotenv import dotenv_values
|
||||
from psycopg2.extensions import parse_dsn
|
||||
|
||||
from .typing import StrPath
|
||||
from .errors import UsageError
|
||||
from . import paths
|
||||
|
||||
LOG = logging.getLogger()
|
||||
CONFIG_CACHE : Dict[str, Any] = {}
|
||||
|
||||
def flatten_config_list(content: Any, section: str = '') -> List[Any]:
|
||||
""" Flatten YAML configuration lists that contain include sections
|
||||
which are lists themselves.
|
||||
"""
|
||||
if not content:
|
||||
return []
|
||||
|
||||
if not isinstance(content, list):
|
||||
raise UsageError(f"List expected in section '{section}'.")
|
||||
|
||||
output = []
|
||||
for ele in content:
|
||||
if isinstance(ele, list):
|
||||
output.extend(flatten_config_list(ele, section))
|
||||
else:
|
||||
output.append(ele)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Configuration:
|
||||
""" This class wraps access to the configuration settings
|
||||
for the Nominatim instance in use.
|
||||
|
||||
All Nominatim configuration options are prefixed with 'NOMINATIM_' to
|
||||
avoid conflicts with other environment variables. All settings can
|
||||
be accessed as properties of the class under the same name as the
|
||||
setting but with the `NOMINATIM_` prefix removed. In addition, there
|
||||
are accessor functions that convert the setting values to types
|
||||
other than string.
|
||||
"""
|
||||
|
||||
def __init__(self, project_dir: Optional[Path],
|
||||
environ: Optional[Mapping[str, str]] = None) -> None:
|
||||
self.environ = environ or os.environ
|
||||
self.project_dir = project_dir
|
||||
self.config_dir = paths.CONFIG_DIR
|
||||
self._config = dotenv_values(str(self.config_dir / 'env.defaults'))
|
||||
if self.project_dir is not None and (self.project_dir / '.env').is_file():
|
||||
self.project_dir = self.project_dir.resolve()
|
||||
self._config.update(dotenv_values(str(self.project_dir / '.env')))
|
||||
|
||||
class _LibDirs:
|
||||
module: Path
|
||||
osm2pgsql: Path
|
||||
php = paths.PHPLIB_DIR
|
||||
sql = paths.SQLLIB_DIR
|
||||
data = paths.DATA_DIR
|
||||
|
||||
self.lib_dir = _LibDirs()
|
||||
self._private_plugins: Dict[str, object] = {}
|
||||
|
||||
|
||||
def set_libdirs(self, **kwargs: StrPath) -> None:
|
||||
""" Set paths to library functions and data.
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
setattr(self.lib_dir, key, None if value is None else Path(value))
|
||||
|
||||
|
||||
def __getattr__(self, name: str) -> str:
|
||||
name = 'NOMINATIM_' + name
|
||||
|
||||
if name in self.environ:
|
||||
return self.environ[name]
|
||||
|
||||
return self._config[name] or ''
|
||||
|
||||
|
||||
def get_bool(self, name: str) -> bool:
|
||||
""" Return the given configuration parameter as a boolean.
|
||||
|
||||
Parameters:
|
||||
name: Name of the configuration parameter with the NOMINATIM_
|
||||
prefix removed.
|
||||
|
||||
Returns:
|
||||
`True` for values of '1', 'yes' and 'true', `False` otherwise.
|
||||
"""
|
||||
return getattr(self, name).lower() in ('1', 'yes', 'true')
|
||||
|
||||
|
||||
def get_int(self, name: str) -> int:
|
||||
""" Return the given configuration parameter as an int.
|
||||
|
||||
Parameters:
|
||||
name: Name of the configuration parameter with the NOMINATIM_
|
||||
prefix removed.
|
||||
|
||||
Returns:
|
||||
The configuration value converted to int.
|
||||
|
||||
Raises:
|
||||
ValueError: when the value is not a number.
|
||||
"""
|
||||
try:
|
||||
return int(getattr(self, name))
|
||||
except ValueError as exp:
|
||||
LOG.fatal("Invalid setting NOMINATIM_%s. Needs to be a number.", name)
|
||||
raise UsageError("Configuration error.") from exp
|
||||
|
||||
|
||||
def get_str_list(self, name: str) -> Optional[List[str]]:
|
||||
""" Return the given configuration parameter as a list of strings.
|
||||
The values are assumed to be given as a comma-sparated list and
|
||||
will be stripped before returning them.
|
||||
|
||||
Parameters:
|
||||
name: Name of the configuration parameter with the NOMINATIM_
|
||||
prefix removed.
|
||||
|
||||
Returns:
|
||||
(List[str]): The comma-split parameter as a list. The
|
||||
elements are stripped of leading and final spaces before
|
||||
being returned.
|
||||
(None): The configuration parameter was unset or empty.
|
||||
"""
|
||||
raw = getattr(self, name)
|
||||
|
||||
return [v.strip() for v in raw.split(',')] if raw else None
|
||||
|
||||
|
||||
def get_path(self, name: str) -> Optional[Path]:
|
||||
""" Return the given configuration parameter as a Path.
|
||||
|
||||
Parameters:
|
||||
name: Name of the configuration parameter with the NOMINATIM_
|
||||
prefix removed.
|
||||
|
||||
Returns:
|
||||
(Path): A Path object of the parameter value.
|
||||
If a relative path is configured, then the function converts this
|
||||
into an absolute path with the project directory as root path.
|
||||
(None): The configuration parameter was unset or empty.
|
||||
"""
|
||||
value = getattr(self, name)
|
||||
if not value:
|
||||
return None
|
||||
|
||||
cfgpath = Path(value)
|
||||
|
||||
if not cfgpath.is_absolute():
|
||||
assert self.project_dir is not None
|
||||
cfgpath = self.project_dir / cfgpath
|
||||
|
||||
return cfgpath.resolve()
|
||||
|
||||
|
||||
def get_libpq_dsn(self) -> str:
|
||||
""" Get configured database DSN converted into the key/value format
|
||||
understood by libpq and psycopg.
|
||||
"""
|
||||
dsn = self.DATABASE_DSN
|
||||
|
||||
def quote_param(param: str) -> str:
|
||||
key, val = param.split('=')
|
||||
val = val.replace('\\', '\\\\').replace("'", "\\'")
|
||||
if ' ' in val:
|
||||
val = "'" + val + "'"
|
||||
return key + '=' + val
|
||||
|
||||
if dsn.startswith('pgsql:'):
|
||||
# Old PHP DSN format. Convert before returning.
|
||||
return ' '.join([quote_param(p) for p in dsn[6:].split(';')])
|
||||
|
||||
return dsn
|
||||
|
||||
|
||||
def get_database_params(self) -> Mapping[str, str]:
|
||||
""" Get the configured parameters for the database connection
|
||||
as a mapping.
|
||||
"""
|
||||
dsn = self.DATABASE_DSN
|
||||
|
||||
if dsn.startswith('pgsql:'):
|
||||
return dict((p.split('=', 1) for p in dsn[6:].split(';')))
|
||||
|
||||
return parse_dsn(dsn)
|
||||
|
||||
|
||||
def get_import_style_file(self) -> Path:
|
||||
""" Return the import style file as a path object. Translates the
|
||||
name of the standard styles automatically into a file in the
|
||||
config style.
|
||||
"""
|
||||
style = getattr(self, 'IMPORT_STYLE')
|
||||
|
||||
if style in ('admin', 'street', 'address', 'full', 'extratags'):
|
||||
return self.config_dir / f'import-{style}.lua'
|
||||
|
||||
return self.find_config_file('', 'IMPORT_STYLE')
|
||||
|
||||
|
||||
def get_os_env(self) -> Dict[str, str]:
|
||||
""" Return a copy of the OS environment with the Nominatim configuration
|
||||
merged in.
|
||||
"""
|
||||
env = {k: v for k, v in self._config.items() if v is not None}
|
||||
env.update(self.environ)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def load_sub_configuration(self, filename: StrPath,
|
||||
config: Optional[str] = None) -> Any:
|
||||
""" Load additional configuration from a file. `filename` is the name
|
||||
of the configuration file. The file is first searched in the
|
||||
project directory and then in the global settings directory.
|
||||
|
||||
If `config` is set, then the name of the configuration file can
|
||||
be additionally given through a .env configuration option. When
|
||||
the option is set, then the file will be exclusively loaded as set:
|
||||
if the name is an absolute path, the file name is taken as is,
|
||||
if the name is relative, it is taken to be relative to the
|
||||
project directory.
|
||||
|
||||
The format of the file is determined from the filename suffix.
|
||||
Currently only files with extension '.yaml' are supported.
|
||||
|
||||
YAML files support a special '!include' construct. When the
|
||||
directive is given, the value is taken to be a filename, the file
|
||||
is loaded using this function and added at the position in the
|
||||
configuration tree.
|
||||
"""
|
||||
configfile = self.find_config_file(filename, config)
|
||||
|
||||
if str(configfile) in CONFIG_CACHE:
|
||||
return CONFIG_CACHE[str(configfile)]
|
||||
|
||||
if configfile.suffix in ('.yaml', '.yml'):
|
||||
result = self._load_from_yaml(configfile)
|
||||
elif configfile.suffix == '.json':
|
||||
with configfile.open('r', encoding='utf-8') as cfg:
|
||||
result = json.load(cfg)
|
||||
else:
|
||||
raise UsageError(f"Config file '{configfile}' has unknown format.")
|
||||
|
||||
CONFIG_CACHE[str(configfile)] = result
|
||||
return result
|
||||
|
||||
|
||||
def load_plugin_module(self, module_name: str, internal_path: str) -> Any:
|
||||
""" Load a Python module as a plugin.
|
||||
|
||||
The module_name may have three variants:
|
||||
|
||||
* A name without any '.' is assumed to be an internal module
|
||||
and will be searched relative to `internal_path`.
|
||||
* If the name ends in `.py`, module_name is assumed to be a
|
||||
file name relative to the project directory.
|
||||
* Any other name is assumed to be an absolute module name.
|
||||
|
||||
In either of the variants the module name must start with a letter.
|
||||
"""
|
||||
if not module_name or not module_name[0].isidentifier():
|
||||
raise UsageError(f'Invalid module name {module_name}')
|
||||
|
||||
if '.' not in module_name:
|
||||
module_name = module_name.replace('-', '_')
|
||||
full_module = f'{internal_path}.{module_name}'
|
||||
return sys.modules.get(full_module) or importlib.import_module(full_module)
|
||||
|
||||
if module_name.endswith('.py'):
|
||||
if self.project_dir is None or not (self.project_dir / module_name).exists():
|
||||
raise UsageError(f"Cannot find module '{module_name}' in project directory.")
|
||||
|
||||
if module_name in self._private_plugins:
|
||||
return self._private_plugins[module_name]
|
||||
|
||||
file_path = str(self.project_dir / module_name)
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
if spec:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
# Do not add to global modules because there is no standard
|
||||
# module name that Python can resolve.
|
||||
self._private_plugins[module_name] = module
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
return module
|
||||
|
||||
return sys.modules.get(module_name) or importlib.import_module(module_name)
|
||||
|
||||
|
||||
def find_config_file(self, filename: StrPath,
|
||||
config: Optional[str] = None) -> Path:
|
||||
""" Resolve the location of a configuration file given a filename and
|
||||
an optional configuration option with the file name.
|
||||
Raises a UsageError when the file cannot be found or is not
|
||||
a regular file.
|
||||
"""
|
||||
if config is not None:
|
||||
cfg_value = getattr(self, config)
|
||||
if cfg_value:
|
||||
cfg_filename = Path(cfg_value)
|
||||
|
||||
if cfg_filename.is_absolute():
|
||||
cfg_filename = cfg_filename.resolve()
|
||||
|
||||
if not cfg_filename.is_file():
|
||||
LOG.fatal("Cannot find config file '%s'.", cfg_filename)
|
||||
raise UsageError("Config file not found.")
|
||||
|
||||
return cfg_filename
|
||||
|
||||
filename = cfg_filename
|
||||
|
||||
|
||||
search_paths = [self.project_dir, self.config_dir]
|
||||
for path in search_paths:
|
||||
if path is not None and (path / filename).is_file():
|
||||
return path / filename
|
||||
|
||||
LOG.fatal("Configuration file '%s' not found.\nDirectories searched: %s",
|
||||
filename, search_paths)
|
||||
raise UsageError("Config file not found.")
|
||||
|
||||
|
||||
def _load_from_yaml(self, cfgfile: Path) -> Any:
|
||||
""" Load a YAML configuration file. This installs a special handler that
|
||||
allows to include other YAML files using the '!include' operator.
|
||||
"""
|
||||
yaml.add_constructor('!include', self._yaml_include_representer,
|
||||
Loader=yaml.SafeLoader)
|
||||
return yaml.safe_load(cfgfile.read_text(encoding='utf-8'))
|
||||
|
||||
|
||||
def _yaml_include_representer(self, loader: Any, node: yaml.Node) -> Any:
|
||||
""" Handler for the '!include' operator in YAML files.
|
||||
|
||||
When the filename is relative, then the file is first searched in the
|
||||
project directory and then in the global settings directory.
|
||||
"""
|
||||
fname = loader.construct_scalar(node)
|
||||
|
||||
if Path(fname).is_absolute():
|
||||
configfile = Path(fname)
|
||||
else:
|
||||
configfile = self.find_config_file(loader.construct_scalar(node))
|
||||
|
||||
if configfile.suffix != '.yaml':
|
||||
LOG.fatal("Format error while reading '%s': only YAML format supported.",
|
||||
configfile)
|
||||
raise UsageError("Cannot handle config file format.")
|
||||
|
||||
return yaml.safe_load(configfile.read_text(encoding='utf-8'))
|
||||
0
src/nominatim_core/db/__init__.py
Normal file
0
src/nominatim_core/db/__init__.py
Normal file
236
src/nominatim_core/db/async_connection.py
Normal file
236
src/nominatim_core/db/async_connection.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
""" Non-blocking database connections.
|
||||
"""
|
||||
from typing import Callable, Any, Optional, Iterator, Sequence
|
||||
import logging
|
||||
import select
|
||||
import time
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import wait_select
|
||||
|
||||
# psycopg2 emits different exceptions pre and post 2.8. Detect if the new error
|
||||
# module is available and adapt the error handling accordingly.
|
||||
try:
|
||||
import psycopg2.errors # pylint: disable=no-name-in-module,import-error
|
||||
__has_psycopg2_errors__ = True
|
||||
except ImportError:
|
||||
__has_psycopg2_errors__ = False
|
||||
|
||||
from ..typing import T_cursor, Query
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
class DeadlockHandler:
|
||||
""" Context manager that catches deadlock exceptions and calls
|
||||
the given handler function. All other exceptions are passed on
|
||||
normally.
|
||||
"""
|
||||
|
||||
def __init__(self, handler: Callable[[], None], ignore_sql_errors: bool = False) -> None:
|
||||
self.handler = handler
|
||||
self.ignore_sql_errors = ignore_sql_errors
|
||||
|
||||
def __enter__(self) -> 'DeadlockHandler':
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
|
||||
if __has_psycopg2_errors__:
|
||||
if exc_type == psycopg2.errors.DeadlockDetected: # pylint: disable=E1101
|
||||
self.handler()
|
||||
return True
|
||||
elif exc_type == psycopg2.extensions.TransactionRollbackError \
|
||||
and exc_value.pgcode == '40P01':
|
||||
self.handler()
|
||||
return True
|
||||
|
||||
if self.ignore_sql_errors and isinstance(exc_value, psycopg2.Error):
|
||||
LOG.info("SQL error ignored: %s", exc_value)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class DBConnection:
|
||||
""" A single non-blocking database connection.
|
||||
"""
|
||||
|
||||
def __init__(self, dsn: str,
|
||||
cursor_factory: Optional[Callable[..., T_cursor]] = None,
|
||||
ignore_sql_errors: bool = False) -> None:
|
||||
self.dsn = dsn
|
||||
|
||||
self.current_query: Optional[Query] = None
|
||||
self.current_params: Optional[Sequence[Any]] = None
|
||||
self.ignore_sql_errors = ignore_sql_errors
|
||||
|
||||
self.conn: Optional['psycopg2._psycopg.connection'] = None
|
||||
self.cursor: Optional['psycopg2._psycopg.cursor'] = None
|
||||
self.connect(cursor_factory=cursor_factory)
|
||||
|
||||
def close(self) -> None:
|
||||
""" Close all open connections. Does not wait for pending requests.
|
||||
"""
|
||||
if self.conn is not None:
|
||||
if self.cursor is not None:
|
||||
self.cursor.close()
|
||||
self.cursor = None
|
||||
self.conn.close()
|
||||
|
||||
self.conn = None
|
||||
|
||||
def connect(self, cursor_factory: Optional[Callable[..., T_cursor]] = None) -> None:
|
||||
""" (Re)connect to the database. Creates an asynchronous connection
|
||||
with JIT and parallel processing disabled. If a connection was
|
||||
already open, it is closed and a new connection established.
|
||||
The caller must ensure that no query is pending before reconnecting.
|
||||
"""
|
||||
self.close()
|
||||
|
||||
# Use a dict to hand in the parameters because async is a reserved
|
||||
# word in Python3.
|
||||
self.conn = psycopg2.connect(**{'dsn': self.dsn, 'async': True}) # type: ignore
|
||||
assert self.conn
|
||||
self.wait()
|
||||
|
||||
if cursor_factory is not None:
|
||||
self.cursor = self.conn.cursor(cursor_factory=cursor_factory)
|
||||
else:
|
||||
self.cursor = self.conn.cursor()
|
||||
# Disable JIT and parallel workers as they are known to cause problems.
|
||||
# Update pg_settings instead of using SET because it does not yield
|
||||
# errors on older versions of Postgres where the settings are not
|
||||
# implemented.
|
||||
self.perform(
|
||||
""" UPDATE pg_settings SET setting = -1 WHERE name = 'jit_above_cost';
|
||||
UPDATE pg_settings SET setting = 0
|
||||
WHERE name = 'max_parallel_workers_per_gather';""")
|
||||
self.wait()
|
||||
|
||||
def _deadlock_handler(self) -> None:
|
||||
LOG.info("Deadlock detected (params = %s), retry.", str(self.current_params))
|
||||
assert self.cursor is not None
|
||||
assert self.current_query is not None
|
||||
assert self.current_params is not None
|
||||
|
||||
self.cursor.execute(self.current_query, self.current_params)
|
||||
|
||||
def wait(self) -> None:
|
||||
""" Block until any pending operation is done.
|
||||
"""
|
||||
while True:
|
||||
with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
|
||||
wait_select(self.conn)
|
||||
self.current_query = None
|
||||
return
|
||||
|
||||
def perform(self, sql: Query, args: Optional[Sequence[Any]] = None) -> None:
|
||||
""" Send SQL query to the server. Returns immediately without
|
||||
blocking.
|
||||
"""
|
||||
assert self.cursor is not None
|
||||
self.current_query = sql
|
||||
self.current_params = args
|
||||
self.cursor.execute(sql, args)
|
||||
|
||||
def fileno(self) -> int:
|
||||
""" File descriptor to wait for. (Makes this class select()able.)
|
||||
"""
|
||||
assert self.conn is not None
|
||||
return self.conn.fileno()
|
||||
|
||||
def is_done(self) -> bool:
|
||||
""" Check if the connection is available for a new query.
|
||||
|
||||
Also checks if the previous query has run into a deadlock.
|
||||
If so, then the previous query is repeated.
|
||||
"""
|
||||
assert self.conn is not None
|
||||
|
||||
if self.current_query is None:
|
||||
return True
|
||||
|
||||
with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
|
||||
if self.conn.poll() == psycopg2.extensions.POLL_OK:
|
||||
self.current_query = None
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class WorkerPool:
|
||||
""" A pool of asynchronous database connections.
|
||||
|
||||
The pool may be used as a context manager.
|
||||
"""
|
||||
REOPEN_CONNECTIONS_AFTER = 100000
|
||||
|
||||
def __init__(self, dsn: str, pool_size: int, ignore_sql_errors: bool = False) -> None:
|
||||
self.threads = [DBConnection(dsn, ignore_sql_errors=ignore_sql_errors)
|
||||
for _ in range(pool_size)]
|
||||
self.free_workers = self._yield_free_worker()
|
||||
self.wait_time = 0.0
|
||||
|
||||
|
||||
def finish_all(self) -> None:
|
||||
""" Wait for all connection to finish.
|
||||
"""
|
||||
for thread in self.threads:
|
||||
while not thread.is_done():
|
||||
thread.wait()
|
||||
|
||||
self.free_workers = self._yield_free_worker()
|
||||
|
||||
def close(self) -> None:
|
||||
""" Close all connections and clear the pool.
|
||||
"""
|
||||
for thread in self.threads:
|
||||
thread.close()
|
||||
self.threads = []
|
||||
self.free_workers = iter([])
|
||||
|
||||
|
||||
def next_free_worker(self) -> DBConnection:
|
||||
""" Get the next free connection.
|
||||
"""
|
||||
return next(self.free_workers)
|
||||
|
||||
|
||||
def _yield_free_worker(self) -> Iterator[DBConnection]:
|
||||
ready = self.threads
|
||||
command_stat = 0
|
||||
while True:
|
||||
for thread in ready:
|
||||
if thread.is_done():
|
||||
command_stat += 1
|
||||
yield thread
|
||||
|
||||
if command_stat > self.REOPEN_CONNECTIONS_AFTER:
|
||||
self._reconnect_threads()
|
||||
ready = self.threads
|
||||
command_stat = 0
|
||||
else:
|
||||
tstart = time.time()
|
||||
_, ready, _ = select.select([], self.threads, [])
|
||||
self.wait_time += time.time() - tstart
|
||||
|
||||
|
||||
def _reconnect_threads(self) -> None:
|
||||
for thread in self.threads:
|
||||
while not thread.is_done():
|
||||
thread.wait()
|
||||
thread.connect()
|
||||
|
||||
|
||||
def __enter__(self) -> 'WorkerPool':
|
||||
return self
|
||||
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
self.finish_all()
|
||||
self.close()
|
||||
21
src/nominatim_core/db/async_core_library.py
Normal file
21
src/nominatim_core/db/async_core_library.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Import the base library to use with asynchronous SQLAlchemy.
|
||||
"""
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
import psycopg
|
||||
PGCORE_LIB = 'psycopg'
|
||||
PGCORE_ERROR: Any = psycopg.Error
|
||||
except ModuleNotFoundError:
|
||||
import asyncpg
|
||||
PGCORE_LIB = 'asyncpg'
|
||||
PGCORE_ERROR = asyncpg.PostgresError
|
||||
254
src/nominatim_core/db/connection.py
Normal file
254
src/nominatim_core/db/connection.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Specialised connection and cursor functions.
|
||||
"""
|
||||
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
import psycopg2.extras
|
||||
from psycopg2 import sql as pysql
|
||||
|
||||
from ..typing import SysEnv, Query, T_cursor
|
||||
from ..errors import UsageError
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
class Cursor(psycopg2.extras.DictCursor):
|
||||
""" A cursor returning dict-like objects and providing specialised
|
||||
execution functions.
|
||||
"""
|
||||
# pylint: disable=arguments-renamed,arguments-differ
|
||||
def execute(self, query: Query, args: Any = None) -> None:
|
||||
""" Query execution that logs the SQL query when debugging is enabled.
|
||||
"""
|
||||
if LOG.isEnabledFor(logging.DEBUG):
|
||||
LOG.debug(self.mogrify(query, args).decode('utf-8'))
|
||||
|
||||
super().execute(query, args)
|
||||
|
||||
|
||||
def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
|
||||
template: Optional[Query] = None) -> None:
|
||||
""" Wrapper for the psycopg2 convenience function to execute
|
||||
SQL for a list of values.
|
||||
"""
|
||||
LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
|
||||
|
||||
psycopg2.extras.execute_values(self, sql, argslist, template=template)
|
||||
|
||||
|
||||
def scalar(self, sql: Query, args: Any = None) -> Any:
|
||||
""" Execute query that returns a single value. The value is returned.
|
||||
If the query yields more than one row, a ValueError is raised.
|
||||
"""
|
||||
self.execute(sql, args)
|
||||
|
||||
if self.rowcount != 1:
|
||||
raise RuntimeError("Query did not return a single row.")
|
||||
|
||||
result = self.fetchone()
|
||||
assert result is not None
|
||||
|
||||
return result[0]
|
||||
|
||||
|
||||
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
|
||||
""" Drop the table with the given name.
|
||||
Set `if_exists` to False if a non-existent table should raise
|
||||
an exception instead of just being ignored. If 'cascade' is set
|
||||
to True then all dependent tables are deleted as well.
|
||||
"""
|
||||
sql = 'DROP TABLE '
|
||||
if if_exists:
|
||||
sql += 'IF EXISTS '
|
||||
sql += '{}'
|
||||
if cascade:
|
||||
sql += ' CASCADE'
|
||||
|
||||
self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
|
||||
|
||||
|
||||
class Connection(psycopg2.extensions.connection):
|
||||
""" A connection that provides the specialised cursor by default and
|
||||
adds convenience functions for administrating the database.
|
||||
"""
|
||||
@overload # type: ignore[override]
|
||||
def cursor(self) -> Cursor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cursor(self, name: str) -> Cursor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
|
||||
...
|
||||
|
||||
def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore
|
||||
""" Return a new cursor. By default the specialised cursor is returned.
|
||||
"""
|
||||
return super().cursor(cursor_factory=cursor_factory, **kwargs)
|
||||
|
||||
|
||||
def table_exists(self, table: str) -> bool:
|
||||
""" Check that a table with the given name exists in the database.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
num = cur.scalar("""SELECT count(*) FROM pg_tables
|
||||
WHERE tablename = %s and schemaname = 'public'""", (table, ))
|
||||
return num == 1 if isinstance(num, int) else False
|
||||
|
||||
|
||||
def table_has_column(self, table: str, column: str) -> bool:
|
||||
""" Check if the table 'table' exists and has a column with name 'column'.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
|
||||
WHERE table_name = %s
|
||||
and column_name = %s""",
|
||||
(table, column))
|
||||
return has_column > 0 if isinstance(has_column, int) else False
|
||||
|
||||
|
||||
def index_exists(self, index: str, table: Optional[str] = None) -> bool:
|
||||
""" Check that an index with the given name exists in the database.
|
||||
If table is not None then the index must relate to the given
|
||||
table.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute("""SELECT tablename FROM pg_indexes
|
||||
WHERE indexname = %s and schemaname = 'public'""", (index, ))
|
||||
if cur.rowcount == 0:
|
||||
return False
|
||||
|
||||
if table is not None:
|
||||
row = cur.fetchone()
|
||||
if row is None or not isinstance(row[0], str):
|
||||
return False
|
||||
return row[0] == table
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
|
||||
""" Drop the table with the given name.
|
||||
Set `if_exists` to False if a non-existent table should raise
|
||||
an exception instead of just being ignored.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.drop_table(name, if_exists, cascade)
|
||||
self.commit()
|
||||
|
||||
|
||||
def server_version_tuple(self) -> Tuple[int, int]:
|
||||
""" Return the server version as a tuple of (major, minor).
|
||||
Converts correctly for pre-10 and post-10 PostgreSQL versions.
|
||||
"""
|
||||
version = self.server_version
|
||||
if version < 100000:
|
||||
return (int(version / 10000), int((version % 10000) / 100))
|
||||
|
||||
return (int(version / 10000), version % 10000)
|
||||
|
||||
|
||||
def postgis_version_tuple(self) -> Tuple[int, int]:
|
||||
""" Return the postgis version installed in the database as a
|
||||
tuple of (major, minor). Assumes that the PostGIS extension
|
||||
has been installed already.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
version = cur.scalar('SELECT postgis_lib_version()')
|
||||
|
||||
version_parts = version.split('.')
|
||||
if len(version_parts) < 2:
|
||||
raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
|
||||
|
||||
return (int(version_parts[0]), int(version_parts[1]))
|
||||
|
||||
|
||||
def extension_loaded(self, extension_name: str) -> bool:
|
||||
""" Return True if the hstore extension is loaded in the database.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, ))
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
class ConnectionContext(ContextManager[Connection]):
|
||||
""" Context manager of the connection that also provides direct access
|
||||
to the underlying connection.
|
||||
"""
|
||||
connection: Connection
|
||||
|
||||
def connect(dsn: str) -> ConnectionContext:
|
||||
""" Open a connection to the database using the specialised connection
|
||||
factory. The returned object may be used in conjunction with 'with'.
|
||||
When used outside a context manager, use the `connection` attribute
|
||||
to get the connection.
|
||||
"""
|
||||
try:
|
||||
conn = psycopg2.connect(dsn, connection_factory=Connection)
|
||||
ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
|
||||
ctxmgr.connection = conn
|
||||
return ctxmgr
|
||||
except psycopg2.OperationalError as err:
|
||||
raise UsageError(f"Cannot connect to database: {err}") from err
|
||||
|
||||
|
||||
# Translation from PG connection string parameters to PG environment variables.
|
||||
# Derived from https://www.postgresql.org/docs/current/libpq-envars.html.
|
||||
_PG_CONNECTION_STRINGS = {
|
||||
'host': 'PGHOST',
|
||||
'hostaddr': 'PGHOSTADDR',
|
||||
'port': 'PGPORT',
|
||||
'dbname': 'PGDATABASE',
|
||||
'user': 'PGUSER',
|
||||
'password': 'PGPASSWORD',
|
||||
'passfile': 'PGPASSFILE',
|
||||
'channel_binding': 'PGCHANNELBINDING',
|
||||
'service': 'PGSERVICE',
|
||||
'options': 'PGOPTIONS',
|
||||
'application_name': 'PGAPPNAME',
|
||||
'sslmode': 'PGSSLMODE',
|
||||
'requiressl': 'PGREQUIRESSL',
|
||||
'sslcompression': 'PGSSLCOMPRESSION',
|
||||
'sslcert': 'PGSSLCERT',
|
||||
'sslkey': 'PGSSLKEY',
|
||||
'sslrootcert': 'PGSSLROOTCERT',
|
||||
'sslcrl': 'PGSSLCRL',
|
||||
'requirepeer': 'PGREQUIREPEER',
|
||||
'ssl_min_protocol_version': 'PGSSLMINPROTOCOLVERSION',
|
||||
'ssl_max_protocol_version': 'PGSSLMAXPROTOCOLVERSION',
|
||||
'gssencmode': 'PGGSSENCMODE',
|
||||
'krbsrvname': 'PGKRBSRVNAME',
|
||||
'gsslib': 'PGGSSLIB',
|
||||
'connect_timeout': 'PGCONNECT_TIMEOUT',
|
||||
'target_session_attrs': 'PGTARGETSESSIONATTRS',
|
||||
}
|
||||
|
||||
|
||||
def get_pg_env(dsn: str,
|
||||
base_env: Optional[SysEnv] = None) -> Dict[str, str]:
|
||||
""" Return a copy of `base_env` with the environment variables for
|
||||
PostgreSQL set up from the given database connection string.
|
||||
If `base_env` is None, then the OS environment is used as a base
|
||||
environment.
|
||||
"""
|
||||
env = dict(base_env if base_env is not None else os.environ)
|
||||
|
||||
for param, value in psycopg2.extensions.parse_dsn(dsn).items():
|
||||
if param in _PG_CONNECTION_STRINGS:
|
||||
env[_PG_CONNECTION_STRINGS[param]] = value
|
||||
else:
|
||||
LOG.error("Unknown connection parameter '%s' ignored.", param)
|
||||
|
||||
return env
|
||||
47
src/nominatim_core/db/properties.py
Normal file
47
src/nominatim_core/db/properties.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Query and access functions for the in-database property table.
|
||||
"""
|
||||
from typing import Optional, cast
|
||||
|
||||
from .connection import Connection
|
||||
|
||||
def set_property(conn: Connection, name: str, value: str) -> None:
|
||||
""" Add or replace the property with the given name.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('SELECT value FROM nominatim_properties WHERE property = %s',
|
||||
(name, ))
|
||||
|
||||
if cur.rowcount == 0:
|
||||
sql = 'INSERT INTO nominatim_properties (value, property) VALUES (%s, %s)'
|
||||
else:
|
||||
sql = 'UPDATE nominatim_properties SET value = %s WHERE property = %s'
|
||||
|
||||
cur.execute(sql, (value, name))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_property(conn: Connection, name: str) -> Optional[str]:
|
||||
""" Return the current value of the given property or None if the property
|
||||
is not set.
|
||||
"""
|
||||
if not conn.table_exists('nominatim_properties'):
|
||||
return None
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('SELECT value FROM nominatim_properties WHERE property = %s',
|
||||
(name, ))
|
||||
|
||||
if cur.rowcount == 0:
|
||||
return None
|
||||
|
||||
result = cur.fetchone()
|
||||
assert result is not None
|
||||
|
||||
return cast(Optional[str], result[0])
|
||||
143
src/nominatim_core/db/sql_preprocessor.py
Normal file
143
src/nominatim_core/db/sql_preprocessor.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Preprocessing of SQL files.
|
||||
"""
|
||||
from typing import Set, Dict, Any, cast
|
||||
import jinja2
|
||||
|
||||
from .connection import Connection
|
||||
from .async_connection import WorkerPool
|
||||
from ..config import Configuration
|
||||
|
||||
def _get_partitions(conn: Connection) -> Set[int]:
|
||||
""" Get the set of partitions currently in use.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('SELECT DISTINCT partition FROM country_name')
|
||||
partitions = set([0])
|
||||
for row in cur:
|
||||
partitions.add(row[0])
|
||||
|
||||
return partitions
|
||||
|
||||
|
||||
def _get_tables(conn: Connection) -> Set[str]:
|
||||
""" Return the set of tables currently in use.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT tablename FROM pg_tables WHERE schemaname = 'public'")
|
||||
|
||||
return set((row[0] for row in list(cur)))
|
||||
|
||||
def _get_middle_db_format(conn: Connection, tables: Set[str]) -> str:
|
||||
""" Returns the version of the slim middle tables.
|
||||
"""
|
||||
if 'osm2pgsql_properties' not in tables:
|
||||
return '1'
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
|
||||
row = cur.fetchone()
|
||||
|
||||
return cast(str, row[0]) if row is not None else '1'
|
||||
|
||||
|
||||
def _setup_tablespace_sql(config: Configuration) -> Dict[str, str]:
|
||||
""" Returns a dict with tablespace expressions for the different tablespace
|
||||
kinds depending on whether a tablespace is configured or not.
|
||||
"""
|
||||
out = {}
|
||||
for subset in ('ADDRESS', 'SEARCH', 'AUX'):
|
||||
for kind in ('DATA', 'INDEX'):
|
||||
tspace = getattr(config, f'TABLESPACE_{subset}_{kind}')
|
||||
if tspace:
|
||||
tspace = f'TABLESPACE "{tspace}"'
|
||||
out[f'{subset.lower()}_{kind.lower()}'] = tspace
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
|
||||
""" Set up a dictionary with various optional Postgresql/Postgis features that
|
||||
depend on the database version.
|
||||
"""
|
||||
pg_version = conn.server_version_tuple()
|
||||
postgis_version = conn.postgis_version_tuple()
|
||||
pg11plus = pg_version >= (11, 0, 0)
|
||||
ps3 = postgis_version >= (3, 0)
|
||||
return {
|
||||
'has_index_non_key_column': pg11plus,
|
||||
'spgist_geom' : 'SPGIST' if pg11plus and ps3 else 'GIST'
|
||||
}
|
||||
|
||||
class SQLPreprocessor:
|
||||
""" A environment for preprocessing SQL files from the
|
||||
lib-sql directory.
|
||||
|
||||
The preprocessor provides a number of default filters and variables.
|
||||
The variables may be overwritten when rendering an SQL file.
|
||||
|
||||
The preprocessing is currently based on the jinja2 templating library
|
||||
and follows its syntax.
|
||||
"""
|
||||
|
||||
def __init__(self, conn: Connection, config: Configuration) -> None:
|
||||
self.env = jinja2.Environment(autoescape=False,
|
||||
loader=jinja2.FileSystemLoader(str(config.lib_dir.sql)))
|
||||
|
||||
db_info: Dict[str, Any] = {}
|
||||
db_info['partitions'] = _get_partitions(conn)
|
||||
db_info['tables'] = _get_tables(conn)
|
||||
db_info['reverse_only'] = 'search_name' not in db_info['tables']
|
||||
db_info['tablespace'] = _setup_tablespace_sql(config)
|
||||
db_info['middle_db_format'] = _get_middle_db_format(conn, db_info['tables'])
|
||||
|
||||
self.env.globals['config'] = config
|
||||
self.env.globals['db'] = db_info
|
||||
self.env.globals['postgres'] = _setup_postgresql_features(conn)
|
||||
|
||||
|
||||
def run_string(self, conn: Connection, template: str, **kwargs: Any) -> None:
|
||||
""" Execute the given SQL template string on the connection.
|
||||
The keyword arguments may supply additional parameters
|
||||
for preprocessing.
|
||||
"""
|
||||
sql = self.env.from_string(template).render(**kwargs)
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def run_sql_file(self, conn: Connection, name: str, **kwargs: Any) -> None:
|
||||
""" Execute the given SQL file on the connection. The keyword arguments
|
||||
may supply additional parameters for preprocessing.
|
||||
"""
|
||||
sql = self.env.get_template(name).render(**kwargs)
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
|
||||
**kwargs: Any) -> None:
|
||||
""" Execute the given SQL files using parallel asynchronous connections.
|
||||
The keyword arguments may supply additional parameters for
|
||||
preprocessing.
|
||||
|
||||
After preprocessing the SQL code is cut at lines containing only
|
||||
'---'. Each chunk is sent to one of the `num_threads` workers.
|
||||
"""
|
||||
sql = self.env.get_template(name).render(**kwargs)
|
||||
|
||||
parts = sql.split('\n---\n')
|
||||
|
||||
with WorkerPool(dsn, num_threads) as pool:
|
||||
for part in parts:
|
||||
pool.next_free_worker().perform(part)
|
||||
119
src/nominatim_core/db/sqlalchemy_schema.py
Normal file
119
src/nominatim_core/db/sqlalchemy_schema.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
SQLAlchemy definitions for all tables used by the frontend.
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from .sqlalchemy_types import Geometry, KeyValueStore, IntArray
|
||||
|
||||
#pylint: disable=too-many-instance-attributes
|
||||
class SearchTables:
|
||||
""" Data class that holds the tables of the Nominatim database.
|
||||
|
||||
This schema strictly reflects the read-access view of the database.
|
||||
Any data used for updates only will not be visible.
|
||||
"""
|
||||
|
||||
def __init__(self, meta: sa.MetaData) -> None:
|
||||
self.meta = meta
|
||||
|
||||
self.import_status = sa.Table('import_status', meta,
|
||||
sa.Column('lastimportdate', sa.DateTime(True), nullable=False),
|
||||
sa.Column('sequence_id', sa.Integer),
|
||||
sa.Column('indexed', sa.Boolean))
|
||||
|
||||
self.properties = sa.Table('nominatim_properties', meta,
|
||||
sa.Column('property', sa.Text, nullable=False),
|
||||
sa.Column('value', sa.Text))
|
||||
|
||||
self.placex = sa.Table('placex', meta,
|
||||
sa.Column('place_id', sa.BigInteger, nullable=False),
|
||||
sa.Column('parent_place_id', sa.BigInteger),
|
||||
sa.Column('linked_place_id', sa.BigInteger),
|
||||
sa.Column('importance', sa.Float),
|
||||
sa.Column('indexed_date', sa.DateTime),
|
||||
sa.Column('rank_address', sa.SmallInteger),
|
||||
sa.Column('rank_search', sa.SmallInteger),
|
||||
sa.Column('indexed_status', sa.SmallInteger),
|
||||
sa.Column('osm_type', sa.String(1), nullable=False),
|
||||
sa.Column('osm_id', sa.BigInteger, nullable=False),
|
||||
sa.Column('class', sa.Text, nullable=False, key='class_'),
|
||||
sa.Column('type', sa.Text, nullable=False),
|
||||
sa.Column('admin_level', sa.SmallInteger),
|
||||
sa.Column('name', KeyValueStore),
|
||||
sa.Column('address', KeyValueStore),
|
||||
sa.Column('extratags', KeyValueStore),
|
||||
sa.Column('geometry', Geometry, nullable=False),
|
||||
sa.Column('wikipedia', sa.Text),
|
||||
sa.Column('country_code', sa.String(2)),
|
||||
sa.Column('housenumber', sa.Text),
|
||||
sa.Column('postcode', sa.Text),
|
||||
sa.Column('centroid', Geometry))
|
||||
|
||||
self.addressline = sa.Table('place_addressline', meta,
|
||||
sa.Column('place_id', sa.BigInteger),
|
||||
sa.Column('address_place_id', sa.BigInteger),
|
||||
sa.Column('distance', sa.Float),
|
||||
sa.Column('fromarea', sa.Boolean),
|
||||
sa.Column('isaddress', sa.Boolean))
|
||||
|
||||
self.postcode = sa.Table('location_postcode', meta,
|
||||
sa.Column('place_id', sa.BigInteger),
|
||||
sa.Column('parent_place_id', sa.BigInteger),
|
||||
sa.Column('rank_search', sa.SmallInteger),
|
||||
sa.Column('rank_address', sa.SmallInteger),
|
||||
sa.Column('indexed_status', sa.SmallInteger),
|
||||
sa.Column('indexed_date', sa.DateTime),
|
||||
sa.Column('country_code', sa.String(2)),
|
||||
sa.Column('postcode', sa.Text),
|
||||
sa.Column('geometry', Geometry))
|
||||
|
||||
self.osmline = sa.Table('location_property_osmline', meta,
|
||||
sa.Column('place_id', sa.BigInteger, nullable=False),
|
||||
sa.Column('osm_id', sa.BigInteger),
|
||||
sa.Column('parent_place_id', sa.BigInteger),
|
||||
sa.Column('indexed_date', sa.DateTime),
|
||||
sa.Column('startnumber', sa.Integer),
|
||||
sa.Column('endnumber', sa.Integer),
|
||||
sa.Column('step', sa.SmallInteger),
|
||||
sa.Column('indexed_status', sa.SmallInteger),
|
||||
sa.Column('linegeo', Geometry),
|
||||
sa.Column('address', KeyValueStore),
|
||||
sa.Column('postcode', sa.Text),
|
||||
sa.Column('country_code', sa.String(2)))
|
||||
|
||||
self.country_name = sa.Table('country_name', meta,
|
||||
sa.Column('country_code', sa.String(2)),
|
||||
sa.Column('name', KeyValueStore),
|
||||
sa.Column('derived_name', KeyValueStore),
|
||||
sa.Column('partition', sa.Integer))
|
||||
|
||||
self.country_grid = sa.Table('country_osm_grid', meta,
|
||||
sa.Column('country_code', sa.String(2)),
|
||||
sa.Column('area', sa.Float),
|
||||
sa.Column('geometry', Geometry))
|
||||
|
||||
# The following tables are not necessarily present.
|
||||
self.search_name = sa.Table('search_name', meta,
|
||||
sa.Column('place_id', sa.BigInteger),
|
||||
sa.Column('importance', sa.Float),
|
||||
sa.Column('search_rank', sa.SmallInteger),
|
||||
sa.Column('address_rank', sa.SmallInteger),
|
||||
sa.Column('name_vector', IntArray),
|
||||
sa.Column('nameaddress_vector', IntArray),
|
||||
sa.Column('country_code', sa.String(2)),
|
||||
sa.Column('centroid', Geometry))
|
||||
|
||||
self.tiger = sa.Table('location_property_tiger', meta,
|
||||
sa.Column('place_id', sa.BigInteger),
|
||||
sa.Column('parent_place_id', sa.BigInteger),
|
||||
sa.Column('startnumber', sa.Integer),
|
||||
sa.Column('endnumber', sa.Integer),
|
||||
sa.Column('step', sa.SmallInteger),
|
||||
sa.Column('linegeo', Geometry),
|
||||
sa.Column('postcode', sa.Text))
|
||||
17
src/nominatim_core/db/sqlalchemy_types/__init__.py
Normal file
17
src/nominatim_core/db/sqlalchemy_types/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Module with custom types for SQLAlchemy
|
||||
"""
|
||||
|
||||
# See also https://github.com/PyCQA/pylint/issues/6006
|
||||
# pylint: disable=useless-import-alias
|
||||
|
||||
from .geometry import (Geometry as Geometry)
|
||||
from .int_array import (IntArray as IntArray)
|
||||
from .key_value import (KeyValueStore as KeyValueStore)
|
||||
from .json import (Json as Json)
|
||||
308
src/nominatim_core/db/sqlalchemy_types/geometry.py
Normal file
308
src/nominatim_core/db/sqlalchemy_types/geometry.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Custom types for SQLAlchemy.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import Callable, Any, cast
|
||||
import sys
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy import types
|
||||
|
||||
from ...typing import SaColumn, SaBind
|
||||
|
||||
#pylint: disable=all
|
||||
|
||||
class Geometry_DistanceSpheroid(sa.sql.expression.FunctionElement[float]):
|
||||
""" Function to compute the spherical distance in meters.
|
||||
"""
|
||||
type = sa.Float()
|
||||
name = 'Geometry_DistanceSpheroid'
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
@compiles(Geometry_DistanceSpheroid) # type: ignore[no-untyped-call, misc]
|
||||
def _default_distance_spheroid(element: Geometry_DistanceSpheroid,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
return "ST_DistanceSpheroid(%s,"\
|
||||
" 'SPHEROID[\"WGS 84\",6378137,298.257223563, AUTHORITY[\"EPSG\",\"7030\"]]')"\
|
||||
% compiler.process(element.clauses, **kw)
|
||||
|
||||
|
||||
@compiles(Geometry_DistanceSpheroid, 'sqlite') # type: ignore[no-untyped-call, misc]
|
||||
def _spatialite_distance_spheroid(element: Geometry_DistanceSpheroid,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
return "COALESCE(Distance(%s, true), 0.0)" % compiler.process(element.clauses, **kw)
|
||||
|
||||
|
||||
class Geometry_IsLineLike(sa.sql.expression.FunctionElement[Any]):
|
||||
""" Check if the geometry is a line or multiline.
|
||||
"""
|
||||
name = 'Geometry_IsLineLike'
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
@compiles(Geometry_IsLineLike) # type: ignore[no-untyped-call, misc]
|
||||
def _default_is_line_like(element: Geometry_IsLineLike,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
return "ST_GeometryType(%s) IN ('ST_LineString', 'ST_MultiLineString')" % \
|
||||
compiler.process(element.clauses, **kw)
|
||||
|
||||
|
||||
@compiles(Geometry_IsLineLike, 'sqlite') # type: ignore[no-untyped-call, misc]
|
||||
def _sqlite_is_line_like(element: Geometry_IsLineLike,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
return "ST_GeometryType(%s) IN ('LINESTRING', 'MULTILINESTRING')" % \
|
||||
compiler.process(element.clauses, **kw)
|
||||
|
||||
|
||||
class Geometry_IsAreaLike(sa.sql.expression.FunctionElement[Any]):
|
||||
""" Check if the geometry is a polygon or multipolygon.
|
||||
"""
|
||||
name = 'Geometry_IsLineLike'
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
@compiles(Geometry_IsAreaLike) # type: ignore[no-untyped-call, misc]
|
||||
def _default_is_area_like(element: Geometry_IsAreaLike,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
return "ST_GeometryType(%s) IN ('ST_Polygon', 'ST_MultiPolygon')" % \
|
||||
compiler.process(element.clauses, **kw)
|
||||
|
||||
|
||||
@compiles(Geometry_IsAreaLike, 'sqlite') # type: ignore[no-untyped-call, misc]
|
||||
def _sqlite_is_area_like(element: Geometry_IsAreaLike,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
return "ST_GeometryType(%s) IN ('POLYGON', 'MULTIPOLYGON')" % \
|
||||
compiler.process(element.clauses, **kw)
|
||||
|
||||
|
||||
class Geometry_IntersectsBbox(sa.sql.expression.FunctionElement[Any]):
|
||||
""" Check if the bounding boxes of the given geometries intersect.
|
||||
"""
|
||||
name = 'Geometry_IntersectsBbox'
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
@compiles(Geometry_IntersectsBbox) # type: ignore[no-untyped-call, misc]
|
||||
def _default_intersects(element: Geometry_IntersectsBbox,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
arg1, arg2 = list(element.clauses)
|
||||
return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
|
||||
|
||||
|
||||
@compiles(Geometry_IntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
|
||||
def _sqlite_intersects(element: Geometry_IntersectsBbox,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
return "MbrIntersects(%s) = 1" % compiler.process(element.clauses, **kw)
|
||||
|
||||
|
||||
class Geometry_ColumnIntersectsBbox(sa.sql.expression.FunctionElement[Any]):
|
||||
""" Check if the bounding box of the geometry intersects with the
|
||||
given table column, using the spatial index for the column.
|
||||
|
||||
The index must exist or the query may return nothing.
|
||||
"""
|
||||
name = 'Geometry_ColumnIntersectsBbox'
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
@compiles(Geometry_ColumnIntersectsBbox) # type: ignore[no-untyped-call, misc]
|
||||
def default_intersects_column(element: Geometry_ColumnIntersectsBbox,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
arg1, arg2 = list(element.clauses)
|
||||
return "%s && %s" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
|
||||
|
||||
|
||||
@compiles(Geometry_ColumnIntersectsBbox, 'sqlite') # type: ignore[no-untyped-call, misc]
|
||||
def spatialite_intersects_column(element: Geometry_ColumnIntersectsBbox,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
arg1, arg2 = list(element.clauses)
|
||||
return "MbrIntersects(%s, %s) = 1 and "\
|
||||
"%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
|
||||
"WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
|
||||
"AND search_frame = %s)" %(
|
||||
compiler.process(arg1, **kw),
|
||||
compiler.process(arg2, **kw),
|
||||
arg1.table.name, arg1.table.name, arg1.name,
|
||||
compiler.process(arg2, **kw))
|
||||
|
||||
|
||||
class Geometry_ColumnDWithin(sa.sql.expression.FunctionElement[Any]):
|
||||
""" Check if the geometry is within the distance of the
|
||||
given table column, using the spatial index for the column.
|
||||
|
||||
The index must exist or the query may return nothing.
|
||||
"""
|
||||
name = 'Geometry_ColumnDWithin'
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
@compiles(Geometry_ColumnDWithin) # type: ignore[no-untyped-call, misc]
|
||||
def default_dwithin_column(element: Geometry_ColumnDWithin,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
return "ST_DWithin(%s)" % compiler.process(element.clauses, **kw)
|
||||
|
||||
@compiles(Geometry_ColumnDWithin, 'sqlite') # type: ignore[no-untyped-call, misc]
|
||||
def spatialite_dwithin_column(element: Geometry_ColumnDWithin,
|
||||
compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
geom1, geom2, dist = list(element.clauses)
|
||||
return "ST_Distance(%s, %s) < %s and "\
|
||||
"%s.ROWID IN (SELECT ROWID FROM SpatialIndex "\
|
||||
"WHERE f_table_name = '%s' AND f_geometry_column = '%s' "\
|
||||
"AND search_frame = ST_Expand(%s, %s))" %(
|
||||
compiler.process(geom1, **kw),
|
||||
compiler.process(geom2, **kw),
|
||||
compiler.process(dist, **kw),
|
||||
geom1.table.name, geom1.table.name, geom1.name,
|
||||
compiler.process(geom2, **kw),
|
||||
compiler.process(dist, **kw))
|
||||
|
||||
|
||||
class Geometry(types.UserDefinedType): # type: ignore[type-arg]
|
||||
""" Simplified type decorator for PostGIS geometry. This type
|
||||
only supports geometries in 4326 projection.
|
||||
"""
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, subtype: str = 'Geometry'):
|
||||
self.subtype = subtype
|
||||
|
||||
|
||||
def get_col_spec(self) -> str:
|
||||
return f'GEOMETRY({self.subtype}, 4326)'
|
||||
|
||||
|
||||
def bind_processor(self, dialect: 'sa.Dialect') -> Callable[[Any], str]:
|
||||
def process(value: Any) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
return cast(str, value.to_wkt())
|
||||
return process
|
||||
|
||||
|
||||
def result_processor(self, dialect: 'sa.Dialect', coltype: object) -> Callable[[Any], str]:
|
||||
def process(value: Any) -> str:
|
||||
assert isinstance(value, str)
|
||||
return value
|
||||
return process
|
||||
|
||||
|
||||
def column_expression(self, col: SaColumn) -> SaColumn:
|
||||
return sa.func.ST_AsEWKB(col)
|
||||
|
||||
|
||||
def bind_expression(self, bindvalue: SaBind) -> SaColumn:
|
||||
return sa.func.ST_GeomFromText(bindvalue, sa.text('4326'), type_=self)
|
||||
|
||||
|
||||
class comparator_factory(types.UserDefinedType.Comparator): # type: ignore[type-arg]
|
||||
|
||||
def intersects(self, other: SaColumn, use_index: bool = True) -> 'sa.Operators':
|
||||
if not use_index:
|
||||
return Geometry_IntersectsBbox(sa.func.coalesce(sa.null(), self.expr), other)
|
||||
|
||||
if isinstance(self.expr, sa.Column):
|
||||
return Geometry_ColumnIntersectsBbox(self.expr, other)
|
||||
|
||||
return Geometry_IntersectsBbox(self.expr, other)
|
||||
|
||||
|
||||
def is_line_like(self) -> SaColumn:
|
||||
return Geometry_IsLineLike(self)
|
||||
|
||||
|
||||
def is_area(self) -> SaColumn:
|
||||
return Geometry_IsAreaLike(self)
|
||||
|
||||
|
||||
def within_distance(self, other: SaColumn, distance: SaColumn) -> SaColumn:
|
||||
if isinstance(self.expr, sa.Column):
|
||||
return Geometry_ColumnDWithin(self.expr, other, distance)
|
||||
|
||||
return self.ST_Distance(other) < distance
|
||||
|
||||
|
||||
def ST_Distance(self, other: SaColumn) -> SaColumn:
|
||||
return sa.func.ST_Distance(self, other, type_=sa.Float)
|
||||
|
||||
|
||||
def ST_Contains(self, other: SaColumn) -> SaColumn:
|
||||
return sa.func.ST_Contains(self, other, type_=sa.Boolean)
|
||||
|
||||
|
||||
def ST_CoveredBy(self, other: SaColumn) -> SaColumn:
|
||||
return sa.func.ST_CoveredBy(self, other, type_=sa.Boolean)
|
||||
|
||||
|
||||
def ST_ClosestPoint(self, other: SaColumn) -> SaColumn:
|
||||
return sa.func.coalesce(sa.func.ST_ClosestPoint(self, other, type_=Geometry),
|
||||
other)
|
||||
|
||||
|
||||
def ST_Buffer(self, other: SaColumn) -> SaColumn:
|
||||
return sa.func.ST_Buffer(self, other, type_=Geometry)
|
||||
|
||||
|
||||
def ST_Expand(self, other: SaColumn) -> SaColumn:
|
||||
return sa.func.ST_Expand(self, other, type_=Geometry)
|
||||
|
||||
|
||||
def ST_Collect(self) -> SaColumn:
|
||||
return sa.func.ST_Collect(self, type_=Geometry)
|
||||
|
||||
|
||||
def ST_Centroid(self) -> SaColumn:
|
||||
return sa.func.ST_Centroid(self, type_=Geometry)
|
||||
|
||||
|
||||
def ST_LineInterpolatePoint(self, other: SaColumn) -> SaColumn:
|
||||
return sa.func.ST_LineInterpolatePoint(self, other, type_=Geometry)
|
||||
|
||||
|
||||
def ST_LineLocatePoint(self, other: SaColumn) -> SaColumn:
|
||||
return sa.func.ST_LineLocatePoint(self, other, type_=sa.Float)
|
||||
|
||||
|
||||
def distance_spheroid(self, other: SaColumn) -> SaColumn:
|
||||
return Geometry_DistanceSpheroid(self, other)
|
||||
|
||||
|
||||
@compiles(Geometry, 'sqlite') # type: ignore[no-untyped-call]
|
||||
def get_col_spec(self, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
return 'GEOMETRY'
|
||||
|
||||
|
||||
SQLITE_FUNCTION_ALIAS = (
|
||||
('ST_AsEWKB', sa.Text, 'AsEWKB'),
|
||||
('ST_GeomFromEWKT', Geometry, 'GeomFromEWKT'),
|
||||
('ST_AsGeoJSON', sa.Text, 'AsGeoJSON'),
|
||||
('ST_AsKML', sa.Text, 'AsKML'),
|
||||
('ST_AsSVG', sa.Text, 'AsSVG'),
|
||||
('ST_LineLocatePoint', sa.Float, 'ST_Line_Locate_Point'),
|
||||
('ST_LineInterpolatePoint', sa.Float, 'ST_Line_Interpolate_Point'),
|
||||
)
|
||||
|
||||
def _add_function_alias(func: str, ftype: type, alias: str) -> None:
|
||||
_FuncDef = type(func, (sa.sql.functions.GenericFunction, ), {
|
||||
"type": ftype(),
|
||||
"name": func,
|
||||
"identifier": func,
|
||||
"inherit_cache": True})
|
||||
|
||||
func_templ = f"{alias}(%s)"
|
||||
|
||||
def _sqlite_impl(element: Any, compiler: Any, **kw: Any) -> Any:
|
||||
return func_templ % compiler.process(element.clauses, **kw)
|
||||
|
||||
compiles(_FuncDef, 'sqlite')(_sqlite_impl) # type: ignore[no-untyped-call]
|
||||
|
||||
for alias in SQLITE_FUNCTION_ALIAS:
|
||||
_add_function_alias(*alias)
|
||||
123
src/nominatim_core/db/sqlalchemy_types/int_array.py
Normal file
123
src/nominatim_core/db/sqlalchemy_types/int_array.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Custom type for an array of integers.
|
||||
"""
|
||||
from typing import Any, List, cast, Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from ...typing import SaDialect, SaColumn
|
||||
|
||||
# pylint: disable=all
|
||||
|
||||
class IntList(sa.types.TypeDecorator[Any]):
|
||||
""" A list of integers saved as a text of comma-separated numbers.
|
||||
"""
|
||||
impl = sa.types.Unicode
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: Optional[Any], dialect: 'sa.Dialect') -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
assert isinstance(value, list)
|
||||
return ','.join(map(str, value))
|
||||
|
||||
def process_result_value(self, value: Optional[Any],
|
||||
dialect: SaDialect) -> Optional[List[int]]:
|
||||
return [int(v) for v in value.split(',')] if value is not None else None
|
||||
|
||||
def copy(self, **kw: Any) -> 'IntList':
|
||||
return IntList(self.impl.length)
|
||||
|
||||
|
||||
class IntArray(sa.types.TypeDecorator[Any]):
|
||||
""" Dialect-independent list of integers.
|
||||
"""
|
||||
impl = IntList
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
|
||||
if dialect.name == 'postgresql':
|
||||
return ARRAY(sa.Integer()) #pylint: disable=invalid-name
|
||||
|
||||
return IntList()
|
||||
|
||||
|
||||
class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
|
||||
|
||||
def __add__(self, other: SaColumn) -> 'sa.ColumnOperators':
|
||||
""" Concate the array with the given array. If one of the
|
||||
operants is null, the value of the other will be returned.
|
||||
"""
|
||||
return ArrayCat(self.expr, other)
|
||||
|
||||
|
||||
def contains(self, other: SaColumn, **kwargs: Any) -> 'sa.ColumnOperators':
|
||||
""" Return true if the array contains all the value of the argument
|
||||
array.
|
||||
"""
|
||||
return ArrayContains(self.expr, other)
|
||||
|
||||
|
||||
|
||||
class ArrayAgg(sa.sql.functions.GenericFunction[Any]):
|
||||
""" Aggregate function to collect elements in an array.
|
||||
"""
|
||||
type = IntArray()
|
||||
identifier = 'ArrayAgg'
|
||||
name = 'array_agg'
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
@compiles(ArrayAgg, 'sqlite') # type: ignore[no-untyped-call, misc]
|
||||
def sqlite_array_agg(element: ArrayAgg, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
return "group_concat(%s, ',')" % compiler.process(element.clauses, **kw)
|
||||
|
||||
|
||||
|
||||
class ArrayContains(sa.sql.expression.FunctionElement[Any]):
|
||||
""" Function to check if an array is fully contained in another.
|
||||
"""
|
||||
name = 'ArrayContains'
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
@compiles(ArrayContains) # type: ignore[no-untyped-call, misc]
|
||||
def generic_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
arg1, arg2 = list(element.clauses)
|
||||
return "(%s @> %s)" % (compiler.process(arg1, **kw),
|
||||
compiler.process(arg2, **kw))
|
||||
|
||||
|
||||
@compiles(ArrayContains, 'sqlite') # type: ignore[no-untyped-call, misc]
|
||||
def sqlite_array_contains(element: ArrayContains, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
return "array_contains(%s)" % compiler.process(element.clauses, **kw)
|
||||
|
||||
|
||||
|
||||
class ArrayCat(sa.sql.expression.FunctionElement[Any]):
|
||||
""" Function to check if an array is fully contained in another.
|
||||
"""
|
||||
type = IntArray()
|
||||
identifier = 'ArrayCat'
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
@compiles(ArrayCat) # type: ignore[no-untyped-call, misc]
|
||||
def generic_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
return "array_cat(%s)" % compiler.process(element.clauses, **kw)
|
||||
|
||||
|
||||
@compiles(ArrayCat, 'sqlite') # type: ignore[no-untyped-call, misc]
|
||||
def sqlite_array_cat(element: ArrayCat, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
arg1, arg2 = list(element.clauses)
|
||||
return "(%s || ',' || %s)" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
|
||||
|
||||
30
src/nominatim_core/db/sqlalchemy_types/json.py
Normal file
30
src/nominatim_core/db/sqlalchemy_types/json.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Common json type for different dialects.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.dialects.sqlite import JSON as sqlite_json
|
||||
|
||||
from ...typing import SaDialect
|
||||
|
||||
# pylint: disable=all
|
||||
|
||||
class Json(sa.types.TypeDecorator[Any]):
|
||||
""" Dialect-independent type for JSON.
|
||||
"""
|
||||
impl = sa.types.JSON
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
|
||||
if dialect.name == 'postgresql':
|
||||
return JSONB(none_as_null=True) # type: ignore[no-untyped-call]
|
||||
|
||||
return sqlite_json(none_as_null=True)
|
||||
62
src/nominatim_core/db/sqlalchemy_types/key_value.py
Normal file
62
src/nominatim_core/db/sqlalchemy_types/key_value.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
A custom type that implements a simple key-value store of strings.
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.dialects.postgresql import HSTORE
|
||||
from sqlalchemy.dialects.sqlite import JSON as sqlite_json
|
||||
|
||||
from ...typing import SaDialect, SaColumn
|
||||
|
||||
# pylint: disable=all
|
||||
|
||||
class KeyValueStore(sa.types.TypeDecorator[Any]):
|
||||
""" Dialect-independent type of a simple key-value store of strings.
|
||||
"""
|
||||
impl = HSTORE
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect: SaDialect) -> sa.types.TypeEngine[Any]:
|
||||
if dialect.name == 'postgresql':
|
||||
return HSTORE() # type: ignore[no-untyped-call]
|
||||
|
||||
return sqlite_json(none_as_null=True)
|
||||
|
||||
|
||||
class comparator_factory(sa.types.UserDefinedType.Comparator): # type: ignore[type-arg]
|
||||
|
||||
def merge(self, other: SaColumn) -> 'sa.Operators':
|
||||
""" Merge the values from the given KeyValueStore into this
|
||||
one, overwriting values where necessary. When the argument
|
||||
is null, nothing happens.
|
||||
"""
|
||||
return KeyValueConcat(self.expr, other)
|
||||
|
||||
|
||||
class KeyValueConcat(sa.sql.expression.FunctionElement[Any]):
|
||||
""" Return the merged key-value store from the input parameters.
|
||||
"""
|
||||
type = KeyValueStore()
|
||||
name = 'JsonConcat'
|
||||
inherit_cache = True
|
||||
|
||||
@compiles(KeyValueConcat) # type: ignore[no-untyped-call, misc]
|
||||
def default_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
arg1, arg2 = list(element.clauses)
|
||||
return "(%s || coalesce(%s, ''::hstore))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
|
||||
|
||||
@compiles(KeyValueConcat, 'sqlite') # type: ignore[no-untyped-call, misc]
|
||||
def sqlite_json_concat(element: KeyValueConcat, compiler: 'sa.Compiled', **kw: Any) -> str:
|
||||
arg1, arg2 = list(element.clauses)
|
||||
return "json_patch(%s, coalesce(%s, '{}'))" % (compiler.process(arg1, **kw), compiler.process(arg2, **kw))
|
||||
|
||||
|
||||
|
||||
127
src/nominatim_core/db/status.py
Normal file
127
src/nominatim_core/db/status.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Access and helper functions for the status and status log table.
|
||||
"""
|
||||
from typing import Optional, Tuple, cast
|
||||
import datetime as dt
|
||||
import logging
|
||||
import re
|
||||
|
||||
from .connection import Connection
|
||||
from ..utils.url_utils import get_url
|
||||
from ..errors import UsageError
|
||||
from ..typing import TypedDict
|
||||
|
||||
LOG = logging.getLogger()
|
||||
ISODATE_FORMAT = '%Y-%m-%dT%H:%M:%S'
|
||||
|
||||
|
||||
class StatusRow(TypedDict):
|
||||
""" Dictionary of columns of the import_status table.
|
||||
"""
|
||||
lastimportdate: dt.datetime
|
||||
sequence_id: Optional[int]
|
||||
indexed: Optional[bool]
|
||||
|
||||
|
||||
def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetime:
|
||||
""" Determine the date of the database from the newest object in the
|
||||
data base.
|
||||
"""
|
||||
# If there is a date from osm2pgsql available, use that.
|
||||
if conn.table_exists('osm2pgsql_properties'):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(""" SELECT value FROM osm2pgsql_properties
|
||||
WHERE property = 'current_timestamp' """)
|
||||
row = cur.fetchone()
|
||||
if row is not None:
|
||||
return dt.datetime.strptime(row[0], "%Y-%m-%dT%H:%M:%SZ")\
|
||||
.replace(tzinfo=dt.timezone.utc)
|
||||
|
||||
if offline:
|
||||
raise UsageError("Cannot determine database date from data in offline mode.")
|
||||
|
||||
# Else, find the node with the highest ID in the database
|
||||
with conn.cursor() as cur:
|
||||
if conn.table_exists('place'):
|
||||
osmid = cur.scalar("SELECT max(osm_id) FROM place WHERE osm_type='N'")
|
||||
else:
|
||||
osmid = cur.scalar("SELECT max(osm_id) FROM placex WHERE osm_type='N'")
|
||||
|
||||
if osmid is None:
|
||||
LOG.fatal("No data found in the database.")
|
||||
raise UsageError("No data found in the database.")
|
||||
|
||||
LOG.info("Using node id %d for timestamp lookup", osmid)
|
||||
# Get the node from the API to find the timestamp when it was created.
|
||||
node_url = f'https://www.openstreetmap.org/api/0.6/node/{osmid}/1'
|
||||
data = get_url(node_url)
|
||||
|
||||
match = re.search(r'timestamp="((\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2}):(\d{2}))Z"', data)
|
||||
|
||||
if match is None:
|
||||
LOG.fatal("The node data downloaded from the API does not contain valid data.\n"
|
||||
"URL used: %s", node_url)
|
||||
raise UsageError("Bad API data.")
|
||||
|
||||
LOG.debug("Found timestamp %s", match.group(1))
|
||||
|
||||
return dt.datetime.strptime(match.group(1), ISODATE_FORMAT).replace(tzinfo=dt.timezone.utc)
|
||||
|
||||
|
||||
def set_status(conn: Connection, date: Optional[dt.datetime],
|
||||
seq: Optional[int] = None, indexed: bool = True) -> None:
|
||||
""" Replace the current status with the given status. If date is `None`
|
||||
then only sequence and indexed will be updated as given. Otherwise
|
||||
the whole status is replaced.
|
||||
The change will be committed to the database.
|
||||
"""
|
||||
assert date is None or date.tzinfo == dt.timezone.utc
|
||||
with conn.cursor() as cur:
|
||||
if date is None:
|
||||
cur.execute("UPDATE import_status set sequence_id = %s, indexed = %s",
|
||||
(seq, indexed))
|
||||
else:
|
||||
cur.execute("TRUNCATE TABLE import_status")
|
||||
cur.execute("""INSERT INTO import_status (lastimportdate, sequence_id, indexed)
|
||||
VALUES (%s, %s, %s)""", (date, seq, indexed))
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_status(conn: Connection) -> Tuple[Optional[dt.datetime], Optional[int], Optional[bool]]:
|
||||
""" Return the current status as a triple of (date, sequence, indexed).
|
||||
If status has not been set up yet, a triple of None is returned.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT * FROM import_status LIMIT 1")
|
||||
if cur.rowcount < 1:
|
||||
return None, None, None
|
||||
|
||||
row = cast(StatusRow, cur.fetchone())
|
||||
return row['lastimportdate'], row['sequence_id'], row['indexed']
|
||||
|
||||
|
||||
def set_indexed(conn: Connection, state: bool) -> None:
|
||||
""" Set the indexed flag in the status table to the given state.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("UPDATE import_status SET indexed = %s", (state, ))
|
||||
conn.commit()
|
||||
|
||||
|
||||
def log_status(conn: Connection, start: dt.datetime,
|
||||
event: str, batchsize: Optional[int] = None) -> None:
|
||||
""" Write a new status line to the `import_osmosis_log` table.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""INSERT INTO import_osmosis_log
|
||||
(batchend, batchseq, batchsize, starttime, endtime, event)
|
||||
SELECT lastimportdate, sequence_id, %s, %s, now(), %s FROM import_status""",
|
||||
(batchsize, start, event))
|
||||
conn.commit()
|
||||
129
src/nominatim_core/db/utils.py
Normal file
129
src/nominatim_core/db/utils.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Helper functions for handling DB accesses.
|
||||
"""
|
||||
from typing import IO, Optional, Union, Any, Iterable
|
||||
import subprocess
|
||||
import logging
|
||||
import gzip
|
||||
import io
|
||||
from pathlib import Path
|
||||
|
||||
from .connection import get_pg_env, Cursor
|
||||
from ..errors import UsageError
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
def _pipe_to_proc(proc: 'subprocess.Popen[bytes]',
|
||||
fdesc: Union[IO[bytes], gzip.GzipFile]) -> int:
|
||||
assert proc.stdin is not None
|
||||
chunk = fdesc.read(2048)
|
||||
while chunk and proc.poll() is None:
|
||||
try:
|
||||
proc.stdin.write(chunk)
|
||||
except BrokenPipeError as exc:
|
||||
raise UsageError("Failed to execute SQL file.") from exc
|
||||
chunk = fdesc.read(2048)
|
||||
|
||||
return len(chunk)
|
||||
|
||||
def execute_file(dsn: str, fname: Path,
|
||||
ignore_errors: bool = False,
|
||||
pre_code: Optional[str] = None,
|
||||
post_code: Optional[str] = None) -> None:
|
||||
""" Read an SQL file and run its contents against the given database
|
||||
using psql. Use `pre_code` and `post_code` to run extra commands
|
||||
before or after executing the file. The commands are run within the
|
||||
same session, so they may be used to wrap the file execution in a
|
||||
transaction.
|
||||
"""
|
||||
cmd = ['psql']
|
||||
if not ignore_errors:
|
||||
cmd.extend(('-v', 'ON_ERROR_STOP=1'))
|
||||
if not LOG.isEnabledFor(logging.INFO):
|
||||
cmd.append('--quiet')
|
||||
|
||||
with subprocess.Popen(cmd, env=get_pg_env(dsn), stdin=subprocess.PIPE) as proc:
|
||||
assert proc.stdin is not None
|
||||
try:
|
||||
if not LOG.isEnabledFor(logging.INFO):
|
||||
proc.stdin.write('set client_min_messages to WARNING;'.encode('utf-8'))
|
||||
|
||||
if pre_code:
|
||||
proc.stdin.write((pre_code + ';').encode('utf-8'))
|
||||
|
||||
if fname.suffix == '.gz':
|
||||
with gzip.open(str(fname), 'rb') as fdesc:
|
||||
remain = _pipe_to_proc(proc, fdesc)
|
||||
else:
|
||||
with fname.open('rb') as fdesc:
|
||||
remain = _pipe_to_proc(proc, fdesc)
|
||||
|
||||
if remain == 0 and post_code:
|
||||
proc.stdin.write((';' + post_code).encode('utf-8'))
|
||||
finally:
|
||||
proc.stdin.close()
|
||||
ret = proc.wait()
|
||||
|
||||
if ret != 0 or remain > 0:
|
||||
raise UsageError("Failed to execute SQL file.")
|
||||
|
||||
|
||||
# List of characters that need to be quoted for the copy command.
|
||||
_SQL_TRANSLATION = {ord('\\'): '\\\\',
|
||||
ord('\t'): '\\t',
|
||||
ord('\n'): '\\n'}
|
||||
|
||||
|
||||
class CopyBuffer:
|
||||
""" Data collector for the copy_from command.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.buffer = io.StringIO()
|
||||
|
||||
|
||||
def __enter__(self) -> 'CopyBuffer':
|
||||
return self
|
||||
|
||||
|
||||
def size(self) -> int:
|
||||
""" Return the number of bytes the buffer currently contains.
|
||||
"""
|
||||
return self.buffer.tell()
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
if self.buffer is not None:
|
||||
self.buffer.close()
|
||||
|
||||
|
||||
def add(self, *data: Any) -> None:
|
||||
""" Add another row of data to the copy buffer.
|
||||
"""
|
||||
first = True
|
||||
for column in data:
|
||||
if first:
|
||||
first = False
|
||||
else:
|
||||
self.buffer.write('\t')
|
||||
if column is None:
|
||||
self.buffer.write('\\N')
|
||||
else:
|
||||
self.buffer.write(str(column).translate(_SQL_TRANSLATION))
|
||||
self.buffer.write('\n')
|
||||
|
||||
|
||||
def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None:
|
||||
""" Copy all collected data into the given table.
|
||||
|
||||
The buffer is empty and reusable after this operation.
|
||||
"""
|
||||
if self.buffer.tell() > 0:
|
||||
self.buffer.seek(0)
|
||||
cur.copy_from(self.buffer, table, columns=columns)
|
||||
self.buffer = io.StringIO()
|
||||
14
src/nominatim_core/errors.py
Normal file
14
src/nominatim_core/errors.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Custom exception and error classes for Nominatim.
|
||||
"""
|
||||
|
||||
class UsageError(Exception):
|
||||
""" An error raised because of bad user input. This error will usually
|
||||
not cause a stack trace to be printed unless debugging is enabled.
|
||||
"""
|
||||
15
src/nominatim_core/paths.py
Normal file
15
src/nominatim_core/paths.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Path settings for extra data used by Nominatim.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
PHPLIB_DIR = (Path(__file__) / '..' / '..' / '..' / 'lib-php').resolve()
|
||||
SQLLIB_DIR = (Path(__file__) / '..' / '..' / '..' / 'lib-sql').resolve()
|
||||
DATA_DIR = (Path(__file__) / '..' / '..' / '..' / 'data').resolve()
|
||||
CONFIG_DIR = (Path(__file__) / '..' / '..' / '..' / 'settings').resolve()
|
||||
0
src/nominatim_core/py.typed
Normal file
0
src/nominatim_core/py.typed
Normal file
75
src/nominatim_core/typing.py
Normal file
75
src/nominatim_core/typing.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Type definitions for typing annotations.
|
||||
|
||||
Complex type definitions are moved here, to keep the source files readable.
|
||||
"""
|
||||
from typing import Any, Union, Mapping, TypeVar, Sequence, TYPE_CHECKING
|
||||
|
||||
# Generics variable names do not confirm to naming styles, ignore globally here.
|
||||
# pylint: disable=invalid-name,abstract-method,multiple-statements
|
||||
# pylint: disable=missing-class-docstring,useless-import-alias
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import psycopg2.sql
|
||||
import psycopg2.extensions
|
||||
import psycopg2.extras
|
||||
import os
|
||||
|
||||
StrPath = Union[str, 'os.PathLike[str]']
|
||||
|
||||
SysEnv = Mapping[str, str]
|
||||
|
||||
# psycopg2-related types
|
||||
|
||||
Query = Union[str, bytes, 'psycopg2.sql.Composable']
|
||||
|
||||
T_ResultKey = TypeVar('T_ResultKey', int, str)
|
||||
|
||||
class DictCursorResult(Mapping[str, Any]):
|
||||
def __getitem__(self, x: Union[int, str]) -> Any: ...
|
||||
|
||||
DictCursorResults = Sequence[DictCursorResult]
|
||||
|
||||
T_cursor = TypeVar('T_cursor', bound='psycopg2.extensions.cursor')
|
||||
|
||||
# The following typing features require typing_extensions to work
|
||||
# on all supported Python versions.
|
||||
# Only require this for type checking but not for normal operations.
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import (Protocol as Protocol,
|
||||
Final as Final,
|
||||
TypedDict as TypedDict)
|
||||
else:
|
||||
Protocol = object
|
||||
Final = 'Final'
|
||||
TypedDict = dict
|
||||
|
||||
|
||||
# SQLAlchemy introduced generic types in version 2.0 making typing
|
||||
# incompatible with older versions. Add wrappers here so we don't have
|
||||
# to litter the code with bare-string types.
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sqlalchemy as sa
|
||||
from typing_extensions import (TypeAlias as TypeAlias)
|
||||
else:
|
||||
TypeAlias = str
|
||||
|
||||
SaLambdaSelect: TypeAlias = 'Union[sa.Select[Any], sa.StatementLambdaElement]'
|
||||
SaSelect: TypeAlias = 'sa.Select[Any]'
|
||||
SaScalarSelect: TypeAlias = 'sa.ScalarSelect[Any]'
|
||||
SaRow: TypeAlias = 'sa.Row[Any]'
|
||||
SaColumn: TypeAlias = 'sa.ColumnElement[Any]'
|
||||
SaExpression: TypeAlias = 'sa.ColumnElement[bool]'
|
||||
SaLabel: TypeAlias = 'sa.Label[Any]'
|
||||
SaFromClause: TypeAlias = 'sa.FromClause'
|
||||
SaSelectable: TypeAlias = 'sa.Selectable'
|
||||
SaBind: TypeAlias = 'sa.BindParameter[Any]'
|
||||
SaDialect: TypeAlias = 'sa.Dialect'
|
||||
0
src/nominatim_core/utils/__init__.py
Normal file
0
src/nominatim_core/utils/__init__.py
Normal file
49
src/nominatim_core/utils/centroid.py
Normal file
49
src/nominatim_core/utils/centroid.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Functions for computation of centroids.
|
||||
"""
|
||||
from typing import Tuple, Any
|
||||
from collections.abc import Collection
|
||||
|
||||
class PointsCentroid:
|
||||
""" Centroid computation from single points using an online algorithm.
|
||||
More points may be added at any time.
|
||||
|
||||
Coordinates are internally treated as a 7-digit fixed-point float
|
||||
(i.e. in OSM style).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.sum_x = 0
|
||||
self.sum_y = 0
|
||||
self.count = 0
|
||||
|
||||
def centroid(self) -> Tuple[float, float]:
|
||||
""" Return the centroid of all points collected so far.
|
||||
"""
|
||||
if self.count == 0:
|
||||
raise ValueError("No points available for centroid.")
|
||||
|
||||
return (float(self.sum_x/self.count)/10000000,
|
||||
float(self.sum_y/self.count)/10000000)
|
||||
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.count
|
||||
|
||||
|
||||
def __iadd__(self, other: Any) -> 'PointsCentroid':
|
||||
if isinstance(other, Collection) and len(other) == 2:
|
||||
if all(isinstance(p, (float, int)) for p in other):
|
||||
x, y = other
|
||||
self.sum_x += int(x * 10000000)
|
||||
self.sum_y += int(y * 10000000)
|
||||
self.count += 1
|
||||
return self
|
||||
|
||||
raise ValueError("Can only add 2-element tuples to centroid.")
|
||||
149
src/nominatim_core/utils/json_writer.py
Normal file
149
src/nominatim_core/utils/json_writer.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Streaming JSON encoder.
|
||||
"""
|
||||
from typing import Any, TypeVar, Optional, Callable
|
||||
import io
|
||||
try:
|
||||
import ujson as json
|
||||
except ModuleNotFoundError:
|
||||
import json # type: ignore[no-redef]
|
||||
|
||||
T = TypeVar('T') # pylint: disable=invalid-name
|
||||
|
||||
class JsonWriter:
|
||||
""" JSON encoder that renders the output directly into an output
|
||||
stream. This is a very simple writer which produces JSON in a
|
||||
compact as possible form.
|
||||
|
||||
The writer does not check for syntactic correctness. It is the
|
||||
responsibility of the caller to call the write functions in an
|
||||
order that produces correct JSON.
|
||||
|
||||
All functions return the writer object itself so that function
|
||||
calls can be chained.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = io.StringIO()
|
||||
self.pending = ''
|
||||
|
||||
|
||||
def __call__(self) -> str:
|
||||
""" Return the rendered JSON content as a string.
|
||||
The writer remains usable after calling this function.
|
||||
"""
|
||||
if self.pending:
|
||||
assert self.pending in (']', '}')
|
||||
self.data.write(self.pending)
|
||||
self.pending = ''
|
||||
return self.data.getvalue()
|
||||
|
||||
|
||||
def start_object(self) -> 'JsonWriter':
|
||||
""" Write the open bracket of a JSON object.
|
||||
"""
|
||||
if self.pending:
|
||||
self.data.write(self.pending)
|
||||
self.pending = '{'
|
||||
return self
|
||||
|
||||
|
||||
def end_object(self) -> 'JsonWriter':
|
||||
""" Write the closing bracket of a JSON object.
|
||||
"""
|
||||
assert self.pending in (',', '{', '')
|
||||
if self.pending == '{':
|
||||
self.data.write(self.pending)
|
||||
self.pending = '}'
|
||||
return self
|
||||
|
||||
|
||||
def start_array(self) -> 'JsonWriter':
|
||||
""" Write the opening bracket of a JSON array.
|
||||
"""
|
||||
if self.pending:
|
||||
self.data.write(self.pending)
|
||||
self.pending = '['
|
||||
return self
|
||||
|
||||
|
||||
def end_array(self) -> 'JsonWriter':
|
||||
""" Write the closing bracket of a JSON array.
|
||||
"""
|
||||
assert self.pending in (',', '[', ']', ')', '')
|
||||
if self.pending not in (',', ''):
|
||||
self.data.write(self.pending)
|
||||
self.pending = ']'
|
||||
return self
|
||||
|
||||
|
||||
def key(self, name: str) -> 'JsonWriter':
|
||||
""" Write the key string of a JSON object.
|
||||
"""
|
||||
assert self.pending
|
||||
self.data.write(self.pending)
|
||||
self.data.write(json.dumps(name, ensure_ascii=False))
|
||||
self.pending = ':'
|
||||
return self
|
||||
|
||||
|
||||
def value(self, value: Any) -> 'JsonWriter':
|
||||
""" Write out a value as JSON. The function uses the json.dumps()
|
||||
function for encoding the JSON. Thus any value that can be
|
||||
encoded by that function is permissible here.
|
||||
"""
|
||||
return self.raw(json.dumps(value, ensure_ascii=False))
|
||||
|
||||
|
||||
def float(self, value: float, precision: int) -> 'JsonWriter':
|
||||
""" Write out a float value with the given precision.
|
||||
"""
|
||||
return self.raw(f"{value:0.{precision}f}")
|
||||
|
||||
def next(self) -> 'JsonWriter':
|
||||
""" Write out a delimiter comma between JSON object or array elements.
|
||||
"""
|
||||
if self.pending:
|
||||
self.data.write(self.pending)
|
||||
self.pending = ','
|
||||
return self
|
||||
|
||||
|
||||
def raw(self, raw_json: str) -> 'JsonWriter':
|
||||
""" Write out the given value as is. This function is useful if
|
||||
a value is already available in JSON format.
|
||||
"""
|
||||
if self.pending:
|
||||
self.data.write(self.pending)
|
||||
self.pending = ''
|
||||
self.data.write(raw_json)
|
||||
return self
|
||||
|
||||
|
||||
def keyval(self, key: str, value: Any) -> 'JsonWriter':
|
||||
""" Write out an object element with the given key and value.
|
||||
This is a shortcut for calling 'key()', 'value()' and 'next()'.
|
||||
"""
|
||||
self.key(key)
|
||||
self.value(value)
|
||||
return self.next()
|
||||
|
||||
|
||||
def keyval_not_none(self, key: str, value: Optional[T],
|
||||
transform: Optional[Callable[[T], Any]] = None) -> 'JsonWriter':
|
||||
""" Write out an object element only if the value is not None.
|
||||
If 'transform' is given, it must be a function that takes the
|
||||
value type and returns a JSON encodable type. The transform
|
||||
function will be called before the value is written out.
|
||||
"""
|
||||
if value is not None:
|
||||
self.key(key)
|
||||
self.value(transform(value) if transform else value)
|
||||
self.next()
|
||||
return self
|
||||
31
src/nominatim_core/utils/url_utils.py
Normal file
31
src/nominatim_core/utils/url_utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Helper functions for accessing URL.
|
||||
"""
|
||||
from typing import IO
|
||||
import logging
|
||||
import urllib.request as urlrequest
|
||||
|
||||
from ..version import NOMINATIM_CORE_VERSION
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
def get_url(url: str) -> str:
|
||||
""" Get the contents from the given URL and return it as a UTF-8 string.
|
||||
|
||||
This version makes sure that an appropriate user agent is sent.
|
||||
"""
|
||||
headers = {"User-Agent": f"Nominatim/{NOMINATIM_CORE_VERSION!s}"}
|
||||
|
||||
try:
|
||||
request = urlrequest.Request(url, headers=headers)
|
||||
with urlrequest.urlopen(request) as response: # type: IO[bytes]
|
||||
return response.read().decode('utf-8')
|
||||
except Exception:
|
||||
LOG.fatal('Failed to load URL: %s', url)
|
||||
raise
|
||||
11
src/nominatim_core/version.py
Normal file
11
src/nominatim_core/version.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Version information for the Nominatim core package.
|
||||
"""
|
||||
|
||||
NOMINATIM_CORE_VERSION = '4.4.99'
|
||||
Reference in New Issue
Block a user