port code to psycopg3

This commit is contained in:
Sarah Hoffmann
2024-07-05 10:43:10 +02:00
parent 3742fa2929
commit 9659afbade
57 changed files with 800 additions and 1330 deletions

View File

@@ -31,6 +31,7 @@ CREATE INDEX IF NOT EXISTS idx_placex_geometry ON placex
-- Index is needed during import but can be dropped as soon as a full -- Index is needed during import but can be dropped as soon as a full
-- geometry index is in place. The partial index is almost as big as the full -- geometry index is in place. The partial index is almost as big as the full
-- index. -- index.
---
DROP INDEX IF EXISTS idx_placex_geometry_lower_rank_ways; DROP INDEX IF EXISTS idx_placex_geometry_lower_rank_ways;
--- ---
CREATE INDEX IF NOT EXISTS idx_placex_geometry_reverse_lookupPolygon CREATE INDEX IF NOT EXISTS idx_placex_geometry_reverse_lookupPolygon
@@ -60,7 +61,6 @@ CREATE INDEX IF NOT EXISTS idx_postcode_postcode
--- ---
DROP INDEX IF EXISTS idx_placex_geometry_address_area_candidates; DROP INDEX IF EXISTS idx_placex_geometry_address_area_candidates;
DROP INDEX IF EXISTS idx_placex_geometry_buildings; DROP INDEX IF EXISTS idx_placex_geometry_buildings;
DROP INDEX IF EXISTS idx_placex_geometry_lower_rank_ways;
DROP INDEX IF EXISTS idx_placex_wikidata; DROP INDEX IF EXISTS idx_placex_wikidata;
DROP INDEX IF EXISTS idx_placex_rank_address_sector; DROP INDEX IF EXISTS idx_placex_rank_address_sector;
DROP INDEX IF EXISTS idx_placex_rank_boundaries_sector; DROP INDEX IF EXISTS idx_placex_rank_boundaries_sector;

View File

@@ -15,7 +15,7 @@ classifiers = [
"Operating System :: OS Independent", "Operating System :: OS Independent",
] ]
dependencies = [ dependencies = [
"psycopg2-binary", "psycopg[pool]",
"python-dotenv", "python-dotenv",
"jinja2", "jinja2",
"pyYAML>=5.1", "pyYAML>=5.1",

View File

@@ -7,7 +7,7 @@
""" """
Implementation of classes for API access via libraries. Implementation of classes for API access via libraries.
""" """
from typing import Mapping, Optional, Any, AsyncIterator, Dict, Sequence, List, Tuple from typing import Mapping, Optional, Any, AsyncIterator, Dict, Sequence, List, Tuple, cast
import asyncio import asyncio
import sys import sys
import contextlib import contextlib
@@ -107,16 +107,16 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
raise UsageError(f"SQlite database '{params.get('dbname')}' does not exist.") raise UsageError(f"SQlite database '{params.get('dbname')}' does not exist.")
else: else:
dsn = self.config.get_database_params() dsn = self.config.get_database_params()
query = {k: v for k, v in dsn.items() query = {k: str(v) for k, v in dsn.items()
if k not in ('user', 'password', 'dbname', 'host', 'port')} if k not in ('user', 'password', 'dbname', 'host', 'port')}
dburl = sa.engine.URL.create( dburl = sa.engine.URL.create(
f'postgresql+{PGCORE_LIB}', f'postgresql+{PGCORE_LIB}',
database=dsn.get('dbname'), database=cast(str, dsn.get('dbname')),
username=dsn.get('user'), username=cast(str, dsn.get('user')),
password=dsn.get('password'), password=cast(str, dsn.get('password')),
host=dsn.get('host'), host=cast(str, dsn.get('host')),
port=int(dsn['port']) if 'port' in dsn else None, port=int(cast(str, dsn['port'])) if 'port' in dsn else None,
query=query) query=query)
engine = sa_asyncio.create_async_engine(dburl, **extra_args) engine = sa_asyncio.create_async_engine(dburl, **extra_args)

View File

@@ -14,6 +14,7 @@ import logging
import os import os
import sys import sys
import argparse import argparse
import asyncio
from pathlib import Path from pathlib import Path
from .config import Configuration from .config import Configuration
@@ -170,24 +171,32 @@ class AdminServe:
raise UsageError("PHP frontend not configured.") raise UsageError("PHP frontend not configured.")
run_php_server(args.server, args.project_dir / 'website') run_php_server(args.server, args.project_dir / 'website')
else: else:
import uvicorn # pylint: disable=import-outside-toplevel asyncio.run(self.run_uvicorn(args))
server_info = args.server.split(':', 1)
host = server_info[0]
if len(server_info) > 1:
if not server_info[1].isdigit():
raise UsageError('Invalid format for --server parameter. Use <host>:<port>')
port = int(server_info[1])
else:
port = 8088
server_module = importlib.import_module(f'nominatim_api.server.{args.engine}.server')
app = server_module.get_application(args.project_dir)
uvicorn.run(app, host=host, port=port)
return 0 return 0
async def run_uvicorn(self, args: NominatimArgs) -> None:
import uvicorn # pylint: disable=import-outside-toplevel
server_info = args.server.split(':', 1)
host = server_info[0]
if len(server_info) > 1:
if not server_info[1].isdigit():
raise UsageError('Invalid format for --server parameter. Use <host>:<port>')
port = int(server_info[1])
else:
port = 8088
server_module = importlib.import_module(f'nominatim_api.server.{args.engine}.server')
app = server_module.get_application(args.project_dir)
config = uvicorn.Config(app, host=host, port=port)
server = uvicorn.Server(config)
await server.serve()
def get_set_parser() -> CommandlineParser: def get_set_parser() -> CommandlineParser:
"""\ """\
Initializes the parser and adds various subcommands for Initializes the parser and adds various subcommands for

View File

@@ -10,6 +10,7 @@ Implementation of the 'add-data' subcommand.
from typing import cast from typing import cast
import argparse import argparse
import logging import logging
import asyncio
import psutil import psutil
@@ -64,15 +65,10 @@ class UpdateAddData:
def run(self, args: NominatimArgs) -> int: def run(self, args: NominatimArgs) -> int:
from ..tokenizer import factory as tokenizer_factory from ..tools import add_osm_data
from ..tools import tiger_data, add_osm_data
if args.tiger_data: if args.tiger_data:
tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config) return asyncio.run(self._add_tiger_data(args))
return tiger_data.add_tiger_data(args.tiger_data,
args.config,
args.threads or psutil.cpu_count() or 1,
tokenizer)
osm2pgsql_params = args.osm2pgsql_options(default_cache=1000, default_threads=1) osm2pgsql_params = args.osm2pgsql_options(default_cache=1000, default_threads=1)
if args.file or args.diff: if args.file or args.diff:
@@ -99,3 +95,16 @@ class UpdateAddData:
osm2pgsql_params) osm2pgsql_params)
return 0 return 0
async def _add_tiger_data(self, args: NominatimArgs) -> int:
from ..tokenizer import factory as tokenizer_factory
from ..tools import tiger_data
assert args.tiger_data
tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
return await tiger_data.add_tiger_data(args.tiger_data,
args.config,
args.threads or psutil.cpu_count() or 1,
tokenizer)

View File

@@ -8,6 +8,7 @@
Implementation of the 'index' subcommand. Implementation of the 'index' subcommand.
""" """
import argparse import argparse
import asyncio
import psutil import psutil
@@ -44,19 +45,7 @@ class UpdateIndex:
def run(self, args: NominatimArgs) -> int: def run(self, args: NominatimArgs) -> int:
from ..indexer.indexer import Indexer asyncio.run(self._do_index(args))
from ..tokenizer import factory as tokenizer_factory
tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
indexer = Indexer(args.config.get_libpq_dsn(), tokenizer,
args.threads or psutil.cpu_count() or 1)
if not args.no_boundaries:
indexer.index_boundaries(args.minrank, args.maxrank)
if not args.boundaries_only:
indexer.index_by_rank(args.minrank, args.maxrank)
indexer.index_postcodes()
if not args.no_boundaries and not args.boundaries_only \ if not args.no_boundaries and not args.boundaries_only \
and args.minrank == 0 and args.maxrank == 30: and args.minrank == 0 and args.maxrank == 30:
@@ -64,3 +53,22 @@ class UpdateIndex:
status.set_indexed(conn, True) status.set_indexed(conn, True)
return 0 return 0
async def _do_index(self, args: NominatimArgs) -> None:
from ..tokenizer import factory as tokenizer_factory
tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
from ..indexer.indexer import Indexer
indexer = Indexer(args.config.get_libpq_dsn(), tokenizer,
args.threads or psutil.cpu_count() or 1)
has_pending = True # run at least once
while has_pending:
if not args.no_boundaries:
await indexer.index_boundaries(args.minrank, args.maxrank)
if not args.boundaries_only:
await indexer.index_by_rank(args.minrank, args.maxrank)
await indexer.index_postcodes()
has_pending = indexer.has_pending()

View File

@@ -11,6 +11,7 @@ from typing import Tuple, Optional
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import asyncio
from ..config import Configuration from ..config import Configuration
from ..db.connection import connect, table_exists from ..db.connection import connect, table_exists
@@ -99,7 +100,7 @@ class UpdateRefresh:
args.project_dir, tokenizer) args.project_dir, tokenizer)
indexer = Indexer(args.config.get_libpq_dsn(), tokenizer, indexer = Indexer(args.config.get_libpq_dsn(), tokenizer,
args.threads or 1) args.threads or 1)
indexer.index_postcodes() asyncio.run(indexer.index_postcodes())
else: else:
LOG.error("The place table doesn't exist. " LOG.error("The place table doesn't exist. "
"Postcode updates on a frozen database is not possible.") "Postcode updates on a frozen database is not possible.")

View File

@@ -13,6 +13,7 @@ import datetime as dt
import logging import logging
import socket import socket
import time import time
import asyncio
from ..db import status from ..db import status
from ..db.connection import connect from ..db.connection import connect
@@ -123,7 +124,7 @@ class UpdateReplication:
return update_interval return update_interval
def _update(self, args: NominatimArgs) -> None: async def _update(self, args: NominatimArgs) -> None:
# pylint: disable=too-many-locals # pylint: disable=too-many-locals
from ..tools import replication from ..tools import replication
from ..indexer.indexer import Indexer from ..indexer.indexer import Indexer
@@ -161,7 +162,7 @@ class UpdateReplication:
if state is not replication.UpdateState.NO_CHANGES and args.do_index: if state is not replication.UpdateState.NO_CHANGES and args.do_index:
index_start = dt.datetime.now(dt.timezone.utc) index_start = dt.datetime.now(dt.timezone.utc)
indexer.index_full(analyse=False) await indexer.index_full(analyse=False)
with connect(dsn) as conn: with connect(dsn) as conn:
status.set_indexed(conn, True) status.set_indexed(conn, True)
@@ -172,8 +173,7 @@ class UpdateReplication:
if state is replication.UpdateState.NO_CHANGES and \ if state is replication.UpdateState.NO_CHANGES and \
args.catch_up or update_interval > 40*60: args.catch_up or update_interval > 40*60:
while indexer.has_pending(): await indexer.index_full(analyse=False)
indexer.index_full(analyse=False)
if LOG.isEnabledFor(logging.WARNING): if LOG.isEnabledFor(logging.WARNING):
assert batchdate is not None assert batchdate is not None
@@ -196,5 +196,5 @@ class UpdateReplication:
if args.check_for_updates: if args.check_for_updates:
return self._check_for_updates(args) return self._check_for_updates(args)
self._update(args) asyncio.run(self._update(args))
return 0 return 0

View File

@@ -11,6 +11,7 @@ from typing import Optional
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import asyncio
import psutil import psutil
@@ -71,14 +72,6 @@ class SetupAll:
def run(self, args: NominatimArgs) -> int: # pylint: disable=too-many-statements, too-many-branches def run(self, args: NominatimArgs) -> int: # pylint: disable=too-many-statements, too-many-branches
from ..data import country_info
from ..tools import database_import, refresh, postcodes, freeze
from ..indexer.indexer import Indexer
num_threads = args.threads or psutil.cpu_count() or 1
country_info.setup_country_config(args.config)
if args.osm_file is None and args.continue_at is None and not args.prepare_database: if args.osm_file is None and args.continue_at is None and not args.prepare_database:
raise UsageError("No input files (use --osm-file).") raise UsageError("No input files (use --osm-file).")
@@ -90,6 +83,16 @@ class SetupAll:
"Cannot use --continue and --prepare-database together." "Cannot use --continue and --prepare-database together."
) )
return asyncio.run(self.async_run(args))
async def async_run(self, args: NominatimArgs) -> int:
from ..data import country_info
from ..tools import database_import, refresh, postcodes, freeze
from ..indexer.indexer import Indexer
num_threads = args.threads or psutil.cpu_count() or 1
country_info.setup_country_config(args.config)
if args.prepare_database or args.continue_at is None: if args.prepare_database or args.continue_at is None:
LOG.warning('Creating database') LOG.warning('Creating database')
@@ -99,39 +102,7 @@ class SetupAll:
return 0 return 0
if args.continue_at in (None, 'import-from-file'): if args.continue_at in (None, 'import-from-file'):
files = args.get_osm_file_list() self._base_import(args)
if not files:
raise UsageError("No input files (use --osm-file).")
if args.continue_at in ('import-from-file', None):
# Check if the correct plugins are installed
database_import.check_existing_database_plugins(args.config.get_libpq_dsn())
LOG.warning('Setting up country tables')
country_info.setup_country_tables(args.config.get_libpq_dsn(),
args.config.lib_dir.data,
args.no_partitions)
LOG.warning('Importing OSM data file')
database_import.import_osm_data(files,
args.osm2pgsql_options(0, 1),
drop=args.no_updates,
ignore_errors=args.ignore_errors)
LOG.warning('Importing wikipedia importance data')
data_path = Path(args.config.WIKIPEDIA_DATA_PATH or args.project_dir)
if refresh.import_wikipedia_articles(args.config.get_libpq_dsn(),
data_path) > 0:
LOG.error('Wikipedia importance dump file not found. '
'Calculating importance values of locations will not '
'use Wikipedia importance data.')
LOG.warning('Importing secondary importance raster data')
if refresh.import_secondary_importance(args.config.get_libpq_dsn(),
args.project_dir) != 0:
LOG.error('Secondary importance file not imported. '
'Falling back to default ranking.')
self._setup_tables(args.config, args.reverse_only)
if args.continue_at in ('import-from-file', 'load-data', None): if args.continue_at in ('import-from-file', 'load-data', None):
LOG.warning('Initialise tables') LOG.warning('Initialise tables')
@@ -139,7 +110,7 @@ class SetupAll:
database_import.truncate_data_tables(conn) database_import.truncate_data_tables(conn)
LOG.warning('Load data into placex table') LOG.warning('Load data into placex table')
database_import.load_data(args.config.get_libpq_dsn(), num_threads) await database_import.load_data(args.config.get_libpq_dsn(), num_threads)
LOG.warning("Setting up tokenizer") LOG.warning("Setting up tokenizer")
tokenizer = self._get_tokenizer(args.continue_at, args.config) tokenizer = self._get_tokenizer(args.continue_at, args.config)
@@ -153,13 +124,13 @@ class SetupAll:
('import-from-file', 'load-data', 'indexing', None): ('import-from-file', 'load-data', 'indexing', None):
LOG.warning('Indexing places') LOG.warning('Indexing places')
indexer = Indexer(args.config.get_libpq_dsn(), tokenizer, num_threads) indexer = Indexer(args.config.get_libpq_dsn(), tokenizer, num_threads)
indexer.index_full(analyse=not args.index_noanalyse) await indexer.index_full(analyse=not args.index_noanalyse)
LOG.warning('Post-process tables') LOG.warning('Post-process tables')
with connect(args.config.get_libpq_dsn()) as conn: with connect(args.config.get_libpq_dsn()) as conn:
database_import.create_search_indices(conn, args.config, await database_import.create_search_indices(conn, args.config,
drop=args.no_updates, drop=args.no_updates,
threads=num_threads) threads=num_threads)
LOG.warning('Create search index for default country names.') LOG.warning('Create search index for default country names.')
country_info.create_country_names(conn, tokenizer, country_info.create_country_names(conn, tokenizer,
args.config.get_str_list('LANGUAGES')) args.config.get_str_list('LANGUAGES'))
@@ -180,6 +151,45 @@ class SetupAll:
return 0 return 0
def _base_import(self, args: NominatimArgs) -> None:
from ..tools import database_import, refresh
from ..data import country_info
files = args.get_osm_file_list()
if not files:
raise UsageError("No input files (use --osm-file).")
if args.continue_at in ('import-from-file', None):
# Check if the correct plugins are installed
database_import.check_existing_database_plugins(args.config.get_libpq_dsn())
LOG.warning('Setting up country tables')
country_info.setup_country_tables(args.config.get_libpq_dsn(),
args.config.lib_dir.data,
args.no_partitions)
LOG.warning('Importing OSM data file')
database_import.import_osm_data(files,
args.osm2pgsql_options(0, 1),
drop=args.no_updates,
ignore_errors=args.ignore_errors)
LOG.warning('Importing wikipedia importance data')
data_path = Path(args.config.WIKIPEDIA_DATA_PATH or args.project_dir)
if refresh.import_wikipedia_articles(args.config.get_libpq_dsn(),
data_path) > 0:
LOG.error('Wikipedia importance dump file not found. '
'Calculating importance values of locations will not '
'use Wikipedia importance data.')
LOG.warning('Importing secondary importance raster data')
if refresh.import_secondary_importance(args.config.get_libpq_dsn(),
args.project_dir) != 0:
LOG.error('Secondary importance file not imported. '
'Falling back to default ranking.')
self._setup_tables(args.config, args.reverse_only)
def _setup_tables(self, config: Configuration, reverse_only: bool) -> None: def _setup_tables(self, config: Configuration, reverse_only: bool) -> None:
""" Set up the basic database layout: tables, indexes and functions. """ Set up the basic database layout: tables, indexes and functions.
""" """

View File

@@ -7,7 +7,7 @@
""" """
Nominatim configuration accessor. Nominatim configuration accessor.
""" """
from typing import Dict, Any, List, Mapping, Optional from typing import Union, Dict, Any, List, Mapping, Optional
import importlib.util import importlib.util
import logging import logging
import os import os
@@ -18,10 +18,7 @@ import yaml
from dotenv import dotenv_values from dotenv import dotenv_values
try: from psycopg.conninfo import conninfo_to_dict
from psycopg2.extensions import parse_dsn
except ModuleNotFoundError:
from psycopg.conninfo import conninfo_to_dict as parse_dsn # type: ignore[assignment]
from .typing import StrPath from .typing import StrPath
from .errors import UsageError from .errors import UsageError
@@ -198,7 +195,7 @@ class Configuration:
return dsn return dsn
def get_database_params(self) -> Mapping[str, str]: def get_database_params(self) -> Mapping[str, Union[str, int, None]]:
""" Get the configured parameters for the database connection """ Get the configured parameters for the database connection
as a mapping. as a mapping.
""" """
@@ -207,7 +204,7 @@ class Configuration:
if dsn.startswith('pgsql:'): if dsn.startswith('pgsql:'):
return dict((p.split('=', 1) for p in dsn[6:].split(';'))) return dict((p.split('=', 1) for p in dsn[6:].split(';')))
return parse_dsn(dsn) return conninfo_to_dict(dsn)
def get_import_style_file(self) -> Path: def get_import_style_file(self) -> Path:

View File

@@ -138,9 +138,10 @@ def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = Fals
country_default_language_code text, country_default_language_code text,
partition integer partition integer
); """) ); """)
cur.execute_values( cur.executemany(
""" INSERT INTO public.country_name """ INSERT INTO public.country_name
(country_code, name, country_default_language_code, partition) VALUES %s (country_code, name, country_default_language_code, partition)
VALUES (%s, %s, %s, %s)
""", params) """, params)
conn.commit() conn.commit()

View File

@@ -1,236 +0,0 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
""" Non-blocking database connections.
"""
from typing import Callable, Any, Optional, Iterator, Sequence
import logging
import select
import time
import psycopg2
from psycopg2.extras import wait_select
# psycopg2 emits different exceptions pre and post 2.8. Detect if the new error
# module is available and adapt the error handling accordingly.
try:
import psycopg2.errors # pylint: disable=no-name-in-module,import-error
__has_psycopg2_errors__ = True
except ImportError:
__has_psycopg2_errors__ = False
from ..typing import T_cursor, Query
LOG = logging.getLogger()
class DeadlockHandler:
""" Context manager that catches deadlock exceptions and calls
the given handler function. All other exceptions are passed on
normally.
"""
def __init__(self, handler: Callable[[], None], ignore_sql_errors: bool = False) -> None:
self.handler = handler
self.ignore_sql_errors = ignore_sql_errors
def __enter__(self) -> 'DeadlockHandler':
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
if __has_psycopg2_errors__:
if exc_type == psycopg2.errors.DeadlockDetected: # pylint: disable=E1101
self.handler()
return True
elif exc_type == psycopg2.extensions.TransactionRollbackError \
and exc_value.pgcode == '40P01':
self.handler()
return True
if self.ignore_sql_errors and isinstance(exc_value, psycopg2.Error):
LOG.info("SQL error ignored: %s", exc_value)
return True
return False
class DBConnection:
""" A single non-blocking database connection.
"""
def __init__(self, dsn: str,
cursor_factory: Optional[Callable[..., T_cursor]] = None,
ignore_sql_errors: bool = False) -> None:
self.dsn = dsn
self.current_query: Optional[Query] = None
self.current_params: Optional[Sequence[Any]] = None
self.ignore_sql_errors = ignore_sql_errors
self.conn: Optional['psycopg2._psycopg.connection'] = None
self.cursor: Optional['psycopg2._psycopg.cursor'] = None
self.connect(cursor_factory=cursor_factory)
def close(self) -> None:
""" Close all open connections. Does not wait for pending requests.
"""
if self.conn is not None:
if self.cursor is not None:
self.cursor.close()
self.cursor = None
self.conn.close()
self.conn = None
def connect(self, cursor_factory: Optional[Callable[..., T_cursor]] = None) -> None:
""" (Re)connect to the database. Creates an asynchronous connection
with JIT and parallel processing disabled. If a connection was
already open, it is closed and a new connection established.
The caller must ensure that no query is pending before reconnecting.
"""
self.close()
# Use a dict to hand in the parameters because async is a reserved
# word in Python3.
self.conn = psycopg2.connect(**{'dsn': self.dsn, 'async': True}) # type: ignore
assert self.conn
self.wait()
if cursor_factory is not None:
self.cursor = self.conn.cursor(cursor_factory=cursor_factory)
else:
self.cursor = self.conn.cursor()
# Disable JIT and parallel workers as they are known to cause problems.
# Update pg_settings instead of using SET because it does not yield
# errors on older versions of Postgres where the settings are not
# implemented.
self.perform(
""" UPDATE pg_settings SET setting = -1 WHERE name = 'jit_above_cost';
UPDATE pg_settings SET setting = 0
WHERE name = 'max_parallel_workers_per_gather';""")
self.wait()
def _deadlock_handler(self) -> None:
LOG.info("Deadlock detected (params = %s), retry.", str(self.current_params))
assert self.cursor is not None
assert self.current_query is not None
assert self.current_params is not None
self.cursor.execute(self.current_query, self.current_params)
def wait(self) -> None:
""" Block until any pending operation is done.
"""
while True:
with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
wait_select(self.conn)
self.current_query = None
return
def perform(self, sql: Query, args: Optional[Sequence[Any]] = None) -> None:
""" Send SQL query to the server. Returns immediately without
blocking.
"""
assert self.cursor is not None
self.current_query = sql
self.current_params = args
self.cursor.execute(sql, args)
def fileno(self) -> int:
""" File descriptor to wait for. (Makes this class select()able.)
"""
assert self.conn is not None
return self.conn.fileno()
def is_done(self) -> bool:
""" Check if the connection is available for a new query.
Also checks if the previous query has run into a deadlock.
If so, then the previous query is repeated.
"""
assert self.conn is not None
if self.current_query is None:
return True
with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
if self.conn.poll() == psycopg2.extensions.POLL_OK:
self.current_query = None
return True
return False
class WorkerPool:
""" A pool of asynchronous database connections.
The pool may be used as a context manager.
"""
REOPEN_CONNECTIONS_AFTER = 100000
def __init__(self, dsn: str, pool_size: int, ignore_sql_errors: bool = False) -> None:
self.threads = [DBConnection(dsn, ignore_sql_errors=ignore_sql_errors)
for _ in range(pool_size)]
self.free_workers = self._yield_free_worker()
self.wait_time = 0.0
def finish_all(self) -> None:
""" Wait for all connection to finish.
"""
for thread in self.threads:
while not thread.is_done():
thread.wait()
self.free_workers = self._yield_free_worker()
def close(self) -> None:
""" Close all connections and clear the pool.
"""
for thread in self.threads:
thread.close()
self.threads = []
self.free_workers = iter([])
def next_free_worker(self) -> DBConnection:
""" Get the next free connection.
"""
return next(self.free_workers)
def _yield_free_worker(self) -> Iterator[DBConnection]:
ready = self.threads
command_stat = 0
while True:
for thread in ready:
if thread.is_done():
command_stat += 1
yield thread
if command_stat > self.REOPEN_CONNECTIONS_AFTER:
self._reconnect_threads()
ready = self.threads
command_stat = 0
else:
tstart = time.time()
_, ready, _ = select.select([], self.threads, [])
self.wait_time += time.time() - tstart
def _reconnect_threads(self) -> None:
for thread in self.threads:
while not thread.is_done():
thread.wait()
thread.connect()
def __enter__(self) -> 'WorkerPool':
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.finish_all()
self.close()

View File

@@ -7,73 +7,27 @@
""" """
Specialised connection and cursor functions. Specialised connection and cursor functions.
""" """
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload,\ from typing import Optional, Any, Dict, Tuple
Tuple, Iterable
import contextlib
import logging import logging
import os import os
import psycopg2 import psycopg
import psycopg2.extensions import psycopg.types.hstore
import psycopg2.extras from psycopg import sql as pysql
from psycopg2 import sql as pysql
from ..typing import SysEnv, Query, T_cursor from ..typing import SysEnv
from ..errors import UsageError from ..errors import UsageError
LOG = logging.getLogger() LOG = logging.getLogger()
class Cursor(psycopg2.extras.DictCursor): Cursor = psycopg.Cursor[Any]
""" A cursor returning dict-like objects and providing specialised Connection = psycopg.Connection[Any]
execution functions.
"""
# pylint: disable=arguments-renamed,arguments-differ
def execute(self, query: Query, args: Any = None) -> None:
""" Query execution that logs the SQL query when debugging is enabled.
"""
if LOG.isEnabledFor(logging.DEBUG):
LOG.debug(self.mogrify(query, args).decode('utf-8'))
super().execute(query, args) def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -> Any:
def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
template: Optional[Query] = None) -> None:
""" Wrapper for the psycopg2 convenience function to execute
SQL for a list of values.
"""
LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
psycopg2.extras.execute_values(self, sql, argslist, template=template)
class Connection(psycopg2.extensions.connection):
""" A connection that provides the specialised cursor by default and
adds convenience functions for administrating the database.
"""
@overload # type: ignore[override]
def cursor(self) -> Cursor:
...
@overload
def cursor(self, name: str) -> Cursor:
...
@overload
def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
...
def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore
""" Return a new cursor. By default the specialised cursor is returned.
"""
return super().cursor(cursor_factory=cursor_factory, **kwargs)
def execute_scalar(conn: Connection, sql: Query, args: Any = None) -> Any:
""" Execute query that returns a single value. The value is returned. """ Execute query that returns a single value. The value is returned.
If the query yields more than one row, a ValueError is raised. If the query yields more than one row, a ValueError is raised.
""" """
with conn.cursor() as cur: with conn.cursor(row_factory=psycopg.rows.tuple_row) as cur:
cur.execute(sql, args) cur.execute(sql, args)
if cur.rowcount != 1: if cur.rowcount != 1:
@@ -144,7 +98,7 @@ def server_version_tuple(conn: Connection) -> Tuple[int, int]:
""" Return the server version as a tuple of (major, minor). """ Return the server version as a tuple of (major, minor).
Converts correctly for pre-10 and post-10 PostgreSQL versions. Converts correctly for pre-10 and post-10 PostgreSQL versions.
""" """
version = conn.server_version version = conn.info.server_version
if version < 100000: if version < 100000:
return (int(version / 10000), int((version % 10000) / 100)) return (int(version / 10000), int((version % 10000) / 100))
@@ -164,31 +118,25 @@ def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
return (int(version_parts[0]), int(version_parts[1])) return (int(version_parts[0]), int(version_parts[1]))
def register_hstore(conn: Connection) -> None: def register_hstore(conn: Connection) -> None:
""" Register the hstore type with psycopg for the connection. """ Register the hstore type with psycopg for the connection.
""" """
psycopg2.extras.register_hstore(conn) info = psycopg.types.TypeInfo.fetch(conn, "hstore")
if info is None:
raise RuntimeError('Hstore extension is requested but not installed.')
psycopg.types.hstore.register_hstore(info, conn)
class ConnectionContext(ContextManager[Connection]): def connect(dsn: str, **kwargs: Any) -> Connection:
""" Context manager of the connection that also provides direct access
to the underlying connection.
"""
connection: Connection
def connect(dsn: str) -> ConnectionContext:
""" Open a connection to the database using the specialised connection """ Open a connection to the database using the specialised connection
factory. The returned object may be used in conjunction with 'with'. factory. The returned object may be used in conjunction with 'with'.
When used outside a context manager, use the `connection` attribute When used outside a context manager, use the `connection` attribute
to get the connection. to get the connection.
""" """
try: try:
conn = psycopg2.connect(dsn, connection_factory=Connection) return psycopg.connect(dsn, row_factory=psycopg.rows.namedtuple_row, **kwargs)
ctxmgr = cast(ConnectionContext, contextlib.closing(conn)) except psycopg.OperationalError as err:
ctxmgr.connection = conn
return ctxmgr
except psycopg2.OperationalError as err:
raise UsageError(f"Cannot connect to database: {err}") from err raise UsageError(f"Cannot connect to database: {err}") from err
@@ -233,10 +181,18 @@ def get_pg_env(dsn: str,
""" """
env = dict(base_env if base_env is not None else os.environ) env = dict(base_env if base_env is not None else os.environ)
for param, value in psycopg2.extensions.parse_dsn(dsn).items(): for param, value in psycopg.conninfo.conninfo_to_dict(dsn).items():
if param in _PG_CONNECTION_STRINGS: if param in _PG_CONNECTION_STRINGS:
env[_PG_CONNECTION_STRINGS[param]] = value env[_PG_CONNECTION_STRINGS[param]] = str(value)
else: else:
LOG.error("Unknown connection parameter '%s' ignored.", param) LOG.error("Unknown connection parameter '%s' ignored.", param)
return env return env
async def run_async_query(dsn: str, query: psycopg.abc.Query) -> None:
""" Open a connection to the database and run a single query
asynchronously.
"""
async with await psycopg.AsyncConnection.connect(dsn) as aconn:
await aconn.execute(query)

View File

@@ -0,0 +1,87 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
A connection pool that executes incoming queries in parallel.
"""
from typing import Any, Tuple, Optional
import asyncio
import logging
import time
import psycopg
LOG = logging.getLogger()
QueueItem = Optional[Tuple[psycopg.abc.Query, Any]]
class QueryPool:
""" Pool to run SQL queries in parallel asynchronous execution.
All queries are run in autocommit mode. If parallel execution leads
to a deadlock, then the query is repeated.
The results of the queries is discarded.
"""
def __init__(self, dsn: str, pool_size: int = 1, **conn_args: Any) -> None:
self.wait_time = 0.0
self.query_queue: 'asyncio.Queue[QueueItem]' = asyncio.Queue(maxsize=2 * pool_size)
self.pool = [asyncio.create_task(self._worker_loop(dsn, **conn_args))
for _ in range(pool_size)]
async def put_query(self, query: psycopg.abc.Query, params: Any) -> None:
""" Schedule a query for execution.
"""
tstart = time.time()
await self.query_queue.put((query, params))
self.wait_time += time.time() - tstart
await asyncio.sleep(0)
async def finish(self) -> None:
""" Wait for all queries to finish and close the pool.
"""
for _ in self.pool:
await self.query_queue.put(None)
tstart = time.time()
await asyncio.wait(self.pool)
self.wait_time += time.time() - tstart
for task in self.pool:
excp = task.exception()
if excp is not None:
raise excp
async def _worker_loop(self, dsn: str, **conn_args: Any) -> None:
conn_args['autocommit'] = True
aconn = await psycopg.AsyncConnection.connect(dsn, **conn_args)
async with aconn:
async with aconn.cursor() as cur:
item = await self.query_queue.get()
while item is not None:
try:
if item[1] is None:
await cur.execute(item[0])
else:
await cur.execute(item[0], item[1])
item = await self.query_queue.get()
except psycopg.errors.DeadlockDetected:
assert item is not None
LOG.info("Deadlock detected (sql = %s, params = %s), retry.",
str(item[0]), str(item[1]))
# item is still valid here, causing a retry
async def __aenter__(self) -> 'QueryPool':
return self
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
await self.finish()

View File

@@ -8,11 +8,12 @@
Preprocessing of SQL files. Preprocessing of SQL files.
""" """
from typing import Set, Dict, Any, cast from typing import Set, Dict, Any, cast
import jinja2 import jinja2
from .connection import Connection, server_version_tuple, postgis_version_tuple from .connection import Connection, server_version_tuple, postgis_version_tuple
from .async_connection import WorkerPool
from ..config import Configuration from ..config import Configuration
from ..db.query_pool import QueryPool
def _get_partitions(conn: Connection) -> Set[int]: def _get_partitions(conn: Connection) -> Set[int]:
""" Get the set of partitions currently in use. """ Get the set of partitions currently in use.
@@ -125,8 +126,8 @@ class SQLPreprocessor:
conn.commit() conn.commit()
def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1, async def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
**kwargs: Any) -> None: **kwargs: Any) -> None:
""" Execute the given SQL files using parallel asynchronous connections. """ Execute the given SQL files using parallel asynchronous connections.
The keyword arguments may supply additional parameters for The keyword arguments may supply additional parameters for
preprocessing. preprocessing.
@@ -138,6 +139,6 @@ class SQLPreprocessor:
parts = sql.split('\n---\n') parts = sql.split('\n---\n')
with WorkerPool(dsn, num_threads) as pool: async with QueryPool(dsn, num_threads) as pool:
for part in parts: for part in parts:
pool.next_free_worker().perform(part) await pool.put_query(part, None)

View File

@@ -7,7 +7,7 @@
""" """
Access and helper functions for the status and status log table. Access and helper functions for the status and status log table.
""" """
from typing import Optional, Tuple, cast from typing import Optional, Tuple
import datetime as dt import datetime as dt
import logging import logging
import re import re
@@ -15,20 +15,11 @@ import re
from .connection import Connection, table_exists, execute_scalar from .connection import Connection, table_exists, execute_scalar
from ..utils.url_utils import get_url from ..utils.url_utils import get_url
from ..errors import UsageError from ..errors import UsageError
from ..typing import TypedDict
LOG = logging.getLogger() LOG = logging.getLogger()
ISODATE_FORMAT = '%Y-%m-%dT%H:%M:%S' ISODATE_FORMAT = '%Y-%m-%dT%H:%M:%S'
class StatusRow(TypedDict):
""" Dictionary of columns of the import_status table.
"""
lastimportdate: dt.datetime
sequence_id: Optional[int]
indexed: Optional[bool]
def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetime: def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetime:
""" Determine the date of the database from the newest object in the """ Determine the date of the database from the newest object in the
data base. data base.
@@ -102,8 +93,9 @@ def get_status(conn: Connection) -> Tuple[Optional[dt.datetime], Optional[int],
if cur.rowcount < 1: if cur.rowcount < 1:
return None, None, None return None, None, None
row = cast(StatusRow, cur.fetchone()) row = cur.fetchone()
return row['lastimportdate'], row['sequence_id'], row['indexed'] assert row
return row.lastimportdate, row.sequence_id, row.indexed
def set_indexed(conn: Connection, state: bool) -> None: def set_indexed(conn: Connection, state: bool) -> None:

View File

@@ -7,14 +7,13 @@
""" """
Helper functions for handling DB accesses. Helper functions for handling DB accesses.
""" """
from typing import IO, Optional, Union, Any, Iterable from typing import IO, Optional, Union
import subprocess import subprocess
import logging import logging
import gzip import gzip
import io
from pathlib import Path from pathlib import Path
from .connection import get_pg_env, Cursor from .connection import get_pg_env
from ..errors import UsageError from ..errors import UsageError
LOG = logging.getLogger() LOG = logging.getLogger()
@@ -72,58 +71,3 @@ def execute_file(dsn: str, fname: Path,
if ret != 0 or remain > 0: if ret != 0 or remain > 0:
raise UsageError("Failed to execute SQL file.") raise UsageError("Failed to execute SQL file.")
# List of characters that need to be quoted for the copy command.
_SQL_TRANSLATION = {ord('\\'): '\\\\',
ord('\t'): '\\t',
ord('\n'): '\\n'}
class CopyBuffer:
""" Data collector for the copy_from command.
"""
def __init__(self) -> None:
self.buffer = io.StringIO()
def __enter__(self) -> 'CopyBuffer':
return self
def size(self) -> int:
""" Return the number of bytes the buffer currently contains.
"""
return self.buffer.tell()
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self.buffer is not None:
self.buffer.close()
def add(self, *data: Any) -> None:
""" Add another row of data to the copy buffer.
"""
first = True
for column in data:
if first:
first = False
else:
self.buffer.write('\t')
if column is None:
self.buffer.write('\\N')
else:
self.buffer.write(str(column).translate(_SQL_TRANSLATION))
self.buffer.write('\n')
def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None:
""" Copy all collected data into the given table.
The buffer is empty and reusable after this operation.
"""
if self.buffer.tell() > 0:
self.buffer.seek(0)
cur.copy_from(self.buffer, table, columns=columns)
self.buffer = io.StringIO()

View File

@@ -7,92 +7,20 @@
""" """
Main work horse for indexing (computing addresses) the database. Main work horse for indexing (computing addresses) the database.
""" """
from typing import Optional, Any, cast from typing import cast, List, Any
import logging import logging
import time import time
import psycopg2.extras import psycopg
from ..typing import DictCursorResults from ..db.connection import connect, execute_scalar
from ..db.async_connection import DBConnection, WorkerPool from ..db.query_pool import QueryPool
from ..db.connection import connect, Connection, Cursor, execute_scalar, register_hstore
from ..tokenizer.base import AbstractTokenizer from ..tokenizer.base import AbstractTokenizer
from .progress import ProgressLogger from .progress import ProgressLogger
from . import runners from . import runners
LOG = logging.getLogger() LOG = logging.getLogger()
class PlaceFetcher:
""" Asynchronous connection that fetches place details for processing.
"""
def __init__(self, dsn: str, setup_conn: Connection) -> None:
self.wait_time = 0.0
self.current_ids: Optional[DictCursorResults] = None
self.conn: Optional[DBConnection] = DBConnection(dsn,
cursor_factory=psycopg2.extras.DictCursor)
# need to fetch those manually because register_hstore cannot
# fetch them on an asynchronous connection below.
hstore_oid = execute_scalar(setup_conn, "SELECT 'hstore'::regtype::oid")
hstore_array_oid = execute_scalar(setup_conn, "SELECT 'hstore[]'::regtype::oid")
psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
array_oid=hstore_array_oid)
def close(self) -> None:
""" Close the underlying asynchronous connection.
"""
if self.conn:
self.conn.close()
self.conn = None
def fetch_next_batch(self, cur: Cursor, runner: runners.Runner) -> bool:
""" Send a request for the next batch of places.
If details for the places are required, they will be fetched
asynchronously.
Returns true if there is still data available.
"""
ids = cast(Optional[DictCursorResults], cur.fetchmany(100))
if not ids:
self.current_ids = None
return False
assert self.conn is not None
self.current_ids = runner.get_place_details(self.conn, ids)
return True
def get_batch(self) -> DictCursorResults:
""" Get the next batch of data, previously requested with
`fetch_next_batch`.
"""
assert self.conn is not None
assert self.conn.cursor is not None
if self.current_ids is not None and not self.current_ids:
tstart = time.time()
self.conn.wait()
self.wait_time += time.time() - tstart
self.current_ids = cast(Optional[DictCursorResults],
self.conn.cursor.fetchall())
return self.current_ids if self.current_ids is not None else []
def __enter__(self) -> 'PlaceFetcher':
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
assert self.conn is not None
self.conn.wait()
self.close()
class Indexer: class Indexer:
""" Main indexing routine. """ Main indexing routine.
""" """
@@ -114,7 +42,7 @@ class Indexer:
return cur.rowcount > 0 return cur.rowcount > 0
def index_full(self, analyse: bool = True) -> None: async def index_full(self, analyse: bool = True) -> None:
""" Index the complete database. This will first index boundaries """ Index the complete database. This will first index boundaries
followed by all other objects. When `analyse` is True, then the followed by all other objects. When `analyse` is True, then the
database will be analysed at the appropriate places to database will be analysed at the appropriate places to
@@ -128,23 +56,27 @@ class Indexer:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute('ANALYZE') cur.execute('ANALYZE')
if self.index_by_rank(0, 4) > 0: while True:
_analyze() if await self.index_by_rank(0, 4) > 0:
_analyze()
if self.index_boundaries(0, 30) > 100: if await self.index_boundaries(0, 30) > 100:
_analyze() _analyze()
if self.index_by_rank(5, 25) > 100: if await self.index_by_rank(5, 25) > 100:
_analyze() _analyze()
if self.index_by_rank(26, 30) > 1000: if await self.index_by_rank(26, 30) > 1000:
_analyze() _analyze()
if self.index_postcodes() > 100: if await self.index_postcodes() > 100:
_analyze() _analyze()
if not self.has_pending():
break
def index_boundaries(self, minrank: int, maxrank: int) -> int: async def index_boundaries(self, minrank: int, maxrank: int) -> int:
""" Index only administrative boundaries within the given rank range. """ Index only administrative boundaries within the given rank range.
""" """
total = 0 total = 0
@@ -153,11 +85,11 @@ class Indexer:
with self.tokenizer.name_analyzer() as analyzer: with self.tokenizer.name_analyzer() as analyzer:
for rank in range(max(minrank, 4), min(maxrank, 26)): for rank in range(max(minrank, 4), min(maxrank, 26)):
total += self._index(runners.BoundaryRunner(rank, analyzer)) total += await self._index(runners.BoundaryRunner(rank, analyzer))
return total return total
def index_by_rank(self, minrank: int, maxrank: int) -> int: async def index_by_rank(self, minrank: int, maxrank: int) -> int:
""" Index all entries of placex in the given rank range (inclusive) """ Index all entries of placex in the given rank range (inclusive)
in order of their address rank. in order of their address rank.
@@ -171,21 +103,27 @@ class Indexer:
with self.tokenizer.name_analyzer() as analyzer: with self.tokenizer.name_analyzer() as analyzer:
for rank in range(max(1, minrank), maxrank + 1): for rank in range(max(1, minrank), maxrank + 1):
total += self._index(runners.RankRunner(rank, analyzer), 20 if rank == 30 else 1) if rank >= 30:
batch = 20
elif rank >= 26:
batch = 5
else:
batch = 1
total += await self._index(runners.RankRunner(rank, analyzer), batch)
if maxrank == 30: if maxrank == 30:
total += self._index(runners.RankRunner(0, analyzer)) total += await self._index(runners.RankRunner(0, analyzer))
total += self._index(runners.InterpolationRunner(analyzer), 20) total += await self._index(runners.InterpolationRunner(analyzer), 20)
return total return total
def index_postcodes(self) -> int: async def index_postcodes(self) -> int:
"""Index the entries of the location_postcode table. """Index the entries of the location_postcode table.
""" """
LOG.warning("Starting indexing postcodes using %s threads", self.num_threads) LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
return self._index(runners.PostcodeRunner(), 20) return await self._index(runners.PostcodeRunner(), 20)
def update_status_table(self) -> None: def update_status_table(self) -> None:
@@ -197,45 +135,58 @@ class Indexer:
conn.commit() conn.commit()
def _index(self, runner: runners.Runner, batch: int = 1) -> int: async def _index(self, runner: runners.Runner, batch: int = 1) -> int:
""" Index a single rank or table. `runner` describes the SQL to use """ Index a single rank or table. `runner` describes the SQL to use
for indexing. `batch` describes the number of objects that for indexing. `batch` describes the number of objects that
should be processed with a single SQL statement should be processed with a single SQL statement
""" """
LOG.warning("Starting %s (using batch size %s)", runner.name(), batch) LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
with connect(self.dsn) as conn: total_tuples = self._prepare_indexing(runner)
register_hstore(conn)
total_tuples = execute_scalar(conn, runner.sql_count_objects())
LOG.debug("Total number of rows: %i", total_tuples)
conn.commit() progress = ProgressLogger(runner.name(), total_tuples)
progress = ProgressLogger(runner.name(), total_tuples) if total_tuples > 0:
async with await psycopg.AsyncConnection.connect(
self.dsn, row_factory=psycopg.rows.dict_row) as aconn,\
QueryPool(self.dsn, self.num_threads, autocommit=True) as pool:
fetcher_time = 0.0
tstart = time.time()
async with aconn.cursor(name='places') as cur:
query = runner.index_places_query(batch)
params: List[Any] = []
num_places = 0
async for place in cur.stream(runner.sql_get_objects()):
fetcher_time += time.time() - tstart
if total_tuples > 0: params.extend(runner.index_places_params(place))
with conn.cursor(name='places') as cur: num_places += 1
cur.execute(runner.sql_get_objects())
with PlaceFetcher(self.dsn, conn) as fetcher: if num_places >= batch:
with WorkerPool(self.dsn, self.num_threads) as pool: LOG.debug("Processing places: %s", str(params))
has_more = fetcher.fetch_next_batch(cur, runner) await pool.put_query(query, params)
while has_more: progress.add(num_places)
places = fetcher.get_batch() params = []
num_places = 0
# asynchronously get the next batch tstart = time.time()
has_more = fetcher.fetch_next_batch(cur, runner)
# And insert the current batch if num_places > 0:
for idx in range(0, len(places), batch): await pool.put_query(runner.index_places_query(num_places), params)
part = places[idx:idx + batch]
LOG.debug("Processing places: %s", str(part))
runner.index_places(pool.next_free_worker(), part)
progress.add(len(part))
LOG.info("Wait time: fetcher: %.2fs, pool: %.2fs", LOG.info("Wait time: fetcher: %.2fs, pool: %.2fs",
fetcher.wait_time, pool.wait_time) fetcher_time, pool.wait_time)
conn.commit()
return progress.done() return progress.done()
def _prepare_indexing(self, runner: runners.Runner) -> int:
with connect(self.dsn) as conn:
hstore_info = psycopg.types.TypeInfo.fetch(conn, "hstore")
if hstore_info is None:
raise RuntimeError('Hstore extension is requested but not installed.')
psycopg.types.hstore.register_hstore(hstore_info)
total_tuples = execute_scalar(conn, runner.sql_count_objects())
LOG.debug("Total number of rows: %i", total_tuples)
return cast(int, total_tuples)

View File

@@ -8,14 +8,14 @@
Mix-ins that provide the actual commands for the indexer for various indexing Mix-ins that provide the actual commands for the indexer for various indexing
tasks. tasks.
""" """
from typing import Any, List from typing import Any, Sequence
import functools
from psycopg2 import sql as pysql from psycopg import sql as pysql
import psycopg2.extras from psycopg.abc import Query
from psycopg.rows import DictRow
from psycopg.types.json import Json
from ..typing import Query, DictCursorResult, DictCursorResults, Protocol from ..typing import Protocol
from ..db.async_connection import DBConnection
from ..data.place_info import PlaceInfo from ..data.place_info import PlaceInfo
from ..tokenizer.base import AbstractAnalyzer from ..tokenizer.base import AbstractAnalyzer
@@ -24,58 +24,48 @@ from ..tokenizer.base import AbstractAnalyzer
def _mk_valuelist(template: str, num: int) -> pysql.Composed: def _mk_valuelist(template: str, num: int) -> pysql.Composed:
return pysql.SQL(',').join([pysql.SQL(template)] * num) return pysql.SQL(',').join([pysql.SQL(template)] * num)
def _analyze_place(place: DictCursorResult, analyzer: AbstractAnalyzer) -> psycopg2.extras.Json: def _analyze_place(place: DictRow, analyzer: AbstractAnalyzer) -> Json:
return psycopg2.extras.Json(analyzer.process_place(PlaceInfo(place))) return Json(analyzer.process_place(PlaceInfo(place)))
class Runner(Protocol): class Runner(Protocol):
def name(self) -> str: ... def name(self) -> str: ...
def sql_count_objects(self) -> Query: ... def sql_count_objects(self) -> Query: ...
def sql_get_objects(self) -> Query: ... def sql_get_objects(self) -> Query: ...
def get_place_details(self, worker: DBConnection, def index_places_query(self, batch_size: int) -> Query: ...
ids: DictCursorResults) -> DictCursorResults: ... def index_places_params(self, place: DictRow) -> Sequence[Any]: ...
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None: ...
SELECT_SQL = pysql.SQL("""SELECT place_id, extra.*
FROM (SELECT * FROM placex {}) as px,
LATERAL placex_indexing_prepare(px) as extra """)
UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)"
class AbstractPlacexRunner: class AbstractPlacexRunner:
""" Returns SQL commands for indexing of the placex table. """ Returns SQL commands for indexing of the placex table.
""" """
SELECT_SQL = pysql.SQL('SELECT place_id FROM placex ')
UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)"
def __init__(self, rank: int, analyzer: AbstractAnalyzer) -> None: def __init__(self, rank: int, analyzer: AbstractAnalyzer) -> None:
self.rank = rank self.rank = rank
self.analyzer = analyzer self.analyzer = analyzer
@functools.lru_cache(maxsize=1) def index_places_query(self, batch_size: int) -> Query:
def _index_sql(self, num_places: int) -> pysql.Composed:
return pysql.SQL( return pysql.SQL(
""" UPDATE placex """ UPDATE placex
SET indexed_status = 0, address = v.addr, token_info = v.ti, SET indexed_status = 0, address = v.addr, token_info = v.ti,
name = v.name, linked_place_id = v.linked_place_id name = v.name, linked_place_id = v.linked_place_id
FROM (VALUES {}) as v(id, name, addr, linked_place_id, ti) FROM (VALUES {}) as v(id, name, addr, linked_place_id, ti)
WHERE place_id = v.id WHERE place_id = v.id
""").format(_mk_valuelist(AbstractPlacexRunner.UPDATE_LINE, num_places)) """).format(_mk_valuelist(UPDATE_LINE, batch_size))
def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults: def index_places_params(self, place: DictRow) -> Sequence[Any]:
worker.perform("""SELECT place_id, extra.* return (place['place_id'],
FROM placex, LATERAL placex_indexing_prepare(placex) as extra place['name'],
WHERE place_id IN %s""", place['address'],
(tuple((p[0] for p in ids)), )) place['linked_place_id'],
_analyze_place(place, self.analyzer))
return []
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
values: List[Any] = []
for place in places:
for field in ('place_id', 'name', 'address', 'linked_place_id'):
values.append(place[field])
values.append(_analyze_place(place, self.analyzer))
worker.perform(self._index_sql(len(places)), values)
class RankRunner(AbstractPlacexRunner): class RankRunner(AbstractPlacexRunner):
@@ -91,10 +81,10 @@ class RankRunner(AbstractPlacexRunner):
""").format(pysql.Literal(self.rank)) """).format(pysql.Literal(self.rank))
def sql_get_objects(self) -> pysql.Composed: def sql_get_objects(self) -> pysql.Composed:
return self.SELECT_SQL + pysql.SQL( return SELECT_SQL.format(pysql.SQL(
"""WHERE indexed_status > 0 and rank_address = {} """WHERE placex.indexed_status > 0 and placex.rank_address = {}
ORDER BY geometry_sector ORDER BY placex.geometry_sector
""").format(pysql.Literal(self.rank)) """).format(pysql.Literal(self.rank)))
class BoundaryRunner(AbstractPlacexRunner): class BoundaryRunner(AbstractPlacexRunner):
@@ -105,19 +95,19 @@ class BoundaryRunner(AbstractPlacexRunner):
def name(self) -> str: def name(self) -> str:
return f"boundaries rank {self.rank}" return f"boundaries rank {self.rank}"
def sql_count_objects(self) -> pysql.Composed: def sql_count_objects(self) -> Query:
return pysql.SQL("""SELECT count(*) FROM placex return pysql.SQL("""SELECT count(*) FROM placex
WHERE indexed_status > 0 WHERE indexed_status > 0
AND rank_search = {} AND rank_search = {}
AND class = 'boundary' and type = 'administrative' AND class = 'boundary' and type = 'administrative'
""").format(pysql.Literal(self.rank)) """).format(pysql.Literal(self.rank))
def sql_get_objects(self) -> pysql.Composed: def sql_get_objects(self) -> Query:
return self.SELECT_SQL + pysql.SQL( return SELECT_SQL.format(pysql.SQL(
"""WHERE indexed_status > 0 and rank_search = {} """WHERE placex.indexed_status > 0 and placex.rank_search = {}
and class = 'boundary' and type = 'administrative' and placex.class = 'boundary' and placex.type = 'administrative'
ORDER BY partition, admin_level ORDER BY placex.partition, placex.admin_level
""").format(pysql.Literal(self.rank)) """).format(pysql.Literal(self.rank)))
class InterpolationRunner: class InterpolationRunner:
@@ -132,40 +122,29 @@ class InterpolationRunner:
def name(self) -> str: def name(self) -> str:
return "interpolation lines (location_property_osmline)" return "interpolation lines (location_property_osmline)"
def sql_count_objects(self) -> str: def sql_count_objects(self) -> Query:
return """SELECT count(*) FROM location_property_osmline return """SELECT count(*) FROM location_property_osmline
WHERE indexed_status > 0""" WHERE indexed_status > 0"""
def sql_get_objects(self) -> str:
return """SELECT place_id def sql_get_objects(self) -> Query:
return """SELECT place_id, get_interpolation_address(address, osm_id) as address
FROM location_property_osmline FROM location_property_osmline
WHERE indexed_status > 0 WHERE indexed_status > 0
ORDER BY geometry_sector""" ORDER BY geometry_sector"""
def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults: def index_places_query(self, batch_size: int) -> Query:
worker.perform("""SELECT place_id, get_interpolation_address(address, osm_id) as address
FROM location_property_osmline WHERE place_id IN %s""",
(tuple((p[0] for p in ids)), ))
return []
@functools.lru_cache(maxsize=1)
def _index_sql(self, num_places: int) -> pysql.Composed:
return pysql.SQL("""UPDATE location_property_osmline return pysql.SQL("""UPDATE location_property_osmline
SET indexed_status = 0, address = v.addr, token_info = v.ti SET indexed_status = 0, address = v.addr, token_info = v.ti
FROM (VALUES {}) as v(id, addr, ti) FROM (VALUES {}) as v(id, addr, ti)
WHERE place_id = v.id WHERE place_id = v.id
""").format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", num_places)) """).format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", batch_size))
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None: def index_places_params(self, place: DictRow) -> Sequence[Any]:
values: List[Any] = [] return (place['place_id'], place['address'],
for place in places: _analyze_place(place, self.analyzer))
values.extend((place[x] for x in ('place_id', 'address')))
values.append(_analyze_place(place, self.analyzer))
worker.perform(self._index_sql(len(places)), values)
@@ -177,20 +156,21 @@ class PostcodeRunner(Runner):
return "postcodes (location_postcode)" return "postcodes (location_postcode)"
def sql_count_objects(self) -> str: def sql_count_objects(self) -> Query:
return 'SELECT count(*) FROM location_postcode WHERE indexed_status > 0' return 'SELECT count(*) FROM location_postcode WHERE indexed_status > 0'
def sql_get_objects(self) -> str: def sql_get_objects(self) -> Query:
return """SELECT place_id FROM location_postcode return """SELECT place_id FROM location_postcode
WHERE indexed_status > 0 WHERE indexed_status > 0
ORDER BY country_code, postcode""" ORDER BY country_code, postcode"""
def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults: def index_places_query(self, batch_size: int) -> Query:
return ids return pysql.SQL("""UPDATE location_postcode SET indexed_status = 0
WHERE place_id IN ({})""")\
.format(pysql.SQL(',').join((pysql.Placeholder() for _ in range(batch_size))))
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
worker.perform(pysql.SQL("""UPDATE location_postcode SET indexed_status = 0 def index_places_params(self, place: DictRow) -> Sequence[Any]:
WHERE place_id IN ({})""") return (place['place_id'], )
.format(pysql.SQL(',').join((pysql.Literal(i[0]) for i in places))))

View File

@@ -11,15 +11,16 @@ libICU instead of the PostgreSQL module.
from typing import Optional, Sequence, List, Tuple, Mapping, Any, cast, \ from typing import Optional, Sequence, List, Tuple, Mapping, Any, cast, \
Dict, Set, Iterable Dict, Set, Iterable
import itertools import itertools
import json
import logging import logging
from pathlib import Path from pathlib import Path
from textwrap import dedent from textwrap import dedent
from psycopg.types.json import Jsonb
from psycopg import sql as pysql
from ..db.connection import connect, Connection, Cursor, server_version_tuple,\ from ..db.connection import connect, Connection, Cursor, server_version_tuple,\
drop_tables, table_exists, execute_scalar drop_tables, table_exists, execute_scalar
from ..config import Configuration from ..config import Configuration
from ..db.utils import CopyBuffer
from ..db.sql_preprocessor import SQLPreprocessor from ..db.sql_preprocessor import SQLPreprocessor
from ..data.place_info import PlaceInfo from ..data.place_info import PlaceInfo
from ..data.place_name import PlaceName from ..data.place_name import PlaceName
@@ -115,8 +116,8 @@ class ICUTokenizer(AbstractTokenizer):
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute('ANALYSE search_name') cur.execute('ANALYSE search_name')
if threads > 1: if threads > 1:
cur.execute('SET max_parallel_workers_per_gather TO %s', cur.execute(pysql.SQL('SET max_parallel_workers_per_gather TO {}')
(min(threads, 6),)) .format(pysql.Literal(min(threads, 6),)))
if server_version_tuple(conn) < (12, 0): if server_version_tuple(conn) < (12, 0):
LOG.info('Computing word frequencies') LOG.info('Computing word frequencies')
@@ -391,7 +392,7 @@ class ICUNameAnalyzer(AbstractAnalyzer):
def __init__(self, dsn: str, sanitizer: PlaceSanitizer, def __init__(self, dsn: str, sanitizer: PlaceSanitizer,
token_analysis: ICUTokenAnalysis) -> None: token_analysis: ICUTokenAnalysis) -> None:
self.conn: Optional[Connection] = connect(dsn).connection self.conn: Optional[Connection] = connect(dsn)
self.conn.autocommit = True self.conn.autocommit = True
self.sanitizer = sanitizer self.sanitizer = sanitizer
self.token_analysis = token_analysis self.token_analysis = token_analysis
@@ -533,9 +534,7 @@ class ICUNameAnalyzer(AbstractAnalyzer):
if terms: if terms:
with self.conn.cursor() as cur: with self.conn.cursor() as cur:
cur.execute_values("""SELECT create_postcode_word(pc, var) cur.executemany("""SELECT create_postcode_word(%s, %s)""", terms)
FROM (VALUES %s) AS v(pc, var)""",
terms)
@@ -578,18 +577,15 @@ class ICUNameAnalyzer(AbstractAnalyzer):
to_add = new_phrases - existing_phrases to_add = new_phrases - existing_phrases
added = 0 added = 0
with CopyBuffer() as copystr: with cursor.copy('COPY word(word_token, type, word, info) FROM STDIN') as copy:
for word, cls, typ, oper in to_add: for word, cls, typ, oper in to_add:
term = self._search_normalized(word) term = self._search_normalized(word)
if term: if term:
copystr.add(term, 'S', word, copy.write_row((term, 'S', word,
json.dumps({'class': cls, 'type': typ, Jsonb({'class': cls, 'type': typ,
'op': oper if oper in ('in', 'near') else None})) 'op': oper if oper in ('in', 'near') else None})))
added += 1 added += 1
copystr.copy_out(cursor, 'word',
columns=['word_token', 'type', 'word', 'info'])
return added return added
@@ -602,11 +598,11 @@ class ICUNameAnalyzer(AbstractAnalyzer):
to_delete = existing_phrases - new_phrases to_delete = existing_phrases - new_phrases
if to_delete: if to_delete:
cursor.execute_values( cursor.executemany(
""" DELETE FROM word USING (VALUES %s) as v(name, in_class, in_type, op) """ DELETE FROM word
WHERE type = 'S' and word = name WHERE type = 'S' and word = %s
and info->>'class' = in_class and info->>'type' = in_type and info->>'class' = %s and info->>'type' = %s
and ((op = '-' and info->>'op' is null) or op = info->>'op') and %s = coalesce(info->>'op', '-')
""", to_delete) """, to_delete)
return len(to_delete) return len(to_delete)
@@ -653,7 +649,7 @@ class ICUNameAnalyzer(AbstractAnalyzer):
gone_tokens.update(existing_tokens[False] & word_tokens) gone_tokens.update(existing_tokens[False] & word_tokens)
if gone_tokens: if gone_tokens:
cur.execute("""DELETE FROM word cur.execute("""DELETE FROM word
USING unnest(%s) as token USING unnest(%s::text[]) as token
WHERE type = 'C' and word = %s WHERE type = 'C' and word = %s
and word_token = token""", and word_token = token""",
(list(gone_tokens), country_code)) (list(gone_tokens), country_code))
@@ -666,12 +662,12 @@ class ICUNameAnalyzer(AbstractAnalyzer):
if internal: if internal:
sql = """INSERT INTO word (word_token, type, word, info) sql = """INSERT INTO word (word_token, type, word, info)
(SELECT token, 'C', %s, '{"internal": "yes"}' (SELECT token, 'C', %s, '{"internal": "yes"}'
FROM unnest(%s) as token) FROM unnest(%s::text[]) as token)
""" """
else: else:
sql = """INSERT INTO word (word_token, type, word) sql = """INSERT INTO word (word_token, type, word)
(SELECT token, 'C', %s (SELECT token, 'C', %s
FROM unnest(%s) as token) FROM unnest(%s::text[]) as token)
""" """
cur.execute(sql, (country_code, list(new_tokens))) cur.execute(sql, (country_code, list(new_tokens)))

View File

@@ -17,7 +17,8 @@ import shutil
from textwrap import dedent from textwrap import dedent
from icu import Transliterator from icu import Transliterator
import psycopg2 import psycopg
from psycopg import sql as pysql
from ..errors import UsageError from ..errors import UsageError
from ..db.connection import connect, Connection, drop_tables, table_exists,\ from ..db.connection import connect, Connection, drop_tables, table_exists,\
@@ -78,12 +79,12 @@ def _check_module(module_dir: str, conn: Connection) -> None:
""" """
with conn.cursor() as cur: with conn.cursor() as cur:
try: try:
cur.execute("""CREATE FUNCTION nominatim_test_import_func(text) cur.execute(pysql.SQL("""CREATE FUNCTION nominatim_test_import_func(text)
RETURNS text AS %s, 'transliteration' RETURNS text AS {}, 'transliteration'
LANGUAGE c IMMUTABLE STRICT; LANGUAGE c IMMUTABLE STRICT;
DROP FUNCTION nominatim_test_import_func(text) DROP FUNCTION nominatim_test_import_func(text)
""", (f'{module_dir}/nominatim.so', )) """).format(pysql.Literal(f'{module_dir}/nominatim.so')))
except psycopg2.DatabaseError as err: except psycopg.DatabaseError as err:
LOG.fatal("Error accessing database module: %s", err) LOG.fatal("Error accessing database module: %s", err)
raise UsageError("Database module cannot be accessed.") from err raise UsageError("Database module cannot be accessed.") from err
@@ -181,7 +182,7 @@ class LegacyTokenizer(AbstractTokenizer):
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
try: try:
out = execute_scalar(conn, "SELECT make_standard_name('a')") out = execute_scalar(conn, "SELECT make_standard_name('a')")
except psycopg2.Error as err: except psycopg.Error as err:
return hint.format(error=str(err)) return hint.format(error=str(err))
if out != 'a': if out != 'a':
@@ -312,7 +313,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
""" """
def __init__(self, dsn: str, normalizer: Any): def __init__(self, dsn: str, normalizer: Any):
self.conn: Optional[Connection] = connect(dsn).connection self.conn: Optional[Connection] = connect(dsn)
self.conn.autocommit = True self.conn.autocommit = True
self.normalizer = normalizer self.normalizer = normalizer
register_hstore(self.conn) register_hstore(self.conn)
@@ -405,7 +406,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
""", (to_delete, )) """, (to_delete, ))
if to_add: if to_add:
cur.execute("""SELECT count(create_postcode_id(pc)) cur.execute("""SELECT count(create_postcode_id(pc))
FROM unnest(%s) as pc FROM unnest(%s::text[]) as pc
""", (to_add, )) """, (to_add, ))
@@ -422,7 +423,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
with self.conn.cursor() as cur: with self.conn.cursor() as cur:
# Get the old phrases. # Get the old phrases.
existing_phrases = set() existing_phrases = set()
cur.execute("""SELECT word, class, type, operator FROM word cur.execute("""SELECT word, class as cls, type, operator FROM word
WHERE class != 'place' WHERE class != 'place'
OR (type != 'house' AND type != 'postcode')""") OR (type != 'house' AND type != 'postcode')""")
for label, cls, typ, oper in cur: for label, cls, typ, oper in cur:
@@ -432,18 +433,19 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
to_delete = existing_phrases - norm_phrases to_delete = existing_phrases - norm_phrases
if to_add: if to_add:
cur.execute_values( cur.executemany(
""" INSERT INTO word (word_id, word_token, word, class, type, """ INSERT INTO word (word_id, word_token, word, class, type,
search_name_count, operator) search_name_count, operator)
(SELECT nextval('seq_word'), ' ' || make_standard_name(name), name, (SELECT nextval('seq_word'), ' ' || make_standard_name(name), name,
class, type, 0, class, type, 0,
CASE WHEN op in ('in', 'near') THEN op ELSE null END CASE WHEN op in ('in', 'near') THEN op ELSE null END
FROM (VALUES %s) as v(name, class, type, op))""", FROM (VALUES (%s, %s, %s, %s)) as v(name, class, type, op))""",
to_add) to_add)
if to_delete and should_replace: if to_delete and should_replace:
cur.execute_values( cur.executemany(
""" DELETE FROM word USING (VALUES %s) as v(name, in_class, in_type, op) """ DELETE FROM word
USING (VALUES (%s, %s, %s, %s)) as v(name, in_class, in_type, op)
WHERE word = name and class = in_class and type = in_type WHERE word = name and class = in_class and type = in_type
and ((op = '-' and operator is null) or op = operator)""", and ((op = '-' and operator is null) or op = operator)""",
to_delete) to_delete)
@@ -462,7 +464,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
"""INSERT INTO word (word_id, word_token, country_code) """INSERT INTO word (word_id, word_token, country_code)
(SELECT nextval('seq_word'), lookup_token, %s (SELECT nextval('seq_word'), lookup_token, %s
FROM (SELECT DISTINCT ' ' || make_standard_name(n) as lookup_token FROM (SELECT DISTINCT ' ' || make_standard_name(n) as lookup_token
FROM unnest(%s)n) y FROM unnest(%s::TEXT[])n) y
WHERE NOT EXISTS(SELECT * FROM word WHERE NOT EXISTS(SELECT * FROM word
WHERE word_token = lookup_token and country_code = %s)) WHERE word_token = lookup_token and country_code = %s))
""", (country_code, list(names.values()), country_code)) """, (country_code, list(names.values()), country_code))

View File

@@ -10,8 +10,8 @@ Functions for database analysis and maintenance.
from typing import Optional, Tuple, Any, cast from typing import Optional, Tuple, Any, cast
import logging import logging
from psycopg2.extras import Json import psycopg
from psycopg2 import DataError from psycopg.types.json import Json
from ..typing import DictCursorResult from ..typing import DictCursorResult
from ..config import Configuration from ..config import Configuration
@@ -59,7 +59,7 @@ def analyse_indexing(config: Configuration, osm_id: Optional[str] = None,
""" """
with connect(config.get_libpq_dsn()) as conn: with connect(config.get_libpq_dsn()) as conn:
register_hstore(conn) register_hstore(conn)
with conn.cursor() as cur: with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
place = _get_place_info(cur, osm_id, place_id) place = _get_place_info(cur, osm_id, place_id)
cur.execute("update placex set indexed_status = 2 where place_id = %s", cur.execute("update placex set indexed_status = 2 where place_id = %s",
@@ -74,6 +74,9 @@ def analyse_indexing(config: Configuration, osm_id: Optional[str] = None,
tokenizer = tokenizer_factory.get_tokenizer_for_db(config) tokenizer = tokenizer_factory.get_tokenizer_for_db(config)
# Enable printing of messages.
conn.add_notice_handler(lambda diag: print(diag.message_primary))
with tokenizer.name_analyzer() as analyzer: with tokenizer.name_analyzer() as analyzer:
cur.execute("""UPDATE placex cur.execute("""UPDATE placex
SET indexed_status = 0, address = %s, token_info = %s, SET indexed_status = 0, address = %s, token_info = %s,
@@ -86,9 +89,6 @@ def analyse_indexing(config: Configuration, osm_id: Optional[str] = None,
# we do not want to keep the results # we do not want to keep the results
conn.rollback() conn.rollback()
for msg in conn.notices:
print(msg)
def clean_deleted_relations(config: Configuration, age: str) -> None: def clean_deleted_relations(config: Configuration, age: str) -> None:
""" Clean deleted relations older than a given age """ Clean deleted relations older than a given age
@@ -101,6 +101,6 @@ def clean_deleted_relations(config: Configuration, age: str) -> None:
WHERE p.osm_type = d.osm_type AND p.osm_id = d.osm_id WHERE p.osm_type = d.osm_type AND p.osm_id = d.osm_id
AND age(p.indexed_date) > %s::interval""", AND age(p.indexed_date) > %s::interval""",
(age, )) (age, ))
except DataError as exc: except psycopg.DataError as exc:
raise UsageError('Invalid PostgreSQL time interval format') from exc raise UsageError('Invalid PostgreSQL time interval format') from exc
conn.commit() conn.commit()

View File

@@ -81,7 +81,7 @@ def check_database(config: Configuration) -> int:
""" Run a number of checks on the database and return the status. """ Run a number of checks on the database and return the status.
""" """
try: try:
conn = connect(config.get_libpq_dsn()).connection conn = connect(config.get_libpq_dsn())
except UsageError as err: except UsageError as err:
conn = _BadConnection(str(err)) # type: ignore[assignment] conn = _BadConnection(str(err)) # type: ignore[assignment]

View File

@@ -15,7 +15,6 @@ from pathlib import Path
from typing import List, Optional, Union from typing import List, Optional, Union
import psutil import psutil
from psycopg2.extensions import make_dsn
from ..config import Configuration from ..config import Configuration
from ..db.connection import connect, server_version_tuple, execute_scalar from ..db.connection import connect, server_version_tuple, execute_scalar
@@ -97,7 +96,7 @@ def report_system_information(config: Configuration) -> None:
"""Generate a report about the host system including software versions, memory, """Generate a report about the host system including software versions, memory,
storage, and database configuration.""" storage, and database configuration."""
with connect(make_dsn(config.get_libpq_dsn(), dbname='postgres')) as conn: with connect(config.get_libpq_dsn(), dbname='postgres') as conn:
postgresql_ver: str = '.'.join(map(str, server_version_tuple(conn))) postgresql_ver: str = '.'.join(map(str, server_version_tuple(conn)))
with conn.cursor() as cur: with conn.cursor() as cur:

View File

@@ -10,19 +10,20 @@ Functions for setting up and importing a new Nominatim database.
from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any
import logging import logging
import os import os
import selectors
import subprocess import subprocess
import asyncio
from pathlib import Path from pathlib import Path
import psutil import psutil
from psycopg2 import sql as pysql import psycopg
from psycopg import sql as pysql
from ..errors import UsageError from ..errors import UsageError
from ..config import Configuration from ..config import Configuration
from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\ from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\
postgis_version_tuple, drop_tables, table_exists, execute_scalar postgis_version_tuple, drop_tables, table_exists, execute_scalar
from ..db.async_connection import DBConnection
from ..db.sql_preprocessor import SQLPreprocessor from ..db.sql_preprocessor import SQLPreprocessor
from ..db.query_pool import QueryPool
from .exec_utils import run_osm2pgsql from .exec_utils import run_osm2pgsql
from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
@@ -136,7 +137,7 @@ def import_osm_data(osm_files: Union[Path, Sequence[Path]],
with connect(options['dsn']) as conn: with connect(options['dsn']) as conn:
if not ignore_errors: if not ignore_errors:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute('SELECT * FROM place LIMIT 1') cur.execute('SELECT true FROM place LIMIT 1')
if cur.rowcount == 0: if cur.rowcount == 0:
raise UsageError('No data imported by osm2pgsql.') raise UsageError('No data imported by osm2pgsql.')
@@ -205,54 +206,51 @@ _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier,
'extratags', 'geometry'))) 'extratags', 'geometry')))
def load_data(dsn: str, threads: int) -> None: async def load_data(dsn: str, threads: int) -> None:
""" Copy data into the word and placex table. """ Copy data into the word and placex table.
""" """
sel = selectors.DefaultSelector() placex_threads = max(1, threads - 1)
# Then copy data from place to placex in <threads - 1> chunks.
place_threads = max(1, threads - 1)
for imod in range(place_threads):
conn = DBConnection(dsn)
conn.connect()
conn.perform(
pysql.SQL("""INSERT INTO placex ({columns})
SELECT {columns} FROM place
WHERE osm_id % {total} = {mod}
AND NOT (class='place' and (type='houses' or type='postcode'))
AND ST_IsValid(geometry)
""").format(columns=_COPY_COLUMNS,
total=pysql.Literal(place_threads),
mod=pysql.Literal(imod)))
sel.register(conn, selectors.EVENT_READ, conn)
# Address interpolations go into another table. progress = asyncio.create_task(_progress_print())
conn = DBConnection(dsn)
conn.connect()
conn.perform("""INSERT INTO location_property_osmline (osm_id, address, linegeo)
SELECT osm_id, address, geometry FROM place
WHERE class='place' and type='houses' and osm_type='W'
and ST_GeometryType(geometry) = 'ST_LineString'
""")
sel.register(conn, selectors.EVENT_READ, conn)
# Now wait for all of them to finish. async with QueryPool(dsn, placex_threads + 1) as pool:
todo = place_threads + 1 # Copy data from place to placex in <threads - 1> chunks.
while todo > 0: for imod in range(placex_threads):
for key, _ in sel.select(1): await pool.put_query(
conn = key.data pysql.SQL("""INSERT INTO placex ({columns})
sel.unregister(conn) SELECT {columns} FROM place
conn.wait() WHERE osm_id % {total} = {mod}
conn.close() AND NOT (class='place'
todo -= 1 and (type='houses' or type='postcode'))
AND ST_IsValid(geometry)
""").format(columns=_COPY_COLUMNS,
total=pysql.Literal(placex_threads),
mod=pysql.Literal(imod)), None)
# Interpolations need to be copied seperately
await pool.put_query("""
INSERT INTO location_property_osmline (osm_id, address, linegeo)
SELECT osm_id, address, geometry FROM place
WHERE class='place' and type='houses' and osm_type='W'
and ST_GeometryType(geometry) = 'ST_LineString' """, None)
progress.cancel()
async with await psycopg.AsyncConnection.connect(dsn) as aconn:
await aconn.execute('ANALYSE')
async def _progress_print() -> None:
while True:
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
print('', flush=True)
break
print('.', end='', flush=True) print('.', end='', flush=True)
print('\n')
with connect(dsn) as syn_conn:
with syn_conn.cursor() as cur:
cur.execute('ANALYSE')
def create_search_indices(conn: Connection, config: Configuration, async def create_search_indices(conn: Connection, config: Configuration,
drop: bool = False, threads: int = 1) -> None: drop: bool = False, threads: int = 1) -> None:
""" Create tables that have explicit partitioning. """ Create tables that have explicit partitioning.
""" """
@@ -271,5 +269,5 @@ def create_search_indices(conn: Connection, config: Configuration,
sql = SQLPreprocessor(conn, config) sql = SQLPreprocessor(conn, config)
sql.run_parallel_sql_file(config.get_libpq_dsn(), await sql.run_parallel_sql_file(config.get_libpq_dsn(),
'indices.sql', min(8, threads), drop=drop) 'indices.sql', min(8, threads), drop=drop)

View File

@@ -10,7 +10,7 @@ Functions for removing unnecessary data from the database.
from typing import Optional from typing import Optional
from pathlib import Path from pathlib import Path
from psycopg2 import sql as pysql from psycopg import sql as pysql
from ..db.connection import Connection, drop_tables, table_exists from ..db.connection import Connection, drop_tables, table_exists

View File

@@ -10,7 +10,7 @@ Functions for database migration to newer software versions.
from typing import List, Tuple, Callable, Any from typing import List, Tuple, Callable, Any
import logging import logging
from psycopg2 import sql as pysql from psycopg import sql as pysql
from ..errors import UsageError from ..errors import UsageError
from ..config import Configuration from ..config import Configuration

View File

@@ -16,7 +16,7 @@ import gzip
import logging import logging
from math import isfinite from math import isfinite
from psycopg2 import sql as pysql from psycopg import sql as pysql
from ..db.connection import connect, Connection, table_exists from ..db.connection import connect, Connection, table_exists
from ..utils.centroid import PointsCentroid from ..utils.centroid import PointsCentroid
@@ -76,30 +76,30 @@ class _PostcodeCollector:
with conn.cursor() as cur: with conn.cursor() as cur:
if to_add: if to_add:
cur.execute_values( cur.executemany(pysql.SQL(
"""INSERT INTO location_postcode """INSERT INTO location_postcode
(place_id, indexed_status, country_code, (place_id, indexed_status, country_code,
postcode, geometry) VALUES %s""", postcode, geometry)
to_add, VALUES (nextval('seq_place'), 1, {}, %s,
template=pysql.SQL("""(nextval('seq_place'), 1, {}, ST_SetSRID(ST_MakePoint(%s, %s), 4326))
%s, 'SRID=4326;POINT(%s %s)') """).format(pysql.Literal(self.country)),
""").format(pysql.Literal(self.country))) to_add)
if to_delete: if to_delete:
cur.execute("""DELETE FROM location_postcode cur.execute("""DELETE FROM location_postcode
WHERE country_code = %s and postcode = any(%s) WHERE country_code = %s and postcode = any(%s)
""", (self.country, to_delete)) """, (self.country, to_delete))
if to_update: if to_update:
cur.execute_values( cur.executemany(
pysql.SQL("""UPDATE location_postcode pysql.SQL("""UPDATE location_postcode
SET indexed_status = 2, SET indexed_status = 2,
geometry = ST_SetSRID(ST_Point(v.x, v.y), 4326) geometry = ST_SetSRID(ST_Point(%s, %s), 4326)
FROM (VALUES %s) AS v (pc, x, y) WHERE country_code = {} and postcode = %s
WHERE country_code = {} and postcode = pc """).format(pysql.Literal(self.country)),
""").format(pysql.Literal(self.country)), to_update) to_update)
def _compute_changes(self, conn: Connection) \ def _compute_changes(self, conn: Connection) \
-> Tuple[List[Tuple[str, float, float]], List[str], List[Tuple[str, float, float]]]: -> Tuple[List[Tuple[str, float, float]], List[str], List[Tuple[float, float, str]]]:
""" Compute which postcodes from the collected postcodes have to be """ Compute which postcodes from the collected postcodes have to be
added or modified and which from the location_postcode table added or modified and which from the location_postcode table
have to be deleted. have to be deleted.
@@ -116,7 +116,7 @@ class _PostcodeCollector:
if pcobj: if pcobj:
newx, newy = pcobj.centroid() newx, newy = pcobj.centroid()
if (x - newx) > 0.0000001 or (y - newy) > 0.0000001: if (x - newx) > 0.0000001 or (y - newy) > 0.0000001:
to_update.append((postcode, newx, newy)) to_update.append((newx, newy, postcode))
else: else:
to_delete.append(postcode) to_delete.append(postcode)

View File

@@ -14,12 +14,12 @@ import logging
from textwrap import dedent from textwrap import dedent
from pathlib import Path from pathlib import Path
from psycopg2 import sql as pysql from psycopg import sql as pysql
from ..config import Configuration from ..config import Configuration
from ..db.connection import Connection, connect, postgis_version_tuple,\ from ..db.connection import Connection, connect, postgis_version_tuple,\
drop_tables, table_exists drop_tables, table_exists
from ..db.utils import execute_file, CopyBuffer from ..db.utils import execute_file
from ..db.sql_preprocessor import SQLPreprocessor from ..db.sql_preprocessor import SQLPreprocessor
from ..version import NOMINATIM_VERSION from ..version import NOMINATIM_VERSION
@@ -68,8 +68,8 @@ def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[s
rank_address SMALLINT) rank_address SMALLINT)
""").format(pysql.Identifier(table))) """).format(pysql.Identifier(table)))
cur.execute_values(pysql.SQL("INSERT INTO {} VALUES %s") cur.executemany(pysql.SQL("INSERT INTO {} VALUES (%s, %s, %s, %s, %s)")
.format(pysql.Identifier(table)), rows) .format(pysql.Identifier(table)), rows)
cur.execute(pysql.SQL('CREATE UNIQUE INDEX ON {} (country_code, class, type)') cur.execute(pysql.SQL('CREATE UNIQUE INDEX ON {} (country_code, class, type)')
.format(pysql.Identifier(table))) .format(pysql.Identifier(table)))
@@ -155,7 +155,7 @@ def import_importance_csv(dsn: str, data_file: Path) -> int:
if not data_file.exists(): if not data_file.exists():
return 1 return 1
# Only import the first occurence of a wikidata ID. # Only import the first occurrence of a wikidata ID.
# This keeps indexes and table small. # This keeps indexes and table small.
wd_done = set() wd_done = set()
@@ -169,24 +169,17 @@ def import_importance_csv(dsn: str, data_file: Path) -> int:
wikidata TEXT wikidata TEXT
) """) ) """)
with gzip.open(str(data_file), 'rt') as fd, CopyBuffer() as buf: copy_cmd = """COPY wikimedia_importance(language, title, importance, wikidata)
for row in csv.DictReader(fd, delimiter='\t', quotechar='|'): FROM STDIN"""
wd_id = int(row['wikidata_id'][1:]) with gzip.open(str(data_file), 'rt') as fd, cur.copy(copy_cmd) as copy:
buf.add(row['language'], row['title'], row['importance'], for row in csv.DictReader(fd, delimiter='\t', quotechar='|'):
None if wd_id in wd_done else row['wikidata_id']) wd_id = int(row['wikidata_id'][1:])
wd_done.add(wd_id) copy.write_row((row['language'],
row['title'],
row['importance'],
None if wd_id in wd_done else row['wikidata_id']))
wd_done.add(wd_id)
if buf.size() > 10000000:
with conn.cursor() as cur:
buf.copy_out(cur, 'wikimedia_importance',
columns=['language', 'title', 'importance',
'wikidata'])
with conn.cursor() as cur:
buf.copy_out(cur, 'wikimedia_importance',
columns=['language', 'title', 'importance', 'wikidata'])
with conn.cursor() as cur:
cur.execute("""CREATE INDEX IF NOT EXISTS idx_wikimedia_importance_title cur.execute("""CREATE INDEX IF NOT EXISTS idx_wikimedia_importance_title
ON wikimedia_importance (title)""") ON wikimedia_importance (title)""")
cur.execute("""CREATE INDEX IF NOT EXISTS idx_wikimedia_importance_wikidata cur.execute("""CREATE INDEX IF NOT EXISTS idx_wikimedia_importance_wikidata

View File

@@ -17,7 +17,7 @@ from typing import Iterable, Tuple, Mapping, Sequence, Optional, Set
import logging import logging
import re import re
from psycopg2.sql import Identifier, SQL from psycopg.sql import Identifier, SQL
from ...typing import Protocol from ...typing import Protocol
from ...config import Configuration from ...config import Configuration

View File

@@ -7,22 +7,22 @@
""" """
Functions for importing tiger data and handling tarbar and directory files Functions for importing tiger data and handling tarbar and directory files
""" """
from typing import Any, TextIO, List, Union, cast from typing import Any, TextIO, List, Union, cast, Iterator, Dict
import csv import csv
import io import io
import logging import logging
import os import os
import tarfile import tarfile
from psycopg2.extras import Json from psycopg.types.json import Json
from ..config import Configuration from ..config import Configuration
from ..db.connection import connect from ..db.connection import connect
from ..db.async_connection import WorkerPool
from ..db.sql_preprocessor import SQLPreprocessor from ..db.sql_preprocessor import SQLPreprocessor
from ..errors import UsageError from ..errors import UsageError
from ..db.query_pool import QueryPool
from ..data.place_info import PlaceInfo from ..data.place_info import PlaceInfo
from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer from ..tokenizer.base import AbstractTokenizer
from . import freeze from . import freeze
LOG = logging.getLogger() LOG = logging.getLogger()
@@ -63,13 +63,13 @@ class TigerInput:
self.tar_handle.close() self.tar_handle.close()
self.tar_handle = None self.tar_handle = None
def __bool__(self) -> bool:
return bool(self.files)
def next_file(self) -> TextIO: def get_file(self, fname: Union[str, tarfile.TarInfo]) -> TextIO:
""" Return a file handle to the next file to be processed. """ Return a file handle to the next file to be processed.
Raises an IndexError if there is no file left. Raises an IndexError if there is no file left.
""" """
fname = self.files.pop(0)
if self.tar_handle is not None: if self.tar_handle is not None:
extracted = self.tar_handle.extractfile(fname) extracted = self.tar_handle.extractfile(fname)
assert extracted is not None assert extracted is not None
@@ -78,47 +78,22 @@ class TigerInput:
return open(cast(str, fname), encoding='utf-8') return open(cast(str, fname), encoding='utf-8')
def __len__(self) -> int: def __iter__(self) -> Iterator[Dict[str, Any]]:
return len(self.files) """ Iterate over the lines in each file.
"""
for fname in self.files:
fd = self.get_file(fname)
yield from csv.DictReader(fd, delimiter=';')
def handle_threaded_sql_statements(pool: WorkerPool, fd: TextIO, async def add_tiger_data(data_dir: str, config: Configuration, threads: int,
analyzer: AbstractAnalyzer) -> None:
""" Handles sql statement with multiplexing
"""
lines = 0
# Using pool of database connections to execute sql statements
sql = "SELECT tiger_line_import(%s, %s, %s, %s, %s, %s)"
for row in csv.DictReader(fd, delimiter=';'):
try:
address = dict(street=row['street'], postcode=row['postcode'])
args = ('SRID=4326;' + row['geometry'],
int(row['from']), int(row['to']), row['interpolation'],
Json(analyzer.process_place(PlaceInfo({'address': address}))),
analyzer.normalize_postcode(row['postcode']))
except ValueError:
continue
pool.next_free_worker().perform(sql, args=args)
lines += 1
if lines == 1000:
print('.', end='', flush=True)
lines = 0
def add_tiger_data(data_dir: str, config: Configuration, threads: int,
tokenizer: AbstractTokenizer) -> int: tokenizer: AbstractTokenizer) -> int:
""" Import tiger data from directory or tar file `data dir`. """ Import tiger data from directory or tar file `data dir`.
""" """
dsn = config.get_libpq_dsn() dsn = config.get_libpq_dsn()
with connect(dsn) as conn: with connect(dsn) as conn:
is_frozen = freeze.is_frozen(conn) if freeze.is_frozen(conn):
conn.close()
if is_frozen:
raise UsageError("Tiger cannot be imported when database frozen (Github issue #3048)") raise UsageError("Tiger cannot be imported when database frozen (Github issue #3048)")
with TigerInput(data_dir) as tar: with TigerInput(data_dir) as tar:
@@ -133,13 +108,30 @@ def add_tiger_data(data_dir: str, config: Configuration, threads: int,
# sql_query in <threads - 1> chunks. # sql_query in <threads - 1> chunks.
place_threads = max(1, threads - 1) place_threads = max(1, threads - 1)
with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool: async with QueryPool(dsn, place_threads, autocommit=True) as pool:
with tokenizer.name_analyzer() as analyzer: with tokenizer.name_analyzer() as analyzer:
while tar: lines = 0
with tar.next_file() as fd: for row in tar:
handle_threaded_sql_statements(pool, fd, analyzer) try:
address = dict(street=row['street'], postcode=row['postcode'])
args = ('SRID=4326;' + row['geometry'],
int(row['from']), int(row['to']), row['interpolation'],
Json(analyzer.process_place(PlaceInfo({'address': address}))),
analyzer.normalize_postcode(row['postcode']))
except ValueError:
continue
print('\n') await pool.put_query(
"""SELECT tiger_line_import(%s::GEOMETRY, %s::INT,
%s::INT, %s::TEXT, %s::JSONB, %s::TEXT)""",
args)
lines += 1
if lines == 1000:
print('.', end='', flush=True)
lines = 0
print('', flush=True)
LOG.warning("Creating indexes on Tiger data") LOG.warning("Creating indexes on Tiger data")
with connect(dsn) as conn: with connect(dsn) as conn:

View File

@@ -16,18 +16,13 @@ from typing import Any, Union, Mapping, TypeVar, Sequence, TYPE_CHECKING
# pylint: disable=missing-class-docstring,useless-import-alias # pylint: disable=missing-class-docstring,useless-import-alias
if TYPE_CHECKING: if TYPE_CHECKING:
import psycopg2.sql
import psycopg2.extensions
import psycopg2.extras
import os import os
StrPath = Union[str, 'os.PathLike[str]'] StrPath = Union[str, 'os.PathLike[str]']
SysEnv = Mapping[str, str] SysEnv = Mapping[str, str]
# psycopg2-related types # psycopg-related types
Query = Union[str, bytes, 'psycopg2.sql.Composable']
T_ResultKey = TypeVar('T_ResultKey', int, str) T_ResultKey = TypeVar('T_ResultKey', int, str)
@@ -36,8 +31,6 @@ class DictCursorResult(Mapping[str, Any]):
DictCursorResults = Sequence[DictCursorResult] DictCursorResults = Sequence[DictCursorResult]
T_cursor = TypeVar('T_cursor', bound='psycopg2.extensions.cursor')
# The following typing features require typing_extensions to work # The following typing features require typing_extensions to work
# on all supported Python versions. # on all supported Python versions.
# Only require this for type checking but not for normal operations. # Only require this for type checking but not for normal operations.

View File

@@ -31,7 +31,7 @@ class NominatimVersion(NamedTuple):
major: int major: int
minor: int minor: int
patch_level: int patch_level: int
db_patch_level: Optional[int] db_patch_level: int
def __str__(self) -> str: def __str__(self) -> str:
if self.db_patch_level is None: if self.db_patch_level is None:
@@ -47,6 +47,7 @@ class NominatimVersion(NamedTuple):
return f"{self.major}.{self.minor}.{self.patch_level}" return f"{self.major}.{self.minor}.{self.patch_level}"
def parse_version(version: str) -> NominatimVersion: def parse_version(version: str) -> NominatimVersion:
""" Parse a version string into a version consisting of a tuple of """ Parse a version string into a version consisting of a tuple of
four ints: major, minor, patch level, database patch level four ints: major, minor, patch level, database patch level

View File

@@ -9,14 +9,14 @@ import importlib
import sys import sys
import tempfile import tempfile
import psycopg2 import psycopg
import psycopg2.extras from psycopg import sql as pysql
sys.path.insert(1, str((Path(__file__) / '..' / '..' / '..' / '..'/ 'src').resolve())) sys.path.insert(1, str((Path(__file__) / '..' / '..' / '..' / '..'/ 'src').resolve()))
from nominatim_db import cli from nominatim_db import cli
from nominatim_db.config import Configuration from nominatim_db.config import Configuration
from nominatim_db.db.connection import Connection from nominatim_db.db.connection import Connection, register_hstore, execute_scalar
from nominatim_db.tools import refresh from nominatim_db.tools import refresh
from nominatim_db.tokenizer import factory as tokenizer_factory from nominatim_db.tokenizer import factory as tokenizer_factory
from steps.utils import run_script from steps.utils import run_script
@@ -60,7 +60,7 @@ class NominatimEnvironment:
""" Return a connection to the database with the given name. """ Return a connection to the database with the given name.
Uses configured host, user and port. Uses configured host, user and port.
""" """
dbargs = {'database': dbname} dbargs = {'dbname': dbname, 'row_factory': psycopg.rows.dict_row}
if self.db_host: if self.db_host:
dbargs['host'] = self.db_host dbargs['host'] = self.db_host
if self.db_port: if self.db_port:
@@ -69,8 +69,7 @@ class NominatimEnvironment:
dbargs['user'] = self.db_user dbargs['user'] = self.db_user
if self.db_pass: if self.db_pass:
dbargs['password'] = self.db_pass dbargs['password'] = self.db_pass
conn = psycopg2.connect(connection_factory=Connection, **dbargs) return psycopg.connect(**dbargs)
return conn
def next_code_coverage_file(self): def next_code_coverage_file(self):
""" Generate the next name for a coverage file. """ Generate the next name for a coverage file.
@@ -132,6 +131,8 @@ class NominatimEnvironment:
conn = False conn = False
refresh.setup_website(Path(self.website_dir.name) / 'website', refresh.setup_website(Path(self.website_dir.name) / 'website',
self.get_test_config(), conn) self.get_test_config(), conn)
if conn:
conn.close()
def get_test_config(self): def get_test_config(self):
@@ -160,11 +161,10 @@ class NominatimEnvironment:
def db_drop_database(self, name): def db_drop_database(self, name):
""" Drop the database with the given name. """ Drop the database with the given name.
""" """
conn = self.connect_database('postgres') with self.connect_database('postgres') as conn:
conn.set_isolation_level(0) conn.autocommit = True
cur = conn.cursor() conn.execute(pysql.SQL('DROP DATABASE IF EXISTS')
cur.execute('DROP DATABASE IF EXISTS {}'.format(name)) + pysql.Identifier(name))
conn.close()
def setup_template_db(self): def setup_template_db(self):
""" Setup a template database that already contains common test data. """ Setup a template database that already contains common test data.
@@ -249,16 +249,18 @@ class NominatimEnvironment:
""" Setup a test against a fresh, empty test database. """ Setup a test against a fresh, empty test database.
""" """
self.setup_template_db() self.setup_template_db()
conn = self.connect_database(self.template_db) with self.connect_database(self.template_db) as conn:
conn.set_isolation_level(0) conn.autocommit = True
cur = conn.cursor() conn.execute(pysql.SQL('DROP DATABASE IF EXISTS')
cur.execute('DROP DATABASE IF EXISTS {}'.format(self.test_db)) + pysql.Identifier(self.test_db))
cur.execute('CREATE DATABASE {} TEMPLATE = {}'.format(self.test_db, self.template_db)) conn.execute(pysql.SQL('CREATE DATABASE {} TEMPLATE = {}').format(
conn.close() pysql.Identifier(self.test_db),
pysql.Identifier(self.template_db)))
self.write_nominatim_config(self.test_db) self.write_nominatim_config(self.test_db)
context.db = self.connect_database(self.test_db) context.db = self.connect_database(self.test_db)
context.db.autocommit = True context.db.autocommit = True
psycopg2.extras.register_hstore(context.db, globally=False) register_hstore(context.db)
def teardown_db(self, context, force_drop=False): def teardown_db(self, context, force_drop=False):
""" Remove the test database, if it exists. """ Remove the test database, if it exists.
@@ -276,31 +278,26 @@ class NominatimEnvironment:
dropped and always false returned. dropped and always false returned.
""" """
if self.reuse_template: if self.reuse_template:
conn = self.connect_database('postgres') with self.connect_database('postgres') as conn:
with conn.cursor() as cur: num = execute_scalar(conn,
cur.execute('select count(*) from pg_database where datname = %s', 'select count(*) from pg_database where datname = %s',
(name,)) (name,))
if cur.fetchone()[0] == 1: if num == 1:
return True return True
conn.close()
else: else:
self.db_drop_database(name) self.db_drop_database(name)
return False return False
def reindex_placex(self, db): def reindex_placex(self, db):
""" Run the indexing step until all data in the placex has """ Run the indexing step until all data in the placex has
been processed. Indexing during updates can produce more data been processed. Indexing during updates can produce more data
to index under some circumstances. That is why indexing may have to index under some circumstances. That is why indexing may have
to be run multiple times. to be run multiple times.
""" """
with db.cursor() as cur: self.run_nominatim('index')
while True:
self.run_nominatim('index')
cur.execute("SELECT 'a' FROM placex WHERE indexed_status != 0 LIMIT 1")
if cur.rowcount == 0:
return
def run_nominatim(self, *cmdline): def run_nominatim(self, *cmdline):
""" Run the nominatim command-line tool via the library. """ Run the nominatim command-line tool via the library.

View File

@@ -7,7 +7,8 @@
import logging import logging
from itertools import chain from itertools import chain
import psycopg2.extras import psycopg
from psycopg import sql as pysql
from place_inserter import PlaceColumn from place_inserter import PlaceColumn
from table_compare import NominatimID, DBRow from table_compare import NominatimID, DBRow
@@ -18,7 +19,7 @@ from nominatim_db.tokenizer import factory as tokenizer_factory
def check_database_integrity(context): def check_database_integrity(context):
""" Check some generic constraints on the tables. """ Check some generic constraints on the tables.
""" """
with context.db.cursor() as cur: with context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
# place_addressline should not have duplicate (place_id, address_place_id) # place_addressline should not have duplicate (place_id, address_place_id)
cur.execute("""SELECT count(*) FROM cur.execute("""SELECT count(*) FROM
(SELECT place_id, address_place_id, count(*) as c (SELECT place_id, address_place_id, count(*) as c
@@ -54,7 +55,7 @@ def add_data_to_planet_relations(context):
with context.db.cursor() as cur: with context.db.cursor() as cur:
cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'") cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
row = cur.fetchone() row = cur.fetchone()
if row is None or row[0] == '1': if row is None or row['value'] == '1':
for r in context.table: for r in context.table:
last_node = 0 last_node = 0
last_way = 0 last_way = 0
@@ -96,8 +97,8 @@ def add_data_to_planet_relations(context):
cur.execute("""INSERT INTO planet_osm_rels (id, tags, members) cur.execute("""INSERT INTO planet_osm_rels (id, tags, members)
VALUES (%s, %s, %s)""", VALUES (%s, %s, %s)""",
(r['id'], psycopg2.extras.Json(tags), (r['id'], psycopg.types.json.Json(tags),
psycopg2.extras.Json(members))) psycopg.types.json.Json(members)))
@given("the ways") @given("the ways")
def add_data_to_planet_ways(context): def add_data_to_planet_ways(context):
@@ -107,10 +108,10 @@ def add_data_to_planet_ways(context):
with context.db.cursor() as cur: with context.db.cursor() as cur:
cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'") cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
row = cur.fetchone() row = cur.fetchone()
json_tags = row is not None and row[0] != '1' json_tags = row is not None and row['value'] != '1'
for r in context.table: for r in context.table:
if json_tags: if json_tags:
tags = psycopg2.extras.Json({h[5:]: r[h] for h in r.headings if h.startswith("tags+")}) tags = psycopg.types.json.Json({h[5:]: r[h] for h in r.headings if h.startswith("tags+")})
else: else:
tags = list(chain.from_iterable([(h[5:], r[h]) tags = list(chain.from_iterable([(h[5:], r[h])
for h in r.headings if h.startswith("tags+")])) for h in r.headings if h.startswith("tags+")]))
@@ -197,7 +198,7 @@ def check_place_contents(context, table, exact):
expected rows are expected to be present with at least one database row. expected rows are expected to be present with at least one database row.
When 'exactly' is given, there must not be additional rows in the database. When 'exactly' is given, there must not be additional rows in the database.
""" """
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with context.db.cursor() as cur:
expected_content = set() expected_content = set()
for row in context.table: for row in context.table:
nid = NominatimID(row['object']) nid = NominatimID(row['object'])
@@ -215,8 +216,9 @@ def check_place_contents(context, table, exact):
DBRow(nid, res, context).assert_row(row, ['object']) DBRow(nid, res, context).assert_row(row, ['object'])
if exact: if exact:
cur.execute('SELECT osm_type, osm_id, class from {}'.format(table)) cur.execute(pysql.SQL('SELECT osm_type, osm_id, class from')
actual = set([(r[0], r[1], r[2]) for r in cur]) + pysql.Identifier(table))
actual = set([(r['osm_type'], r['osm_id'], r['class']) for r in cur])
assert expected_content == actual, \ assert expected_content == actual, \
f"Missing entries: {expected_content - actual}\n" \ f"Missing entries: {expected_content - actual}\n" \
f"Not expected in table: {actual - expected_content}" f"Not expected in table: {actual - expected_content}"
@@ -227,7 +229,7 @@ def check_place_has_entry(context, table, oid):
""" Ensure that no database row for the given object exists. The ID """ Ensure that no database row for the given object exists. The ID
must be of the form '<NRW><osm id>[:<class>]'. must be of the form '<NRW><osm id>[:<class>]'.
""" """
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with context.db.cursor() as cur:
NominatimID(oid).query_osm_id(cur, "SELECT * FROM %s where {}" % table) NominatimID(oid).query_osm_id(cur, "SELECT * FROM %s where {}" % table)
assert cur.rowcount == 0, \ assert cur.rowcount == 0, \
"Found {} entries for ID {}".format(cur.rowcount, oid) "Found {} entries for ID {}".format(cur.rowcount, oid)
@@ -244,7 +246,7 @@ def check_search_name_contents(context, exclude):
tokenizer = tokenizer_factory.get_tokenizer_for_db(context.nominatim.get_test_config()) tokenizer = tokenizer_factory.get_tokenizer_for_db(context.nominatim.get_test_config())
with tokenizer.name_analyzer() as analyzer: with tokenizer.name_analyzer() as analyzer:
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with context.db.cursor() as cur:
for row in context.table: for row in context.table:
nid = NominatimID(row['object']) nid = NominatimID(row['object'])
nid.row_by_place_id(cur, 'search_name', nid.row_by_place_id(cur, 'search_name',
@@ -276,7 +278,7 @@ def check_search_name_has_entry(context, oid):
""" Check that there is noentry in the search_name table for the given """ Check that there is noentry in the search_name table for the given
objects. IDs are in format '<NRW><osm id>[:<class>]'. objects. IDs are in format '<NRW><osm id>[:<class>]'.
""" """
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with context.db.cursor() as cur:
NominatimID(oid).row_by_place_id(cur, 'search_name') NominatimID(oid).row_by_place_id(cur, 'search_name')
assert cur.rowcount == 0, \ assert cur.rowcount == 0, \
@@ -290,7 +292,7 @@ def check_location_postcode(context):
All rows must be present as excepted and there must not be additional All rows must be present as excepted and there must not be additional
rows. rows.
""" """
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with context.db.cursor() as cur:
cur.execute("SELECT *, ST_AsText(geometry) as geomtxt FROM location_postcode") cur.execute("SELECT *, ST_AsText(geometry) as geomtxt FROM location_postcode")
assert cur.rowcount == len(list(context.table)), \ assert cur.rowcount == len(list(context.table)), \
"Postcode table has {} rows, expected {}.".format(cur.rowcount, len(list(context.table))) "Postcode table has {} rows, expected {}.".format(cur.rowcount, len(list(context.table)))
@@ -321,7 +323,7 @@ def check_word_table_for_postcodes(context, exclude, postcodes):
plist.sort() plist.sort()
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with context.db.cursor() as cur:
if nctx.tokenizer != 'legacy': if nctx.tokenizer != 'legacy':
cur.execute("SELECT word FROM word WHERE type = 'P' and word = any(%s)", cur.execute("SELECT word FROM word WHERE type = 'P' and word = any(%s)",
(plist,)) (plist,))
@@ -330,7 +332,7 @@ def check_word_table_for_postcodes(context, exclude, postcodes):
and class = 'place' and type = 'postcode'""", and class = 'place' and type = 'postcode'""",
(plist,)) (plist,))
found = [row[0] for row in cur] found = [row['word'] for row in cur]
assert len(found) == len(set(found)), f"Duplicate rows for postcodes: {found}" assert len(found) == len(set(found)), f"Duplicate rows for postcodes: {found}"
if exclude: if exclude:
@@ -347,7 +349,7 @@ def check_place_addressline(context):
representing the addressee and the 'address' column, representing the representing the addressee and the 'address' column, representing the
address item. address item.
""" """
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with context.db.cursor() as cur:
for row in context.table: for row in context.table:
nid = NominatimID(row['object']) nid = NominatimID(row['object'])
pid = nid.get_place_id(cur) pid = nid.get_place_id(cur)
@@ -366,7 +368,7 @@ def check_place_addressline_exclude(context):
""" Check that the place_addressline doesn't contain any entries for the """ Check that the place_addressline doesn't contain any entries for the
given addressee/address item pairs. given addressee/address item pairs.
""" """
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with context.db.cursor() as cur:
for row in context.table: for row in context.table:
pid = NominatimID(row['object']).get_place_id(cur) pid = NominatimID(row['object']).get_place_id(cur)
apid = NominatimID(row['address']).get_place_id(cur, allow_empty=True) apid = NominatimID(row['address']).get_place_id(cur, allow_empty=True)
@@ -381,7 +383,7 @@ def check_place_addressline_exclude(context):
def check_location_property_osmline(context, oid, neg): def check_location_property_osmline(context, oid, neg):
""" Check that the given way is present in the interpolation table. """ Check that the given way is present in the interpolation table.
""" """
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with context.db.cursor() as cur:
cur.execute("""SELECT *, ST_AsText(linegeo) as geomtxt cur.execute("""SELECT *, ST_AsText(linegeo) as geomtxt
FROM location_property_osmline FROM location_property_osmline
WHERE osm_id = %s AND startnumber IS NOT NULL""", WHERE osm_id = %s AND startnumber IS NOT NULL""",
@@ -417,7 +419,7 @@ def check_place_contents(context, exact):
expected rows are expected to be present with at least one database row. expected rows are expected to be present with at least one database row.
When 'exactly' is given, there must not be additional rows in the database. When 'exactly' is given, there must not be additional rows in the database.
""" """
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with context.db.cursor() as cur:
expected_content = set() expected_content = set()
for row in context.table: for row in context.table:
if ':' in row['object']: if ':' in row['object']:
@@ -447,7 +449,7 @@ def check_place_contents(context, exact):
if exact: if exact:
cur.execute('SELECT osm_id, startnumber from location_property_osmline') cur.execute('SELECT osm_id, startnumber from location_property_osmline')
actual = set([(r[0], r[1]) for r in cur]) actual = set([(r['osm_id'], r['startnumber']) for r in cur])
assert expected_content == actual, \ assert expected_content == actual, \
f"Missing entries: {expected_content - actual}\n" \ f"Missing entries: {expected_content - actual}\n" \
f"Not expected in table: {actual - expected_content}" f"Not expected in table: {actual - expected_content}"

View File

@@ -10,6 +10,9 @@ Functions to facilitate accessing and comparing the content of DB tables.
import re import re
import json import json
import psycopg
from psycopg import sql as pysql
from steps.check_functions import Almost from steps.check_functions import Almost
ID_REGEX = re.compile(r"(?P<typ>[NRW])(?P<oid>\d+)(:(?P<cls>\w+))?") ID_REGEX = re.compile(r"(?P<typ>[NRW])(?P<oid>\d+)(:(?P<cls>\w+))?")
@@ -73,7 +76,7 @@ class NominatimID:
assert cur.rowcount == 1, \ assert cur.rowcount == 1, \
"Place ID {!s} not unique. Found {} entries.".format(self, cur.rowcount) "Place ID {!s} not unique. Found {} entries.".format(self, cur.rowcount)
return cur.fetchone()[0] return cur.fetchone()['place_id']
class DBRow: class DBRow:
@@ -152,9 +155,10 @@ class DBRow:
def _has_centroid(self, expected): def _has_centroid(self, expected):
if expected == 'in geometry': if expected == 'in geometry':
with self.context.db.cursor() as cur: with self.context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
cur.execute("""SELECT ST_Within(ST_SetSRID(ST_Point({cx}, {cy}), 4326), cur.execute("""SELECT ST_Within(ST_SetSRID(ST_Point(%(cx)s, %(cy)s), 4326),
ST_SetSRID('{geomtxt}'::geometry, 4326))""".format(**self.db_row)) ST_SetSRID(%(geomtxt)s::geometry, 4326))""",
(self.db_row))
return cur.fetchone()[0] return cur.fetchone()[0]
if ' ' in expected: if ' ' in expected:
@@ -166,10 +170,11 @@ class DBRow:
def _has_geometry(self, expected): def _has_geometry(self, expected):
geom = self.context.osm.parse_geometry(expected) geom = self.context.osm.parse_geometry(expected)
with self.context.db.cursor() as cur: with self.context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
cur.execute("""SELECT ST_Equals(ST_SnapToGrid({}, 0.00001, 0.00001), cur.execute(pysql.SQL("""SELECT ST_Equals(ST_SnapToGrid({}, 0.00001, 0.00001),
ST_SnapToGrid(ST_SetSRID('{}'::geometry, 4326), 0.00001, 0.00001))""".format( ST_SnapToGrid(ST_SetSRID({}::geometry, 4326), 0.00001, 0.00001))""")
geom, self.db_row['geomtxt'])) .format(pysql.SQL(geom),
pysql.Literal(self.db_row['geomtxt'])))
return cur.fetchone()[0] return cur.fetchone()[0]
def assert_msg(self, name, value): def assert_msg(self, name, value):
@@ -209,7 +214,7 @@ class DBRow:
if actual == 0: if actual == 0:
return "place ID 0" return "place ID 0"
with self.context.db.cursor() as cur: with self.context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
cur.execute("""SELECT osm_type, osm_id, class cur.execute("""SELECT osm_type, osm_id, class
FROM placex WHERE place_id = %s""", FROM placex WHERE place_id = %s""",
(actual, )) (actual, ))

View File

@@ -13,8 +13,6 @@ from pathlib import Path
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import psycopg2.extras
from fake_adaptor import FakeAdaptor, FakeError, FakeResponse from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
import nominatim_api.v1.server_glue as glue import nominatim_api.v1.server_glue as glue
@@ -31,7 +29,6 @@ class TestDeletableEndPoint:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup_deletable_table(self, temp_db_cursor, table_factory, temp_db_with_extensions): def setup_deletable_table(self, temp_db_cursor, table_factory, temp_db_with_extensions):
psycopg2.extras.register_hstore(temp_db_cursor)
table_factory('import_polygon_delete', table_factory('import_polygon_delete',
definition='osm_id bigint, osm_type char(1), class text, type text', definition='osm_id bigint, osm_type char(1), class text, type text',
content=[(345, 'N', 'boundary', 'administrative'), content=[(345, 'N', 'boundary', 'administrative'),

View File

@@ -14,8 +14,6 @@ from pathlib import Path
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import psycopg2.extras
from fake_adaptor import FakeAdaptor, FakeError, FakeResponse from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
import nominatim_api.v1.server_glue as glue import nominatim_api.v1.server_glue as glue
@@ -32,8 +30,6 @@ class TestPolygonsEndPoint:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup_deletable_table(self, temp_db_cursor, table_factory, temp_db_with_extensions): def setup_deletable_table(self, temp_db_cursor, table_factory, temp_db_with_extensions):
psycopg2.extras.register_hstore(temp_db_cursor)
self.now = dt.datetime.now() self.now = dt.datetime.now()
self.recent = dt.datetime.now() - dt.timedelta(days=3) self.recent = dt.datetime.now() - dt.timedelta(days=3)

View File

@@ -25,6 +25,23 @@ class MockParamCapture:
return self.return_value return self.return_value
class AsyncMockParamCapture:
""" Mock that records the parameters with which a function was called
as well as the number of calls.
"""
def __init__(self, retval=0):
self.called = 0
self.return_value = retval
self.last_args = None
self.last_kwargs = None
async def __call__(self, *args, **kwargs):
self.called += 1
self.last_args = args
self.last_kwargs = kwargs
return self.return_value
class DummyTokenizer: class DummyTokenizer:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.update_sql_functions_called = False self.update_sql_functions_called = False
@@ -69,6 +86,17 @@ def mock_func_factory(monkeypatch):
return get_mock return get_mock
@pytest.fixture
def async_mock_func_factory(monkeypatch):
def get_mock(module, func):
mock = AsyncMockParamCapture()
mock.func_name = func
monkeypatch.setattr(module, func, mock)
return mock
return get_mock
@pytest.fixture @pytest.fixture
def cli_tokenizer_mock(monkeypatch): def cli_tokenizer_mock(monkeypatch):
tok = DummyTokenizer() tok = DummyTokenizer()

View File

@@ -17,6 +17,7 @@ import pytest
import nominatim_db.indexer.indexer import nominatim_db.indexer.indexer
import nominatim_db.tools.add_osm_data import nominatim_db.tools.add_osm_data
import nominatim_db.tools.freeze import nominatim_db.tools.freeze
import nominatim_db.tools.tiger_data
def test_cli_help(cli_call, capsys): def test_cli_help(cli_call, capsys):
@@ -52,8 +53,8 @@ def test_cli_add_data_object_command(cli_call, mock_func_factory, name, oid):
def test_cli_add_data_tiger_data(cli_call, cli_tokenizer_mock, mock_func_factory): def test_cli_add_data_tiger_data(cli_call, cli_tokenizer_mock, async_mock_func_factory):
mock = mock_func_factory(nominatim_db.tools.tiger_data, 'add_tiger_data') mock = async_mock_func_factory(nominatim_db.tools.tiger_data, 'add_tiger_data')
assert cli_call('add-data', '--tiger-data', 'somewhere') == 0 assert cli_call('add-data', '--tiger-data', 'somewhere') == 0
@@ -68,38 +69,6 @@ def test_cli_serve_php(cli_call, mock_func_factory):
assert func.called == 1 assert func.called == 1
def test_cli_serve_starlette_custom_server(cli_call, mock_func_factory):
pytest.importorskip("starlette")
mod = pytest.importorskip("uvicorn")
func = mock_func_factory(mod, "run")
cli_call('serve', '--engine', 'starlette', '--server', 'foobar:4545') == 0
assert func.called == 1
assert func.last_kwargs['host'] == 'foobar'
assert func.last_kwargs['port'] == 4545
def test_cli_serve_starlette_custom_server_bad_port(cli_call, mock_func_factory):
pytest.importorskip("starlette")
mod = pytest.importorskip("uvicorn")
func = mock_func_factory(mod, "run")
cli_call('serve', '--engine', 'starlette', '--server', 'foobar:45:45') == 1
@pytest.mark.parametrize("engine", ['falcon', 'starlette'])
def test_cli_serve_uvicorn_based(cli_call, engine, mock_func_factory):
pytest.importorskip(engine)
mod = pytest.importorskip("uvicorn")
func = mock_func_factory(mod, "run")
cli_call('serve', '--engine', engine) == 0
assert func.called == 1
assert func.last_kwargs['host'] == '127.0.0.1'
assert func.last_kwargs['port'] == 8088
class TestCliWithDb: class TestCliWithDb:
@@ -120,16 +89,19 @@ class TestCliWithDb:
@pytest.mark.parametrize("params,do_bnds,do_ranks", [ @pytest.mark.parametrize("params,do_bnds,do_ranks", [
([], 1, 1), ([], 2, 2),
(['--boundaries-only'], 1, 0), (['--boundaries-only'], 2, 0),
(['--no-boundaries'], 0, 1), (['--no-boundaries'], 0, 2),
(['--boundaries-only', '--no-boundaries'], 0, 0)]) (['--boundaries-only', '--no-boundaries'], 0, 0)])
def test_index_command(self, mock_func_factory, table_factory, def test_index_command(self, monkeypatch, async_mock_func_factory, table_factory,
params, do_bnds, do_ranks): params, do_bnds, do_ranks):
table_factory('import_status', 'indexed bool') table_factory('import_status', 'indexed bool')
bnd_mock = mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_boundaries') bnd_mock = async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_boundaries')
rank_mock = mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_by_rank') rank_mock = async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_by_rank')
postcode_mock = mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_postcodes') postcode_mock = async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_postcodes')
monkeypatch.setattr(nominatim_db.indexer.indexer.Indexer, 'has_pending',
[False, True].pop)
assert self.call_nominatim('index', *params) == 0 assert self.call_nominatim('index', *params) == 0

View File

@@ -34,7 +34,8 @@ class TestCliImportWithDb:
@pytest.mark.parametrize('with_updates', [True, False]) @pytest.mark.parametrize('with_updates', [True, False])
def test_import_full(self, mock_func_factory, with_updates, place_table, property_table): def test_import_full(self, mock_func_factory, async_mock_func_factory,
with_updates, place_table, property_table):
mocks = [ mocks = [
mock_func_factory(nominatim_db.tools.database_import, 'setup_database_skeleton'), mock_func_factory(nominatim_db.tools.database_import, 'setup_database_skeleton'),
mock_func_factory(nominatim_db.data.country_info, 'setup_country_tables'), mock_func_factory(nominatim_db.data.country_info, 'setup_country_tables'),
@@ -42,15 +43,15 @@ class TestCliImportWithDb:
mock_func_factory(nominatim_db.tools.refresh, 'import_wikipedia_articles'), mock_func_factory(nominatim_db.tools.refresh, 'import_wikipedia_articles'),
mock_func_factory(nominatim_db.tools.refresh, 'import_secondary_importance'), mock_func_factory(nominatim_db.tools.refresh, 'import_secondary_importance'),
mock_func_factory(nominatim_db.tools.database_import, 'truncate_data_tables'), mock_func_factory(nominatim_db.tools.database_import, 'truncate_data_tables'),
mock_func_factory(nominatim_db.tools.database_import, 'load_data'), async_mock_func_factory(nominatim_db.tools.database_import, 'load_data'),
mock_func_factory(nominatim_db.tools.database_import, 'create_tables'), mock_func_factory(nominatim_db.tools.database_import, 'create_tables'),
mock_func_factory(nominatim_db.tools.database_import, 'create_table_triggers'), mock_func_factory(nominatim_db.tools.database_import, 'create_table_triggers'),
mock_func_factory(nominatim_db.tools.database_import, 'create_partition_tables'), mock_func_factory(nominatim_db.tools.database_import, 'create_partition_tables'),
mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'), async_mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
mock_func_factory(nominatim_db.data.country_info, 'create_country_names'), mock_func_factory(nominatim_db.data.country_info, 'create_country_names'),
mock_func_factory(nominatim_db.tools.refresh, 'load_address_levels_from_config'), mock_func_factory(nominatim_db.tools.refresh, 'load_address_levels_from_config'),
mock_func_factory(nominatim_db.tools.postcodes, 'update_postcodes'), mock_func_factory(nominatim_db.tools.postcodes, 'update_postcodes'),
mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'), async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'),
mock_func_factory(nominatim_db.tools.refresh, 'setup_website'), mock_func_factory(nominatim_db.tools.refresh, 'setup_website'),
] ]
@@ -72,14 +73,14 @@ class TestCliImportWithDb:
assert mock.called == 1, "Mock '{}' not called".format(mock.func_name) assert mock.called == 1, "Mock '{}' not called".format(mock.func_name)
def test_import_continue_load_data(self, mock_func_factory): def test_import_continue_load_data(self, mock_func_factory, async_mock_func_factory):
mocks = [ mocks = [
mock_func_factory(nominatim_db.tools.database_import, 'truncate_data_tables'), mock_func_factory(nominatim_db.tools.database_import, 'truncate_data_tables'),
mock_func_factory(nominatim_db.tools.database_import, 'load_data'), async_mock_func_factory(nominatim_db.tools.database_import, 'load_data'),
mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'), async_mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
mock_func_factory(nominatim_db.data.country_info, 'create_country_names'), mock_func_factory(nominatim_db.data.country_info, 'create_country_names'),
mock_func_factory(nominatim_db.tools.postcodes, 'update_postcodes'), mock_func_factory(nominatim_db.tools.postcodes, 'update_postcodes'),
mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'), async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'),
mock_func_factory(nominatim_db.tools.refresh, 'setup_website'), mock_func_factory(nominatim_db.tools.refresh, 'setup_website'),
mock_func_factory(nominatim_db.db.properties, 'set_property') mock_func_factory(nominatim_db.db.properties, 'set_property')
] ]
@@ -91,12 +92,12 @@ class TestCliImportWithDb:
assert mock.called == 1, "Mock '{}' not called".format(mock.func_name) assert mock.called == 1, "Mock '{}' not called".format(mock.func_name)
def test_import_continue_indexing(self, mock_func_factory, placex_table, def test_import_continue_indexing(self, mock_func_factory, async_mock_func_factory,
temp_db_conn): placex_table, temp_db_conn):
mocks = [ mocks = [
mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'), async_mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
mock_func_factory(nominatim_db.data.country_info, 'create_country_names'), mock_func_factory(nominatim_db.data.country_info, 'create_country_names'),
mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'), async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'),
mock_func_factory(nominatim_db.tools.refresh, 'setup_website'), mock_func_factory(nominatim_db.tools.refresh, 'setup_website'),
mock_func_factory(nominatim_db.db.properties, 'set_property') mock_func_factory(nominatim_db.db.properties, 'set_property')
] ]
@@ -110,9 +111,9 @@ class TestCliImportWithDb:
assert self.call_nominatim('import', '--continue', 'indexing') == 0 assert self.call_nominatim('import', '--continue', 'indexing') == 0
def test_import_continue_postprocess(self, mock_func_factory): def test_import_continue_postprocess(self, mock_func_factory, async_mock_func_factory):
mocks = [ mocks = [
mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'), async_mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
mock_func_factory(nominatim_db.data.country_info, 'create_country_names'), mock_func_factory(nominatim_db.data.country_info, 'create_country_names'),
mock_func_factory(nominatim_db.tools.refresh, 'setup_website'), mock_func_factory(nominatim_db.tools.refresh, 'setup_website'),
mock_func_factory(nominatim_db.db.properties, 'set_property') mock_func_factory(nominatim_db.db.properties, 'set_property')

View File

@@ -45,9 +45,9 @@ class TestRefresh:
assert self.tokenizer_mock.update_word_tokens_called assert self.tokenizer_mock.update_word_tokens_called
def test_refresh_postcodes(self, mock_func_factory, place_table): def test_refresh_postcodes(self, async_mock_func_factory, mock_func_factory, place_table):
func_mock = mock_func_factory(nominatim_db.tools.postcodes, 'update_postcodes') func_mock = mock_func_factory(nominatim_db.tools.postcodes, 'update_postcodes')
idx_mock = mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_postcodes') idx_mock = async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_postcodes')
assert self.call_nominatim('refresh', '--postcodes') == 0 assert self.call_nominatim('refresh', '--postcodes') == 0
assert func_mock.called == 1 assert func_mock.called == 1

View File

@@ -47,8 +47,8 @@ def init_status(temp_db_conn, status_table):
@pytest.fixture @pytest.fixture
def index_mock(mock_func_factory, tokenizer_mock, init_status): def index_mock(async_mock_func_factory, tokenizer_mock, init_status):
return mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full') return async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full')
@pytest.fixture @pytest.fixture

View File

@@ -8,7 +8,8 @@ import itertools
import sys import sys
from pathlib import Path from pathlib import Path
import psycopg2 import psycopg
from psycopg import sql as pysql
import pytest import pytest
# always test against the source # always test against the source
@@ -36,26 +37,23 @@ def temp_db(monkeypatch):
exported into NOMINATIM_DATABASE_DSN. exported into NOMINATIM_DATABASE_DSN.
""" """
name = 'test_nominatim_python_unittest' name = 'test_nominatim_python_unittest'
conn = psycopg2.connect(database='postgres')
conn.set_isolation_level(0) with psycopg.connect(dbname='postgres', autocommit=True) as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute('DROP DATABASE IF EXISTS {}'.format(name)) cur.execute(pysql.SQL('DROP DATABASE IF EXISTS') + pysql.Identifier(name))
cur.execute('CREATE DATABASE {}'.format(name)) cur.execute(pysql.SQL('CREATE DATABASE') + pysql.Identifier(name))
conn.close()
monkeypatch.setenv('NOMINATIM_DATABASE_DSN', 'dbname=' + name) monkeypatch.setenv('NOMINATIM_DATABASE_DSN', 'dbname=' + name)
with psycopg.connect(dbname=name) as conn:
with conn.cursor() as cur:
cur.execute('CREATE EXTENSION hstore')
yield name yield name
conn = psycopg2.connect(database='postgres') with psycopg.connect(dbname='postgres', autocommit=True) as conn:
with conn.cursor() as cur:
conn.set_isolation_level(0) cur.execute('DROP DATABASE IF EXISTS {}'.format(name))
with conn.cursor() as cur:
cur.execute('DROP DATABASE IF EXISTS {}'.format(name))
conn.close()
@pytest.fixture @pytest.fixture
@@ -65,11 +63,9 @@ def dsn(temp_db):
@pytest.fixture @pytest.fixture
def temp_db_with_extensions(temp_db): def temp_db_with_extensions(temp_db):
conn = psycopg2.connect(database=temp_db) with psycopg.connect(dbname=temp_db) as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute('CREATE EXTENSION hstore; CREATE EXTENSION postgis;') cur.execute('CREATE EXTENSION postgis')
conn.commit()
conn.close()
return temp_db return temp_db
@@ -77,7 +73,8 @@ def temp_db_with_extensions(temp_db):
def temp_db_conn(temp_db): def temp_db_conn(temp_db):
""" Connection to the test database. """ Connection to the test database.
""" """
with connection.connect('dbname=' + temp_db) as conn: with connection.connect('', autocommit=True, dbname=temp_db) as conn:
connection.register_hstore(conn)
yield conn yield conn
@@ -86,22 +83,25 @@ def temp_db_cursor(temp_db):
""" Connection and cursor towards the test database. The connection will """ Connection and cursor towards the test database. The connection will
be in auto-commit mode. be in auto-commit mode.
""" """
conn = psycopg2.connect('dbname=' + temp_db) with psycopg.connect(dbname=temp_db, autocommit=True, cursor_factory=CursorForTesting) as conn:
conn.set_isolation_level(0) connection.register_hstore(conn)
with conn.cursor(cursor_factory=CursorForTesting) as cur: with conn.cursor() as cur:
yield cur yield cur
conn.close()
@pytest.fixture @pytest.fixture
def table_factory(temp_db_cursor): def table_factory(temp_db_conn):
""" A fixture that creates new SQL tables, potentially filled with """ A fixture that creates new SQL tables, potentially filled with
content. content.
""" """
def mk_table(name, definition='id INT', content=None): def mk_table(name, definition='id INT', content=None):
temp_db_cursor.execute('CREATE TABLE {} ({})'.format(name, definition)) with psycopg.ClientCursor(temp_db_conn) as cur:
if content is not None: cur.execute('CREATE TABLE {} ({})'.format(name, definition))
temp_db_cursor.execute_values("INSERT INTO {} VALUES %s".format(name), content) if content:
sql = pysql.SQL("INSERT INTO {} VALUES ({})")\
.format(pysql.Identifier(name),
pysql.SQL(',').join([pysql.Placeholder() for _ in range(len(content[0]))]))
cur.executemany(sql , content)
return mk_table return mk_table
@@ -168,7 +168,6 @@ def place_row(place_table, temp_db_cursor):
""" A factory for rows in the place table. The table is created as a """ A factory for rows in the place table. The table is created as a
prerequisite to the fixture. prerequisite to the fixture.
""" """
psycopg2.extras.register_hstore(temp_db_cursor)
idseq = itertools.count(1001) idseq = itertools.count(1001)
def _insert(osm_type='N', osm_id=None, cls='amenity', typ='cafe', names=None, def _insert(osm_type='N', osm_id=None, cls='amenity', typ='cafe', names=None,
admin_level=None, address=None, extratags=None, geom=None): admin_level=None, address=None, extratags=None, geom=None):

View File

@@ -5,11 +5,11 @@
# Copyright (C) 2024 by the Nominatim developer community. # Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log. # For a full list of authors see the git log.
""" """
Specialised psycopg2 cursor with shortcut functions useful for testing. Specialised psycopg cursor with shortcut functions useful for testing.
""" """
import psycopg2.extras import psycopg
class CursorForTesting(psycopg2.extras.DictCursor): class CursorForTesting(psycopg.Cursor):
""" Extension to the DictCursor class that provides execution """ Extension to the DictCursor class that provides execution
short-cuts that simplify writing assertions. short-cuts that simplify writing assertions.
""" """
@@ -59,9 +59,3 @@ class CursorForTesting(psycopg2.extras.DictCursor):
return self.scalar('SELECT count(*) FROM ' + table) return self.scalar('SELECT count(*) FROM ' + table)
return self.scalar('SELECT count(*) FROM {} WHERE {}'.format(table, where)) return self.scalar('SELECT count(*) FROM {} WHERE {}'.format(table, where))
def execute_values(self, *args, **kwargs):
""" Execute the execute_values() function on the cursor.
"""
psycopg2.extras.execute_values(self, *args, **kwargs)

View File

@@ -1,113 +0,0 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2024 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Tests for function providing a non-blocking query interface towards PostgreSQL.
"""
from contextlib import closing
import concurrent.futures
import pytest
import psycopg2
from nominatim_db.db.async_connection import DBConnection, DeadlockHandler
@pytest.fixture
def conn(temp_db):
with closing(DBConnection('dbname=' + temp_db)) as connection:
yield connection
@pytest.fixture
def simple_conns(temp_db):
conn1 = psycopg2.connect('dbname=' + temp_db)
conn2 = psycopg2.connect('dbname=' + temp_db)
yield conn1.cursor(), conn2.cursor()
conn1.close()
conn2.close()
def test_simple_query(conn, temp_db_cursor):
conn.connect()
conn.perform('CREATE TABLE foo (id INT)')
conn.wait()
assert temp_db_cursor.table_exists('foo')
def test_wait_for_query(conn):
conn.connect()
conn.perform('SELECT pg_sleep(1)')
assert not conn.is_done()
conn.wait()
def test_bad_query(conn):
conn.connect()
conn.perform('SELECT efasfjsea')
with pytest.raises(psycopg2.ProgrammingError):
conn.wait()
def test_bad_query_ignore(temp_db):
with closing(DBConnection('dbname=' + temp_db, ignore_sql_errors=True)) as conn:
conn.connect()
conn.perform('SELECT efasfjsea')
conn.wait()
def exec_with_deadlock(cur, sql, detector):
with DeadlockHandler(lambda *args: detector.append(1)):
cur.execute(sql)
def test_deadlock(simple_conns):
cur1, cur2 = simple_conns
cur1.execute("""CREATE TABLE t1 (id INT PRIMARY KEY, t TEXT);
INSERT into t1 VALUES (1, 'a'), (2, 'b')""")
cur1.connection.commit()
cur1.execute("UPDATE t1 SET t = 'x' WHERE id = 1")
cur2.execute("UPDATE t1 SET t = 'x' WHERE id = 2")
# This is the tricky part of the test. The first SQL command runs into
# a lock and blocks, so we have to run it in a separate thread. When the
# second deadlocking SQL statement is issued, Postgresql will abort one of
# the two transactions that cause the deadlock. There is no way to tell
# which one of the two. Therefore wrap both in a DeadlockHandler and
# expect that exactly one of the two triggers.
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
deadlock_check = []
try:
future = executor.submit(exec_with_deadlock, cur2,
"UPDATE t1 SET t = 'y' WHERE id = 1",
deadlock_check)
while not future.running():
pass
exec_with_deadlock(cur1, "UPDATE t1 SET t = 'y' WHERE id = 2",
deadlock_check)
finally:
# Whatever happens, make sure the deadlock gets resolved.
cur1.connection.rollback()
future.result()
assert len(deadlock_check) == 1

View File

@@ -8,7 +8,7 @@
Tests for specialised connection and cursor classes. Tests for specialised connection and cursor classes.
""" """
import pytest import pytest
import psycopg2 import psycopg
import nominatim_db.db.connection as nc import nominatim_db.db.connection as nc
@@ -73,7 +73,7 @@ def test_drop_many_tables(db, table_factory):
def test_drop_table_non_existing_force(db): def test_drop_table_non_existing_force(db):
with pytest.raises(psycopg2.ProgrammingError, match='.*does not exist.*'): with pytest.raises(psycopg.ProgrammingError, match='.*does not exist.*'):
nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre', if_exists=False) nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre', if_exists=False)
def test_connection_server_version_tuple(db): def test_connection_server_version_tuple(db):

View File

@@ -8,6 +8,7 @@
Tests for SQL preprocessing. Tests for SQL preprocessing.
""" """
import pytest import pytest
import pytest_asyncio
from nominatim_db.db.sql_preprocessor import SQLPreprocessor from nominatim_db.db.sql_preprocessor import SQLPreprocessor
@@ -54,3 +55,17 @@ def test_load_file_with_params(sql_preprocessor, sql_factory, temp_db_conn, temp
sql_preprocessor.run_sql_file(temp_db_conn, sqlfile, bar='XX', foo='ZZ') sql_preprocessor.run_sql_file(temp_db_conn, sqlfile, bar='XX', foo='ZZ')
assert temp_db_cursor.scalar('SELECT test()') == 'ZZ XX' assert temp_db_cursor.scalar('SELECT test()') == 'ZZ XX'
@pytest.mark.asyncio
async def test_load_parallel_file(dsn, sql_preprocessor, tmp_path, temp_db_cursor):
(tmp_path / 'test.sql').write_text("""
CREATE TABLE foo (a TEXT);
CREATE TABLE foo2(a TEXT);""" +
"\n---\nCREATE TABLE bar (b INT);")
await sql_preprocessor.run_parallel_sql_file(dsn, 'test.sql', num_threads=4)
assert temp_db_cursor.table_exists('foo')
assert temp_db_cursor.table_exists('foo2')
assert temp_db_cursor.table_exists('bar')

View File

@@ -58,103 +58,3 @@ def test_execute_file_with_post_code(dsn, tmp_path, temp_db_cursor):
db_utils.execute_file(dsn, tmpfile, post_code='INSERT INTO test VALUES(23)') db_utils.execute_file(dsn, tmpfile, post_code='INSERT INTO test VALUES(23)')
assert temp_db_cursor.row_set('SELECT * FROM test') == {(23, )} assert temp_db_cursor.row_set('SELECT * FROM test') == {(23, )}
class TestCopyBuffer:
TABLE_NAME = 'copytable'
@pytest.fixture(autouse=True)
def setup_test_table(self, table_factory):
table_factory(self.TABLE_NAME, 'col_a INT, col_b TEXT')
def table_rows(self, cursor):
return cursor.row_set('SELECT * FROM ' + self.TABLE_NAME)
def test_copybuffer_empty(self):
with db_utils.CopyBuffer() as buf:
buf.copy_out(None, "dummy")
def test_all_columns(self, temp_db_cursor):
with db_utils.CopyBuffer() as buf:
buf.add(3, 'hum')
buf.add(None, 'f\\t')
buf.copy_out(temp_db_cursor, self.TABLE_NAME)
assert self.table_rows(temp_db_cursor) == {(3, 'hum'), (None, 'f\\t')}
def test_selected_columns(self, temp_db_cursor):
with db_utils.CopyBuffer() as buf:
buf.add('foo')
buf.copy_out(temp_db_cursor, self.TABLE_NAME,
columns=['col_b'])
assert self.table_rows(temp_db_cursor) == {(None, 'foo')}
def test_reordered_columns(self, temp_db_cursor):
with db_utils.CopyBuffer() as buf:
buf.add('one', 1)
buf.add(' two ', 2)
buf.copy_out(temp_db_cursor, self.TABLE_NAME,
columns=['col_b', 'col_a'])
assert self.table_rows(temp_db_cursor) == {(1, 'one'), (2, ' two ')}
def test_special_characters(self, temp_db_cursor):
with db_utils.CopyBuffer() as buf:
buf.add('foo\tbar')
buf.add('sun\nson')
buf.add('\\N')
buf.copy_out(temp_db_cursor, self.TABLE_NAME,
columns=['col_b'])
assert self.table_rows(temp_db_cursor) == {(None, 'foo\tbar'),
(None, 'sun\nson'),
(None, '\\N')}
class TestCopyBufferJson:
TABLE_NAME = 'copytable'
@pytest.fixture(autouse=True)
def setup_test_table(self, table_factory):
table_factory(self.TABLE_NAME, 'col_a INT, col_b JSONB')
def table_rows(self, cursor):
cursor.execute('SELECT * FROM ' + self.TABLE_NAME)
results = {k: v for k,v in cursor}
assert len(results) == cursor.rowcount
return results
def test_json_object(self, temp_db_cursor):
with db_utils.CopyBuffer() as buf:
buf.add(1, json.dumps({'test': 'value', 'number': 1}))
buf.copy_out(temp_db_cursor, self.TABLE_NAME)
assert self.table_rows(temp_db_cursor) == \
{1: {'test': 'value', 'number': 1}}
def test_json_object_special_chras(self, temp_db_cursor):
with db_utils.CopyBuffer() as buf:
buf.add(1, json.dumps({'te\tst': 'va\nlue', 'nu"mber': None}))
buf.copy_out(temp_db_cursor, self.TABLE_NAME)
assert self.table_rows(temp_db_cursor) == \
{1: {'te\tst': 'va\nlue', 'nu"mber': None}}

View File

@@ -9,6 +9,7 @@ Tests for running the indexing.
""" """
import itertools import itertools
import pytest import pytest
import pytest_asyncio
from nominatim_db.indexer import indexer from nominatim_db.indexer import indexer
from nominatim_db.tokenizer import factory from nominatim_db.tokenizer import factory
@@ -21,9 +22,8 @@ class IndexerTestDB:
self.postcode_id = itertools.count(700000) self.postcode_id = itertools.count(700000)
self.conn = conn self.conn = conn
self.conn.set_isolation_level(0) self.conn.autocimmit = True
with self.conn.cursor() as cur: with self.conn.cursor() as cur:
cur.execute('CREATE EXTENSION hstore')
cur.execute("""CREATE TABLE placex (place_id BIGINT, cur.execute("""CREATE TABLE placex (place_id BIGINT,
name HSTORE, name HSTORE,
class TEXT, class TEXT,
@@ -156,7 +156,8 @@ def test_tokenizer(tokenizer_mock, project_env):
@pytest.mark.parametrize("threads", [1, 15]) @pytest.mark.parametrize("threads", [1, 15])
def test_index_all_by_rank(test_db, threads, test_tokenizer): @pytest.mark.asyncio
async def test_index_all_by_rank(test_db, threads, test_tokenizer):
for rank in range(31): for rank in range(31):
test_db.add_place(rank_address=rank, rank_search=rank) test_db.add_place(rank_address=rank, rank_search=rank)
test_db.add_osmline() test_db.add_osmline()
@@ -165,7 +166,7 @@ def test_index_all_by_rank(test_db, threads, test_tokenizer):
assert test_db.osmline_unindexed() == 1 assert test_db.osmline_unindexed() == 1
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads) idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads)
idx.index_by_rank(0, 30) await idx.index_by_rank(0, 30)
assert test_db.placex_unindexed() == 0 assert test_db.placex_unindexed() == 0
assert test_db.osmline_unindexed() == 0 assert test_db.osmline_unindexed() == 0
@@ -190,7 +191,8 @@ def test_index_all_by_rank(test_db, threads, test_tokenizer):
@pytest.mark.parametrize("threads", [1, 15]) @pytest.mark.parametrize("threads", [1, 15])
def test_index_partial_without_30(test_db, threads, test_tokenizer): @pytest.mark.asyncio
async def test_index_partial_without_30(test_db, threads, test_tokenizer):
for rank in range(31): for rank in range(31):
test_db.add_place(rank_address=rank, rank_search=rank) test_db.add_place(rank_address=rank, rank_search=rank)
test_db.add_osmline() test_db.add_osmline()
@@ -200,7 +202,7 @@ def test_index_partial_without_30(test_db, threads, test_tokenizer):
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', idx = indexer.Indexer('dbname=test_nominatim_python_unittest',
test_tokenizer, threads) test_tokenizer, threads)
idx.index_by_rank(4, 15) await idx.index_by_rank(4, 15)
assert test_db.placex_unindexed() == 19 assert test_db.placex_unindexed() == 19
assert test_db.osmline_unindexed() == 1 assert test_db.osmline_unindexed() == 1
@@ -211,7 +213,8 @@ def test_index_partial_without_30(test_db, threads, test_tokenizer):
@pytest.mark.parametrize("threads", [1, 15]) @pytest.mark.parametrize("threads", [1, 15])
def test_index_partial_with_30(test_db, threads, test_tokenizer): @pytest.mark.asyncio
async def test_index_partial_with_30(test_db, threads, test_tokenizer):
for rank in range(31): for rank in range(31):
test_db.add_place(rank_address=rank, rank_search=rank) test_db.add_place(rank_address=rank, rank_search=rank)
test_db.add_osmline() test_db.add_osmline()
@@ -220,7 +223,7 @@ def test_index_partial_with_30(test_db, threads, test_tokenizer):
assert test_db.osmline_unindexed() == 1 assert test_db.osmline_unindexed() == 1
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads) idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads)
idx.index_by_rank(28, 30) await idx.index_by_rank(28, 30)
assert test_db.placex_unindexed() == 27 assert test_db.placex_unindexed() == 27
assert test_db.osmline_unindexed() == 0 assert test_db.osmline_unindexed() == 0
@@ -230,7 +233,8 @@ def test_index_partial_with_30(test_db, threads, test_tokenizer):
WHERE indexed_status = 0 AND rank_address between 1 and 27""") == 0 WHERE indexed_status = 0 AND rank_address between 1 and 27""") == 0
@pytest.mark.parametrize("threads", [1, 15]) @pytest.mark.parametrize("threads", [1, 15])
def test_index_boundaries(test_db, threads, test_tokenizer): @pytest.mark.asyncio
async def test_index_boundaries(test_db, threads, test_tokenizer):
for rank in range(4, 10): for rank in range(4, 10):
test_db.add_admin(rank_address=rank, rank_search=rank) test_db.add_admin(rank_address=rank, rank_search=rank)
for rank in range(31): for rank in range(31):
@@ -241,7 +245,7 @@ def test_index_boundaries(test_db, threads, test_tokenizer):
assert test_db.osmline_unindexed() == 1 assert test_db.osmline_unindexed() == 1
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads) idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads)
idx.index_boundaries(0, 30) await idx.index_boundaries(0, 30)
assert test_db.placex_unindexed() == 31 assert test_db.placex_unindexed() == 31
assert test_db.osmline_unindexed() == 1 assert test_db.osmline_unindexed() == 1
@@ -252,21 +256,23 @@ def test_index_boundaries(test_db, threads, test_tokenizer):
@pytest.mark.parametrize("threads", [1, 15]) @pytest.mark.parametrize("threads", [1, 15])
def test_index_postcodes(test_db, threads, test_tokenizer): @pytest.mark.asyncio
async def test_index_postcodes(test_db, threads, test_tokenizer):
for postcode in range(1000): for postcode in range(1000):
test_db.add_postcode('de', postcode) test_db.add_postcode('de', postcode)
for postcode in range(32000, 33000): for postcode in range(32000, 33000):
test_db.add_postcode('us', postcode) test_db.add_postcode('us', postcode)
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads) idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads)
idx.index_postcodes() await idx.index_postcodes()
assert test_db.scalar("""SELECT count(*) FROM location_postcode assert test_db.scalar("""SELECT count(*) FROM location_postcode
WHERE indexed_status != 0""") == 0 WHERE indexed_status != 0""") == 0
@pytest.mark.parametrize("analyse", [True, False]) @pytest.mark.parametrize("analyse", [True, False])
def test_index_full(test_db, analyse, test_tokenizer): @pytest.mark.asyncio
async def test_index_full(test_db, analyse, test_tokenizer):
for rank in range(4, 10): for rank in range(4, 10):
test_db.add_admin(rank_address=rank, rank_search=rank) test_db.add_admin(rank_address=rank, rank_search=rank)
for rank in range(31): for rank in range(31):
@@ -276,22 +282,9 @@ def test_index_full(test_db, analyse, test_tokenizer):
test_db.add_postcode('de', postcode) test_db.add_postcode('de', postcode)
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, 4) idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, 4)
idx.index_full(analyse=analyse) await idx.index_full(analyse=analyse)
assert test_db.placex_unindexed() == 0 assert test_db.placex_unindexed() == 0
assert test_db.osmline_unindexed() == 0 assert test_db.osmline_unindexed() == 0
assert test_db.scalar("""SELECT count(*) FROM location_postcode assert test_db.scalar("""SELECT count(*) FROM location_postcode
WHERE indexed_status != 0""") == 0 WHERE indexed_status != 0""") == 0
@pytest.mark.parametrize("threads", [1, 15])
def test_index_reopen_connection(test_db, threads, monkeypatch, test_tokenizer):
monkeypatch.setattr(indexer.WorkerPool, "REOPEN_CONNECTIONS_AFTER", 15)
for _ in range(1000):
test_db.add_place(rank_address=30, rank_search=30)
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads)
idx.index_by_rank(28, 30)
assert test_db.placex_unindexed() == 0

View File

@@ -36,9 +36,9 @@ class MockIcuWordTable:
with self.conn.cursor() as cur: with self.conn.cursor() as cur:
cur.execute("""INSERT INTO word (word_token, type, word, info) cur.execute("""INSERT INTO word (word_token, type, word, info)
VALUES (%s, 'S', %s, VALUES (%s, 'S', %s,
json_build_object('class', %s, json_build_object('class', %s::text,
'type', %s, 'type', %s::text,
'op', %s)) 'op', %s::text))
""", (word_token, word, cls, typ, oper)) """, (word_token, word, cls, typ, oper))
self.conn.commit() self.conn.commit()
@@ -71,7 +71,7 @@ class MockIcuWordTable:
word = word_tokens[0] word = word_tokens[0]
for token in word_tokens: for token in word_tokens:
cur.execute("""INSERT INTO word (word_id, word_token, type, word, info) cur.execute("""INSERT INTO word (word_id, word_token, type, word, info)
VALUES (%s, %s, 'H', %s, jsonb_build_object('lookup', %s)) VALUES (%s, %s, 'H', %s, jsonb_build_object('lookup', %s::text))
""", (word_id, token, word, word_tokens[0])) """, (word_id, token, word, word_tokens[0]))
self.conn.commit() self.conn.commit()

View File

@@ -68,7 +68,7 @@ class MockLegacyWordTable:
def get_special(self): def get_special(self):
with self.conn.cursor() as cur: with self.conn.cursor() as cur:
cur.execute("""SELECT word_token, word, class, type, operator cur.execute("""SELECT word_token, word, class as cls, type, operator
FROM word WHERE class != 'place'""") FROM word WHERE class != 'place'""")
result = set((tuple(row) for row in cur)) result = set((tuple(row) for row in cur))
assert len(result) == cur.rowcount, "Word table has duplicates." assert len(result) == cur.rowcount, "Word table has duplicates."

View File

@@ -9,8 +9,6 @@ Custom mocks for testing.
""" """
import itertools import itertools
import psycopg2.extras
from nominatim_db.db import properties from nominatim_db.db import properties
# This must always point to the mock word table for the default tokenizer. # This must always point to the mock word table for the default tokenizer.
@@ -56,7 +54,6 @@ class MockPlacexTable:
admin_level=None, address=None, extratags=None, geom='POINT(10 4)', admin_level=None, address=None, extratags=None, geom='POINT(10 4)',
country=None, housenumber=None, rank_search=30): country=None, housenumber=None, rank_search=30):
with self.conn.cursor() as cur: with self.conn.cursor() as cur:
psycopg2.extras.register_hstore(cur)
cur.execute("""INSERT INTO placex (place_id, osm_type, osm_id, class, cur.execute("""INSERT INTO placex (place_id, osm_type, osm_id, class,
type, name, admin_level, address, type, name, admin_level, address,
housenumber, rank_search, housenumber, rank_search,

View File

@@ -8,10 +8,11 @@
Tests for functions to import a new database. Tests for functions to import a new database.
""" """
from pathlib import Path from pathlib import Path
from contextlib import closing
import pytest import pytest
import psycopg2 import pytest_asyncio
import psycopg
from psycopg import sql as pysql
from nominatim_db.tools import database_import from nominatim_db.tools import database_import
from nominatim_db.errors import UsageError from nominatim_db.errors import UsageError
@@ -21,10 +22,7 @@ class TestDatabaseSetup:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup_nonexistant_db(self): def setup_nonexistant_db(self):
conn = psycopg2.connect(database='postgres') with psycopg.connect(dbname='postgres', autocommit=True) as conn:
try:
conn.set_isolation_level(0)
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(f'DROP DATABASE IF EXISTS {self.DBNAME}') cur.execute(f'DROP DATABASE IF EXISTS {self.DBNAME}')
@@ -32,22 +30,17 @@ class TestDatabaseSetup:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(f'DROP DATABASE IF EXISTS {self.DBNAME}') cur.execute(f'DROP DATABASE IF EXISTS {self.DBNAME}')
finally:
conn.close()
@pytest.fixture @pytest.fixture
def cursor(self): def cursor(self):
conn = psycopg2.connect(database=self.DBNAME) with psycopg.connect(dbname=self.DBNAME) as conn:
try:
with conn.cursor() as cur: with conn.cursor() as cur:
yield cur yield cur
finally:
conn.close()
def conn(self): def conn(self):
return closing(psycopg2.connect(database=self.DBNAME)) return psycopg.connect(dbname=self.DBNAME)
def test_setup_skeleton(self): def test_setup_skeleton(self):
@@ -178,18 +171,19 @@ def test_truncate_database_tables(temp_db_conn, temp_db_cursor, table_factory, w
@pytest.mark.parametrize("threads", (1, 5)) @pytest.mark.parametrize("threads", (1, 5))
def test_load_data(dsn, place_row, placex_table, osmline_table, @pytest.mark.asyncio
async def test_load_data(dsn, place_row, placex_table, osmline_table,
temp_db_cursor, threads): temp_db_cursor, threads):
for func in ('precompute_words', 'getorcreate_housenumber_id', 'make_standard_name'): for func in ('precompute_words', 'getorcreate_housenumber_id', 'make_standard_name'):
temp_db_cursor.execute(f"""CREATE FUNCTION {func} (src TEXT) temp_db_cursor.execute(pysql.SQL("""CREATE FUNCTION {} (src TEXT)
RETURNS TEXT AS $$ SELECT 'a'::TEXT $$ LANGUAGE SQL RETURNS TEXT AS $$ SELECT 'a'::TEXT $$ LANGUAGE SQL
""") """).format(pysql.Identifier(func)))
for oid in range(100, 130): for oid in range(100, 130):
place_row(osm_id=oid) place_row(osm_id=oid)
place_row(osm_type='W', osm_id=342, cls='place', typ='houses', place_row(osm_type='W', osm_id=342, cls='place', typ='houses',
geom='SRID=4326;LINESTRING(0 0, 10 10)') geom='SRID=4326;LINESTRING(0 0, 10 10)')
database_import.load_data(dsn, threads) await database_import.load_data(dsn, threads)
assert temp_db_cursor.table_rows('placex') == 30 assert temp_db_cursor.table_rows('placex') == 30
assert temp_db_cursor.table_rows('location_property_osmline') == 1 assert temp_db_cursor.table_rows('location_property_osmline') == 1
@@ -241,11 +235,12 @@ class TestSetupSQL:
@pytest.mark.parametrize("drop", [True, False]) @pytest.mark.parametrize("drop", [True, False])
def test_create_search_indices(self, temp_db_conn, temp_db_cursor, drop): @pytest.mark.asyncio
async def test_create_search_indices(self, temp_db_conn, temp_db_cursor, drop):
self.write_sql('indices.sql', self.write_sql('indices.sql',
"""CREATE FUNCTION test() RETURNS bool """CREATE FUNCTION test() RETURNS bool
AS $$ SELECT {{drop}} $$ LANGUAGE SQL""") AS $$ SELECT {{drop}} $$ LANGUAGE SQL""")
database_import.create_search_indices(temp_db_conn, self.config, drop) await database_import.create_search_indices(temp_db_conn, self.config, drop)
temp_db_cursor.scalar('SELECT test()') == drop temp_db_cursor.scalar('SELECT test()') == drop

View File

@@ -8,7 +8,6 @@
Tests for migration functions Tests for migration functions
""" """
import pytest import pytest
import psycopg2.extras
from nominatim_db.tools import migration from nominatim_db.tools import migration
from nominatim_db.errors import UsageError from nominatim_db.errors import UsageError
@@ -44,7 +43,6 @@ def test_no_migration_old_versions(temp_db_with_extensions, table_factory, def_c
def test_set_up_migration_for_36(temp_db_with_extensions, temp_db_cursor, def test_set_up_migration_for_36(temp_db_with_extensions, temp_db_cursor,
table_factory, def_config, monkeypatch, table_factory, def_config, monkeypatch,
postprocess_mock): postprocess_mock):
psycopg2.extras.register_hstore(temp_db_cursor)
# don't actually run any migration, except the property table creation # don't actually run any migration, except the property table creation
monkeypatch.setattr(migration, '_MIGRATION_FUNCTIONS', monkeypatch.setattr(migration, '_MIGRATION_FUNCTIONS',
[((3, 5, 0, 99), migration.add_nominatim_property_table)]) [((3, 5, 0, 99), migration.add_nominatim_property_table)])

View File

@@ -47,7 +47,7 @@ class MockPostcodeTable:
country_code, postcode, country_code, postcode,
geometry) geometry)
VALUES (nextval('seq_place'), 1, %s, %s, VALUES (nextval('seq_place'), 1, %s, %s,
'SRID=4326;POINT(%s %s)')""", ST_SetSRID(ST_MakePoint(%s, %s), 4326))""",
(country, postcode, x, y)) (country, postcode, x, y))
self.conn.commit() self.conn.commit()

View File

@@ -11,6 +11,7 @@ import tarfile
from textwrap import dedent from textwrap import dedent
import pytest import pytest
import pytest_asyncio
from nominatim_db.db.connection import execute_scalar from nominatim_db.db.connection import execute_scalar
from nominatim_db.tools import tiger_data, freeze from nominatim_db.tools import tiger_data, freeze
@@ -76,82 +77,91 @@ def csv_factory(tmp_path):
@pytest.mark.parametrize("threads", (1, 5)) @pytest.mark.parametrize("threads", (1, 5))
def test_add_tiger_data(def_config, src_dir, tiger_table, tokenizer_mock, threads): @pytest.mark.asyncio
tiger_data.add_tiger_data(str(src_dir / 'test' / 'testdb' / 'tiger'), async def test_add_tiger_data(def_config, src_dir, tiger_table, tokenizer_mock, threads):
def_config, threads, tokenizer_mock()) await tiger_data.add_tiger_data(str(src_dir / 'test' / 'testdb' / 'tiger'),
def_config, threads, tokenizer_mock())
assert tiger_table.count() == 6213 assert tiger_table.count() == 6213
def test_add_tiger_data_database_frozen(def_config, temp_db_conn, tiger_table, tokenizer_mock, @pytest.mark.asyncio
async def test_add_tiger_data_database_frozen(def_config, temp_db_conn, tiger_table, tokenizer_mock,
tmp_path): tmp_path):
freeze.drop_update_tables(temp_db_conn) freeze.drop_update_tables(temp_db_conn)
with pytest.raises(UsageError) as excinfo: with pytest.raises(UsageError) as excinfo:
tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock()) await tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
assert "database frozen" in str(excinfo.value) assert "database frozen" in str(excinfo.value)
assert tiger_table.count() == 0 assert tiger_table.count() == 0
def test_add_tiger_data_no_files(def_config, tiger_table, tokenizer_mock,
@pytest.mark.asyncio
async def test_add_tiger_data_no_files(def_config, tiger_table, tokenizer_mock,
tmp_path): tmp_path):
tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock()) await tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
assert tiger_table.count() == 0 assert tiger_table.count() == 0
def test_add_tiger_data_bad_file(def_config, tiger_table, tokenizer_mock, @pytest.mark.asyncio
async def test_add_tiger_data_bad_file(def_config, tiger_table, tokenizer_mock,
tmp_path): tmp_path):
sqlfile = tmp_path / '1010.csv' sqlfile = tmp_path / '1010.csv'
sqlfile.write_text("""Random text""") sqlfile.write_text("""Random text""")
tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock()) await tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
assert tiger_table.count() == 0 assert tiger_table.count() == 0
def test_add_tiger_data_hnr_nan(def_config, tiger_table, tokenizer_mock, @pytest.mark.asyncio
async def test_add_tiger_data_hnr_nan(def_config, tiger_table, tokenizer_mock,
csv_factory, tmp_path): csv_factory, tmp_path):
csv_factory('file1', hnr_from=99) csv_factory('file1', hnr_from=99)
csv_factory('file2', hnr_from='L12') csv_factory('file2', hnr_from='L12')
csv_factory('file3', hnr_to='12.4') csv_factory('file3', hnr_to='12.4')
tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock()) await tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
assert tiger_table.count() == 1 assert tiger_table.count() == 1
assert tiger_table.row()['start'] == 99 assert tiger_table.row().start == 99
@pytest.mark.parametrize("threads", (1, 5)) @pytest.mark.parametrize("threads", (1, 5))
def test_add_tiger_data_tarfile(def_config, tiger_table, tokenizer_mock, @pytest.mark.asyncio
async def test_add_tiger_data_tarfile(def_config, tiger_table, tokenizer_mock,
tmp_path, src_dir, threads): tmp_path, src_dir, threads):
tar = tarfile.open(str(tmp_path / 'sample.tar.gz'), "w:gz") tar = tarfile.open(str(tmp_path / 'sample.tar.gz'), "w:gz")
tar.add(str(src_dir / 'test' / 'testdb' / 'tiger' / '01001.csv')) tar.add(str(src_dir / 'test' / 'testdb' / 'tiger' / '01001.csv'))
tar.close() tar.close()
tiger_data.add_tiger_data(str(tmp_path / 'sample.tar.gz'), def_config, threads, await tiger_data.add_tiger_data(str(tmp_path / 'sample.tar.gz'), def_config, threads,
tokenizer_mock()) tokenizer_mock())
assert tiger_table.count() == 6213 assert tiger_table.count() == 6213
def test_add_tiger_data_bad_tarfile(def_config, tiger_table, tokenizer_mock, @pytest.mark.asyncio
async def test_add_tiger_data_bad_tarfile(def_config, tiger_table, tokenizer_mock,
tmp_path): tmp_path):
tarfile = tmp_path / 'sample.tar.gz' tarfile = tmp_path / 'sample.tar.gz'
tarfile.write_text("""Random text""") tarfile.write_text("""Random text""")
with pytest.raises(UsageError): with pytest.raises(UsageError):
tiger_data.add_tiger_data(str(tarfile), def_config, 1, tokenizer_mock()) await tiger_data.add_tiger_data(str(tarfile), def_config, 1, tokenizer_mock())
def test_add_tiger_data_empty_tarfile(def_config, tiger_table, tokenizer_mock, @pytest.mark.asyncio
async def test_add_tiger_data_empty_tarfile(def_config, tiger_table, tokenizer_mock,
tmp_path): tmp_path):
tar = tarfile.open(str(tmp_path / 'sample.tar.gz'), "w:gz") tar = tarfile.open(str(tmp_path / 'sample.tar.gz'), "w:gz")
tar.add(__file__) tar.add(__file__)
tar.close() tar.close()
tiger_data.add_tiger_data(str(tmp_path / 'sample.tar.gz'), def_config, 1, await tiger_data.add_tiger_data(str(tmp_path / 'sample.tar.gz'), def_config, 1,
tokenizer_mock()) tokenizer_mock())
assert tiger_table.count() == 0 assert tiger_table.count() == 0