make DB helper functions free functions

Also changes the drop function so that it can drop multiple tables
at once.
This commit is contained in:
Sarah Hoffmann
2024-07-02 15:15:50 +02:00
parent 71249bd94a
commit 3742fa2929
30 changed files with 347 additions and 364 deletions

View File

@@ -12,7 +12,7 @@ import argparse
import random import random
from ..errors import UsageError from ..errors import UsageError
from ..db.connection import connect from ..db.connection import connect, table_exists
from .args import NominatimArgs from .args import NominatimArgs
# Do not repeat documentation of subcommand classes. # Do not repeat documentation of subcommand classes.
@@ -115,7 +115,7 @@ class AdminFuncs:
tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config) tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
with connect(args.config.get_libpq_dsn()) as conn: with connect(args.config.get_libpq_dsn()) as conn:
if conn.table_exists('search_name'): if table_exists(conn, 'search_name'):
words = tokenizer.most_frequent_words(conn, 1000) words = tokenizer.most_frequent_words(conn, 1000)
else: else:
words = [] words = []

View File

@@ -13,7 +13,7 @@ import logging
from pathlib import Path from pathlib import Path
from ..config import Configuration from ..config import Configuration
from ..db.connection import connect from ..db.connection import connect, table_exists
from ..tokenizer.base import AbstractTokenizer from ..tokenizer.base import AbstractTokenizer
from .args import NominatimArgs from .args import NominatimArgs
@@ -124,7 +124,7 @@ class UpdateRefresh:
with connect(args.config.get_libpq_dsn()) as conn: with connect(args.config.get_libpq_dsn()) as conn:
# If the table did not exist before, then the importance code # If the table did not exist before, then the importance code
# needs to be enabled. # needs to be enabled.
if not conn.table_exists('secondary_importance'): if not table_exists(conn, 'secondary_importance'):
args.functions = True args.functions = True
LOG.warning('Import secondary importance raster data from %s', args.project_dir) LOG.warning('Import secondary importance raster data from %s', args.project_dir)

View File

@@ -9,10 +9,9 @@ Functions for importing and managing static country information.
""" """
from typing import Dict, Any, Iterable, Tuple, Optional, Container, overload from typing import Dict, Any, Iterable, Tuple, Optional, Container, overload
from pathlib import Path from pathlib import Path
import psycopg2.extras
from ..db import utils as db_utils from ..db import utils as db_utils
from ..db.connection import connect, Connection from ..db.connection import connect, Connection, register_hstore
from ..errors import UsageError from ..errors import UsageError
from ..config import Configuration from ..config import Configuration
from ..tokenizer.base import AbstractTokenizer from ..tokenizer.base import AbstractTokenizer
@@ -129,8 +128,8 @@ def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = Fals
params.append((ccode, props['names'], lang, partition)) params.append((ccode, props['names'], lang, partition))
with connect(dsn) as conn: with connect(dsn) as conn:
register_hstore(conn)
with conn.cursor() as cur: with conn.cursor() as cur:
psycopg2.extras.register_hstore(cur)
cur.execute( cur.execute(
""" CREATE TABLE public.country_name ( """ CREATE TABLE public.country_name (
country_code character varying(2), country_code character varying(2),
@@ -157,8 +156,8 @@ def create_country_names(conn: Connection, tokenizer: AbstractTokenizer,
return ':' not in key or not languages or \ return ':' not in key or not languages or \
key[key.index(':') + 1:] in languages key[key.index(':') + 1:] in languages
register_hstore(conn)
with conn.cursor() as cur: with conn.cursor() as cur:
psycopg2.extras.register_hstore(cur)
cur.execute("""SELECT country_code, name FROM country_name cur.execute("""SELECT country_code, name FROM country_name
WHERE country_code is not null""") WHERE country_code is not null""")

View File

@@ -7,7 +7,8 @@
""" """
Specialised connection and cursor functions. Specialised connection and cursor functions.
""" """
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload,\
Tuple, Iterable
import contextlib import contextlib
import logging import logging
import os import os
@@ -46,37 +47,6 @@ class Cursor(psycopg2.extras.DictCursor):
psycopg2.extras.execute_values(self, sql, argslist, template=template) psycopg2.extras.execute_values(self, sql, argslist, template=template)
def scalar(self, sql: Query, args: Any = None) -> Any:
""" Execute query that returns a single value. The value is returned.
If the query yields more than one row, a ValueError is raised.
"""
self.execute(sql, args)
if self.rowcount != 1:
raise RuntimeError("Query did not return a single row.")
result = self.fetchone()
assert result is not None
return result[0]
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
""" Drop the table with the given name.
Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored. If 'cascade' is set
to True then all dependent tables are deleted as well.
"""
sql = 'DROP TABLE '
if if_exists:
sql += 'IF EXISTS '
sql += '{}'
if cascade:
sql += ' CASCADE'
self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
class Connection(psycopg2.extensions.connection): class Connection(psycopg2.extensions.connection):
""" A connection that provides the specialised cursor by default and """ A connection that provides the specialised cursor by default and
adds convenience functions for administrating the database. adds convenience functions for administrating the database.
@@ -99,80 +69,105 @@ class Connection(psycopg2.extensions.connection):
return super().cursor(cursor_factory=cursor_factory, **kwargs) return super().cursor(cursor_factory=cursor_factory, **kwargs)
def table_exists(self, table: str) -> bool: def execute_scalar(conn: Connection, sql: Query, args: Any = None) -> Any:
""" Check that a table with the given name exists in the database. """ Execute query that returns a single value. The value is returned.
""" If the query yields more than one row, a ValueError is raised.
with self.cursor() as cur: """
num = cur.scalar("""SELECT count(*) FROM pg_tables with conn.cursor() as cur:
WHERE tablename = %s and schemaname = 'public'""", (table, )) cur.execute(sql, args)
return num == 1 if isinstance(num, int) else False
if cur.rowcount != 1:
raise RuntimeError("Query did not return a single row.")
result = cur.fetchone()
assert result is not None
return result[0]
def table_has_column(self, table: str, column: str) -> bool: def table_exists(conn: Connection, table: str) -> bool:
""" Check if the table 'table' exists and has a column with name 'column'. """ Check that a table with the given name exists in the database.
""" """
with self.cursor() as cur: num = execute_scalar(conn,
has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns """SELECT count(*) FROM pg_tables
WHERE table_name = %s WHERE tablename = %s and schemaname = 'public'""", (table, ))
and column_name = %s""", return num == 1 if isinstance(num, int) else False
(table, column))
return has_column > 0 if isinstance(has_column, int) else False
def index_exists(self, index: str, table: Optional[str] = None) -> bool: def table_has_column(conn: Connection, table: str, column: str) -> bool:
""" Check that an index with the given name exists in the database. """ Check if the table 'table' exists and has a column with name 'column'.
If table is not None then the index must relate to the given """
table. has_column = execute_scalar(conn,
""" """SELECT count(*) FROM information_schema.columns
with self.cursor() as cur: WHERE table_name = %s and column_name = %s""",
cur.execute("""SELECT tablename FROM pg_indexes (table, column))
WHERE indexname = %s and schemaname = 'public'""", (index, )) return has_column > 0 if isinstance(has_column, int) else False
if cur.rowcount == 0:
def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> bool:
""" Check that an index with the given name exists in the database.
If table is not None then the index must relate to the given
table.
"""
with conn.cursor() as cur:
cur.execute("""SELECT tablename FROM pg_indexes
WHERE indexname = %s and schemaname = 'public'""", (index, ))
if cur.rowcount == 0:
return False
if table is not None:
row = cur.fetchone()
if row is None or not isinstance(row[0], str):
return False return False
return row[0] == table
if table is not None: return True
row = cur.fetchone()
if row is None or not isinstance(row[0], str):
return False
return row[0] == table
return True def drop_tables(conn: Connection, *names: str,
if_exists: bool = True, cascade: bool = False) -> None:
""" Drop one or more tables with the given names.
Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored. `cascade` will cause
depended objects to be dropped as well.
The caller needs to take care of committing the change.
"""
sql = pysql.SQL('DROP TABLE%s{}%s' % (
' IF EXISTS ' if if_exists else ' ',
' CASCADE' if cascade else ''))
with conn.cursor() as cur:
for name in names:
cur.execute(sql.format(pysql.Identifier(name)))
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None: def server_version_tuple(conn: Connection) -> Tuple[int, int]:
""" Drop the table with the given name. """ Return the server version as a tuple of (major, minor).
Set `if_exists` to False if a non-existent table should raise Converts correctly for pre-10 and post-10 PostgreSQL versions.
an exception instead of just being ignored. """
""" version = conn.server_version
with self.cursor() as cur: if version < 100000:
cur.drop_table(name, if_exists, cascade) return (int(version / 10000), int((version % 10000) / 100))
self.commit()
return (int(version / 10000), version % 10000)
def server_version_tuple(self) -> Tuple[int, int]: def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
""" Return the server version as a tuple of (major, minor). """ Return the postgis version installed in the database as a
Converts correctly for pre-10 and post-10 PostgreSQL versions. tuple of (major, minor). Assumes that the PostGIS extension
""" has been installed already.
version = self.server_version """
if version < 100000: version = execute_scalar(conn, 'SELECT postgis_lib_version()')
return (int(version / 10000), int((version % 10000) / 100))
return (int(version / 10000), version % 10000) version_parts = version.split('.')
if len(version_parts) < 2:
raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
return (int(version_parts[0]), int(version_parts[1]))
def postgis_version_tuple(self) -> Tuple[int, int]: def register_hstore(conn: Connection) -> None:
""" Return the postgis version installed in the database as a """ Register the hstore type with psycopg for the connection.
tuple of (major, minor). Assumes that the PostGIS extension """
has been installed already. psycopg2.extras.register_hstore(conn)
"""
with self.cursor() as cur:
version = cur.scalar('SELECT postgis_lib_version()')
version_parts = version.split('.')
if len(version_parts) < 2:
raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
return (int(version_parts[0]), int(version_parts[1]))
class ConnectionContext(ContextManager[Connection]): class ConnectionContext(ContextManager[Connection]):

View File

@@ -9,7 +9,7 @@ Query and access functions for the in-database property table.
""" """
from typing import Optional, cast from typing import Optional, cast
from .connection import Connection from .connection import Connection, table_exists
def set_property(conn: Connection, name: str, value: str) -> None: def set_property(conn: Connection, name: str, value: str) -> None:
""" Add or replace the property with the given name. """ Add or replace the property with the given name.
@@ -31,7 +31,7 @@ def get_property(conn: Connection, name: str) -> Optional[str]:
""" Return the current value of the given property or None if the property """ Return the current value of the given property or None if the property
is not set. is not set.
""" """
if not conn.table_exists('nominatim_properties'): if not table_exists(conn, 'nominatim_properties'):
return None return None
with conn.cursor() as cur: with conn.cursor() as cur:

View File

@@ -10,7 +10,7 @@ Preprocessing of SQL files.
from typing import Set, Dict, Any, cast from typing import Set, Dict, Any, cast
import jinja2 import jinja2
from .connection import Connection from .connection import Connection, server_version_tuple, postgis_version_tuple
from .async_connection import WorkerPool from .async_connection import WorkerPool
from ..config import Configuration from ..config import Configuration
@@ -66,8 +66,8 @@ def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
""" Set up a dictionary with various optional Postgresql/Postgis features that """ Set up a dictionary with various optional Postgresql/Postgis features that
depend on the database version. depend on the database version.
""" """
pg_version = conn.server_version_tuple() pg_version = server_version_tuple(conn)
postgis_version = conn.postgis_version_tuple() postgis_version = postgis_version_tuple(conn)
pg11plus = pg_version >= (11, 0, 0) pg11plus = pg_version >= (11, 0, 0)
ps3 = postgis_version >= (3, 0) ps3 = postgis_version >= (3, 0)
return { return {

View File

@@ -12,7 +12,7 @@ import datetime as dt
import logging import logging
import re import re
from .connection import Connection from .connection import Connection, table_exists, execute_scalar
from ..utils.url_utils import get_url from ..utils.url_utils import get_url
from ..errors import UsageError from ..errors import UsageError
from ..typing import TypedDict from ..typing import TypedDict
@@ -34,7 +34,7 @@ def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetim
data base. data base.
""" """
# If there is a date from osm2pgsql available, use that. # If there is a date from osm2pgsql available, use that.
if conn.table_exists('osm2pgsql_properties'): if table_exists(conn, 'osm2pgsql_properties'):
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(""" SELECT value FROM osm2pgsql_properties cur.execute(""" SELECT value FROM osm2pgsql_properties
WHERE property = 'current_timestamp' """) WHERE property = 'current_timestamp' """)
@@ -47,15 +47,14 @@ def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetim
raise UsageError("Cannot determine database date from data in offline mode.") raise UsageError("Cannot determine database date from data in offline mode.")
# Else, find the node with the highest ID in the database # Else, find the node with the highest ID in the database
with conn.cursor() as cur: if table_exists(conn, 'place'):
if conn.table_exists('place'): osmid = execute_scalar(conn, "SELECT max(osm_id) FROM place WHERE osm_type='N'")
osmid = cur.scalar("SELECT max(osm_id) FROM place WHERE osm_type='N'") else:
else: osmid = execute_scalar(conn, "SELECT max(osm_id) FROM placex WHERE osm_type='N'")
osmid = cur.scalar("SELECT max(osm_id) FROM placex WHERE osm_type='N'")
if osmid is None: if osmid is None:
LOG.fatal("No data found in the database.") LOG.fatal("No data found in the database.")
raise UsageError("No data found in the database.") raise UsageError("No data found in the database.")
LOG.info("Using node id %d for timestamp lookup", osmid) LOG.info("Using node id %d for timestamp lookup", osmid)
# Get the node from the API to find the timestamp when it was created. # Get the node from the API to find the timestamp when it was created.

View File

@@ -15,7 +15,7 @@ import psycopg2.extras
from ..typing import DictCursorResults from ..typing import DictCursorResults
from ..db.async_connection import DBConnection, WorkerPool from ..db.async_connection import DBConnection, WorkerPool
from ..db.connection import connect, Connection, Cursor from ..db.connection import connect, Connection, Cursor, execute_scalar, register_hstore
from ..tokenizer.base import AbstractTokenizer from ..tokenizer.base import AbstractTokenizer
from .progress import ProgressLogger from .progress import ProgressLogger
from . import runners from . import runners
@@ -32,15 +32,15 @@ class PlaceFetcher:
self.conn: Optional[DBConnection] = DBConnection(dsn, self.conn: Optional[DBConnection] = DBConnection(dsn,
cursor_factory=psycopg2.extras.DictCursor) cursor_factory=psycopg2.extras.DictCursor)
with setup_conn.cursor() as cur: # need to fetch those manually because register_hstore cannot
# need to fetch those manually because register_hstore cannot # fetch them on an asynchronous connection below.
# fetch them on an asynchronous connection below. hstore_oid = execute_scalar(setup_conn, "SELECT 'hstore'::regtype::oid")
hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid") hstore_array_oid = execute_scalar(setup_conn, "SELECT 'hstore[]'::regtype::oid")
hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid")
psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid, psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
array_oid=hstore_array_oid) array_oid=hstore_array_oid)
def close(self) -> None: def close(self) -> None:
""" Close the underlying asynchronous connection. """ Close the underlying asynchronous connection.
""" """
@@ -205,10 +205,9 @@ class Indexer:
LOG.warning("Starting %s (using batch size %s)", runner.name(), batch) LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
psycopg2.extras.register_hstore(conn) register_hstore(conn)
with conn.cursor() as cur: total_tuples = execute_scalar(conn, runner.sql_count_objects())
total_tuples = cur.scalar(runner.sql_count_objects()) LOG.debug("Total number of rows: %i", total_tuples)
LOG.debug("Total number of rows: %i", total_tuples)
conn.commit() conn.commit()

View File

@@ -16,7 +16,8 @@ import logging
from pathlib import Path from pathlib import Path
from textwrap import dedent from textwrap import dedent
from ..db.connection import connect, Connection, Cursor from ..db.connection import connect, Connection, Cursor, server_version_tuple,\
drop_tables, table_exists, execute_scalar
from ..config import Configuration from ..config import Configuration
from ..db.utils import CopyBuffer from ..db.utils import CopyBuffer
from ..db.sql_preprocessor import SQLPreprocessor from ..db.sql_preprocessor import SQLPreprocessor
@@ -108,7 +109,7 @@ class ICUTokenizer(AbstractTokenizer):
""" Recompute frequencies for all name words. """ Recompute frequencies for all name words.
""" """
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
if not conn.table_exists('search_name'): if not table_exists(conn, 'search_name'):
return return
with conn.cursor() as cur: with conn.cursor() as cur:
@@ -117,10 +118,9 @@ class ICUTokenizer(AbstractTokenizer):
cur.execute('SET max_parallel_workers_per_gather TO %s', cur.execute('SET max_parallel_workers_per_gather TO %s',
(min(threads, 6),)) (min(threads, 6),))
if conn.server_version_tuple() < (12, 0): if server_version_tuple(conn) < (12, 0):
LOG.info('Computing word frequencies') LOG.info('Computing word frequencies')
cur.drop_table('word_frequencies') drop_tables(conn, 'word_frequencies', 'addressword_frequencies')
cur.drop_table('addressword_frequencies')
cur.execute("""CREATE TEMP TABLE word_frequencies AS cur.execute("""CREATE TEMP TABLE word_frequencies AS
SELECT unnest(name_vector) as id, count(*) SELECT unnest(name_vector) as id, count(*)
FROM search_name GROUP BY id""") FROM search_name GROUP BY id""")
@@ -152,17 +152,16 @@ class ICUTokenizer(AbstractTokenizer):
$$ LANGUAGE plpgsql IMMUTABLE; $$ LANGUAGE plpgsql IMMUTABLE;
""") """)
LOG.info('Update word table with recomputed frequencies') LOG.info('Update word table with recomputed frequencies')
cur.drop_table('tmp_word') drop_tables(conn, 'tmp_word')
cur.execute("""CREATE TABLE tmp_word AS cur.execute("""CREATE TABLE tmp_word AS
SELECT word_id, word_token, type, word, SELECT word_id, word_token, type, word,
word_freq_update(word_id, info) as info word_freq_update(word_id, info) as info
FROM word FROM word
""") """)
cur.drop_table('word_frequencies') drop_tables(conn, 'word_frequencies', 'addressword_frequencies')
cur.drop_table('addressword_frequencies')
else: else:
LOG.info('Computing word frequencies') LOG.info('Computing word frequencies')
cur.drop_table('word_frequencies') drop_tables(conn, 'word_frequencies')
cur.execute(""" cur.execute("""
CREATE TEMP TABLE word_frequencies AS CREATE TEMP TABLE word_frequencies AS
WITH word_freq AS MATERIALIZED ( WITH word_freq AS MATERIALIZED (
@@ -182,7 +181,7 @@ class ICUTokenizer(AbstractTokenizer):
cur.execute('CREATE UNIQUE INDEX ON word_frequencies(id) INCLUDE(info)') cur.execute('CREATE UNIQUE INDEX ON word_frequencies(id) INCLUDE(info)')
cur.execute('ANALYSE word_frequencies') cur.execute('ANALYSE word_frequencies')
LOG.info('Update word table with recomputed frequencies') LOG.info('Update word table with recomputed frequencies')
cur.drop_table('tmp_word') drop_tables(conn, 'tmp_word')
cur.execute("""CREATE TABLE tmp_word AS cur.execute("""CREATE TABLE tmp_word AS
SELECT word_id, word_token, type, word, SELECT word_id, word_token, type, word,
(CASE WHEN wf.info is null THEN word.info (CASE WHEN wf.info is null THEN word.info
@@ -191,7 +190,7 @@ class ICUTokenizer(AbstractTokenizer):
FROM word LEFT JOIN word_frequencies wf FROM word LEFT JOIN word_frequencies wf
ON word.word_id = wf.id ON word.word_id = wf.id
""") """)
cur.drop_table('word_frequencies') drop_tables(conn, 'word_frequencies')
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute('SET max_parallel_workers_per_gather TO 0') cur.execute('SET max_parallel_workers_per_gather TO 0')
@@ -210,7 +209,7 @@ class ICUTokenizer(AbstractTokenizer):
""" Remove unused house numbers. """ Remove unused house numbers.
""" """
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
if not conn.table_exists('search_name'): if not table_exists(conn, 'search_name'):
return return
with conn.cursor(name="hnr_counter") as cur: with conn.cursor(name="hnr_counter") as cur:
cur.execute("""SELECT DISTINCT word_id, coalesce(info->>'lookup', word_token) cur.execute("""SELECT DISTINCT word_id, coalesce(info->>'lookup', word_token)
@@ -311,8 +310,7 @@ class ICUTokenizer(AbstractTokenizer):
frequencies. frequencies.
""" """
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
with conn.cursor() as cur: drop_tables(conn, 'word')
cur.drop_table('word')
sqlp = SQLPreprocessor(conn, config) sqlp = SQLPreprocessor(conn, config)
sqlp.run_string(conn, """ sqlp.run_string(conn, """
CREATE TABLE word ( CREATE TABLE word (
@@ -370,8 +368,8 @@ class ICUTokenizer(AbstractTokenizer):
""" Rename all tables and indexes used by the tokenizer. """ Rename all tables and indexes used by the tokenizer.
""" """
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
drop_tables(conn, 'word')
with conn.cursor() as cur: with conn.cursor() as cur:
cur.drop_table('word')
cur.execute(f"ALTER TABLE {old} RENAME TO word") cur.execute(f"ALTER TABLE {old} RENAME TO word")
for idx in ('word_token', 'word_id'): for idx in ('word_token', 'word_id'):
cur.execute(f"""ALTER INDEX idx_{old}_{idx} cur.execute(f"""ALTER INDEX idx_{old}_{idx}
@@ -733,11 +731,10 @@ class ICUNameAnalyzer(AbstractAnalyzer):
if norm_name: if norm_name:
result = self._cache.housenumbers.get(norm_name, result) result = self._cache.housenumbers.get(norm_name, result)
if result[0] is None: if result[0] is None:
with self.conn.cursor() as cur: hid = execute_scalar(self.conn, "SELECT getorcreate_hnr_id(%s)", (norm_name, ))
hid = cur.scalar("SELECT getorcreate_hnr_id(%s)", (norm_name, ))
result = hid, norm_name result = hid, norm_name
self._cache.housenumbers[norm_name] = result self._cache.housenumbers[norm_name] = result
else: else:
# Otherwise use the analyzer to determine the canonical name. # Otherwise use the analyzer to determine the canonical name.
# Per convention we use the first variant as the 'lookup name', the # Per convention we use the first variant as the 'lookup name', the
@@ -748,11 +745,10 @@ class ICUNameAnalyzer(AbstractAnalyzer):
if result[0] is None: if result[0] is None:
variants = analyzer.compute_variants(word_id) variants = analyzer.compute_variants(word_id)
if variants: if variants:
with self.conn.cursor() as cur: hid = execute_scalar(self.conn, "SELECT create_analyzed_hnr_id(%s, %s)",
hid = cur.scalar("SELECT create_analyzed_hnr_id(%s, %s)",
(word_id, list(variants))) (word_id, list(variants)))
result = hid, variants[0] result = hid, variants[0]
self._cache.housenumbers[word_id] = result self._cache.housenumbers[word_id] = result
return result return result

View File

@@ -18,10 +18,10 @@ from textwrap import dedent
from icu import Transliterator from icu import Transliterator
import psycopg2 import psycopg2
import psycopg2.extras
from ..errors import UsageError from ..errors import UsageError
from ..db.connection import connect, Connection from ..db.connection import connect, Connection, drop_tables, table_exists,\
execute_scalar, register_hstore
from ..config import Configuration from ..config import Configuration
from ..db import properties from ..db import properties
from ..db import utils as db_utils from ..db import utils as db_utils
@@ -179,11 +179,10 @@ class LegacyTokenizer(AbstractTokenizer):
* Can nominatim.so be accessed by the database user? * Can nominatim.so be accessed by the database user?
""" """
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
with conn.cursor() as cur: try:
try: out = execute_scalar(conn, "SELECT make_standard_name('a')")
out = cur.scalar("SELECT make_standard_name('a')") except psycopg2.Error as err:
except psycopg2.Error as err: return hint.format(error=str(err))
return hint.format(error=str(err))
if out != 'a': if out != 'a':
return hint.format(error='Unexpected result for make_standard_name()') return hint.format(error='Unexpected result for make_standard_name()')
@@ -214,9 +213,9 @@ class LegacyTokenizer(AbstractTokenizer):
""" Recompute the frequency of full words. """ Recompute the frequency of full words.
""" """
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
if conn.table_exists('search_name'): if table_exists(conn, 'search_name'):
drop_tables(conn, "word_frequencies")
with conn.cursor() as cur: with conn.cursor() as cur:
cur.drop_table("word_frequencies")
LOG.info("Computing word frequencies") LOG.info("Computing word frequencies")
cur.execute("""CREATE TEMP TABLE word_frequencies AS cur.execute("""CREATE TEMP TABLE word_frequencies AS
SELECT unnest(name_vector) as id, count(*) SELECT unnest(name_vector) as id, count(*)
@@ -226,7 +225,7 @@ class LegacyTokenizer(AbstractTokenizer):
cur.execute("""UPDATE word SET search_name_count = count cur.execute("""UPDATE word SET search_name_count = count
FROM word_frequencies FROM word_frequencies
WHERE word_token like ' %' and word_id = id""") WHERE word_token like ' %' and word_id = id""")
cur.drop_table("word_frequencies") drop_tables(conn, "word_frequencies")
conn.commit() conn.commit()
@@ -316,7 +315,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
self.conn: Optional[Connection] = connect(dsn).connection self.conn: Optional[Connection] = connect(dsn).connection
self.conn.autocommit = True self.conn.autocommit = True
self.normalizer = normalizer self.normalizer = normalizer
psycopg2.extras.register_hstore(self.conn) register_hstore(self.conn)
self._cache = _TokenCache(self.conn) self._cache = _TokenCache(self.conn)
@@ -536,9 +535,8 @@ class _TokenInfo:
def add_names(self, conn: Connection, names: Mapping[str, str]) -> None: def add_names(self, conn: Connection, names: Mapping[str, str]) -> None:
""" Add token information for the names of the place. """ Add token information for the names of the place.
""" """
with conn.cursor() as cur: # Create the token IDs for all names.
# Create the token IDs for all names. self.data['names'] = execute_scalar(conn, "SELECT make_keywords(%s)::text",
self.data['names'] = cur.scalar("SELECT make_keywords(%s)::text",
(names, )) (names, ))
@@ -576,9 +574,8 @@ class _TokenInfo:
""" Add addr:street match terms. """ Add addr:street match terms.
""" """
def _get_street(name: str) -> Optional[str]: def _get_street(name: str) -> Optional[str]:
with conn.cursor() as cur: return cast(Optional[str],
return cast(Optional[str], execute_scalar(conn, "SELECT word_ids_from_name(%s)::text", (name, )))
cur.scalar("SELECT word_ids_from_name(%s)::text", (name, )))
tokens = self.cache.streets.get(street, _get_street) tokens = self.cache.streets.get(street, _get_street)
self.data['street'] = tokens or '{}' self.data['street'] = tokens or '{}'

View File

@@ -10,12 +10,12 @@ Functions for database analysis and maintenance.
from typing import Optional, Tuple, Any, cast from typing import Optional, Tuple, Any, cast
import logging import logging
from psycopg2.extras import Json, register_hstore from psycopg2.extras import Json
from psycopg2 import DataError from psycopg2 import DataError
from ..typing import DictCursorResult from ..typing import DictCursorResult
from ..config import Configuration from ..config import Configuration
from ..db.connection import connect, Cursor from ..db.connection import connect, Cursor, register_hstore
from ..errors import UsageError from ..errors import UsageError
from ..tokenizer import factory as tokenizer_factory from ..tokenizer import factory as tokenizer_factory
from ..data.place_info import PlaceInfo from ..data.place_info import PlaceInfo

View File

@@ -12,7 +12,8 @@ from enum import Enum
from textwrap import dedent from textwrap import dedent
from ..config import Configuration from ..config import Configuration
from ..db.connection import connect, Connection from ..db.connection import connect, Connection, server_version_tuple,\
index_exists, table_exists, execute_scalar
from ..db import properties from ..db import properties
from ..errors import UsageError from ..errors import UsageError
from ..tokenizer import factory as tokenizer_factory from ..tokenizer import factory as tokenizer_factory
@@ -109,14 +110,14 @@ def _get_indexes(conn: Connection) -> List[str]:
'idx_postcode_id', 'idx_postcode_id',
'idx_postcode_postcode' 'idx_postcode_postcode'
] ]
if conn.table_exists('search_name'): if table_exists(conn, 'search_name'):
indexes.extend(('idx_search_name_nameaddress_vector', indexes.extend(('idx_search_name_nameaddress_vector',
'idx_search_name_name_vector', 'idx_search_name_name_vector',
'idx_search_name_centroid')) 'idx_search_name_centroid'))
if conn.server_version_tuple() >= (11, 0, 0): if server_version_tuple(conn) >= (11, 0, 0):
indexes.extend(('idx_placex_housenumber', indexes.extend(('idx_placex_housenumber',
'idx_osmline_parent_osm_id_with_hnr')) 'idx_osmline_parent_osm_id_with_hnr'))
if conn.table_exists('place'): if table_exists(conn, 'place'):
indexes.extend(('idx_location_area_country_place_id', indexes.extend(('idx_location_area_country_place_id',
'idx_place_osm_unique', 'idx_place_osm_unique',
'idx_placex_rank_address_sector', 'idx_placex_rank_address_sector',
@@ -153,7 +154,7 @@ def check_connection(conn: Any, config: Configuration) -> CheckResult:
Hints: Hints:
* Are you connecting to the correct database? * Are you connecting to the correct database?
{instruction} {instruction}
Check the Migration chapter of the Administration Guide. Check the Migration chapter of the Administration Guide.
@@ -165,7 +166,7 @@ def check_database_version(conn: Connection, config: Configuration) -> CheckResu
""" Checking database_version matches Nominatim software version """ Checking database_version matches Nominatim software version
""" """
if conn.table_exists('nominatim_properties'): if table_exists(conn, 'nominatim_properties'):
db_version_str = properties.get_property(conn, 'database_version') db_version_str = properties.get_property(conn, 'database_version')
else: else:
db_version_str = None db_version_str = None
@@ -202,7 +203,7 @@ def check_database_version(conn: Connection, config: Configuration) -> CheckResu
def check_placex_table(conn: Connection, config: Configuration) -> CheckResult: def check_placex_table(conn: Connection, config: Configuration) -> CheckResult:
""" Checking for placex table """ Checking for placex table
""" """
if conn.table_exists('placex'): if table_exists(conn, 'placex'):
return CheckState.OK return CheckState.OK
return CheckState.FATAL, dict(config=config) return CheckState.FATAL, dict(config=config)
@@ -212,8 +213,7 @@ def check_placex_table(conn: Connection, config: Configuration) -> CheckResult:
def check_placex_size(conn: Connection, _: Configuration) -> CheckResult: def check_placex_size(conn: Connection, _: Configuration) -> CheckResult:
""" Checking for placex content """ Checking for placex content
""" """
with conn.cursor() as cur: cnt = execute_scalar(conn, 'SELECT count(*) FROM (SELECT * FROM placex LIMIT 100) x')
cnt = cur.scalar('SELECT count(*) FROM (SELECT * FROM placex LIMIT 100) x')
return CheckState.OK if cnt > 0 else CheckState.FATAL return CheckState.OK if cnt > 0 else CheckState.FATAL
@@ -244,16 +244,15 @@ def check_tokenizer(_: Connection, config: Configuration) -> CheckResult:
def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult: def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult:
""" Checking for wikipedia/wikidata data """ Checking for wikipedia/wikidata data
""" """
if not conn.table_exists('search_name') or not conn.table_exists('place'): if not table_exists(conn, 'search_name') or not table_exists(conn, 'place'):
return CheckState.NOT_APPLICABLE return CheckState.NOT_APPLICABLE
with conn.cursor() as cur: if table_exists(conn, 'wikimedia_importance'):
if conn.table_exists('wikimedia_importance'): cnt = execute_scalar(conn, 'SELECT count(*) FROM wikimedia_importance')
cnt = cur.scalar('SELECT count(*) FROM wikimedia_importance') else:
else: cnt = execute_scalar(conn, 'SELECT count(*) FROM wikipedia_article')
cnt = cur.scalar('SELECT count(*) FROM wikipedia_article')
return CheckState.WARN if cnt == 0 else CheckState.OK return CheckState.WARN if cnt == 0 else CheckState.OK
@_check(hint="""\ @_check(hint="""\
@@ -264,8 +263,7 @@ def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult
def check_indexing(conn: Connection, _: Configuration) -> CheckResult: def check_indexing(conn: Connection, _: Configuration) -> CheckResult:
""" Checking indexing status """ Checking indexing status
""" """
with conn.cursor() as cur: cnt = execute_scalar(conn, 'SELECT count(*) FROM placex WHERE indexed_status > 0')
cnt = cur.scalar('SELECT count(*) FROM placex WHERE indexed_status > 0')
if cnt == 0: if cnt == 0:
return CheckState.OK return CheckState.OK
@@ -276,7 +274,7 @@ def check_indexing(conn: Connection, _: Configuration) -> CheckResult:
Low counts of unindexed places are fine.""" Low counts of unindexed places are fine."""
return CheckState.WARN, dict(count=cnt, index_cmd=index_cmd) return CheckState.WARN, dict(count=cnt, index_cmd=index_cmd)
if conn.index_exists('idx_placex_rank_search'): if index_exists(conn, 'idx_placex_rank_search'):
# Likely just an interrupted update. # Likely just an interrupted update.
index_cmd = 'nominatim index' index_cmd = 'nominatim index'
else: else:
@@ -297,7 +295,7 @@ def check_database_indexes(conn: Connection, _: Configuration) -> CheckResult:
""" """
missing = [] missing = []
for index in _get_indexes(conn): for index in _get_indexes(conn):
if not conn.index_exists(index): if not index_exists(conn, index):
missing.append(index) missing.append(index)
if missing: if missing:
@@ -340,11 +338,10 @@ def check_tiger_table(conn: Connection, config: Configuration) -> CheckResult:
if not config.get_bool('USE_US_TIGER_DATA'): if not config.get_bool('USE_US_TIGER_DATA'):
return CheckState.NOT_APPLICABLE return CheckState.NOT_APPLICABLE
if not conn.table_exists('location_property_tiger'): if not table_exists(conn, 'location_property_tiger'):
return CheckState.FAIL, dict(error='TIGER data table not found.') return CheckState.FAIL, dict(error='TIGER data table not found.')
with conn.cursor() as cur: if execute_scalar(conn, 'SELECT count(*) FROM location_property_tiger') == 0:
if cur.scalar('SELECT count(*) FROM location_property_tiger') == 0: return CheckState.FAIL, dict(error='TIGER data table is empty.')
return CheckState.FAIL, dict(error='TIGER data table is empty.')
return CheckState.OK return CheckState.OK

View File

@@ -12,21 +12,16 @@ import os
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Union from typing import List, Optional, Union
import psutil import psutil
from psycopg2.extensions import make_dsn, parse_dsn from psycopg2.extensions import make_dsn
from ..config import Configuration from ..config import Configuration
from ..db.connection import connect from ..db.connection import connect, server_version_tuple, execute_scalar
from ..version import NOMINATIM_VERSION from ..version import NOMINATIM_VERSION
def convert_version(ver_tup: Tuple[int, int]) -> str:
"""converts tuple version (ver_tup) to a string representation"""
return ".".join(map(str, ver_tup))
def friendly_memory_string(mem: float) -> str: def friendly_memory_string(mem: float) -> str:
"""Create a user friendly string for the amount of memory specified as mem""" """Create a user friendly string for the amount of memory specified as mem"""
mem_magnitude = ("bytes", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") mem_magnitude = ("bytes", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
@@ -103,16 +98,16 @@ def report_system_information(config: Configuration) -> None:
storage, and database configuration.""" storage, and database configuration."""
with connect(make_dsn(config.get_libpq_dsn(), dbname='postgres')) as conn: with connect(make_dsn(config.get_libpq_dsn(), dbname='postgres')) as conn:
postgresql_ver: str = convert_version(conn.server_version_tuple()) postgresql_ver: str = '.'.join(map(str, server_version_tuple(conn)))
with conn.cursor() as cur: with conn.cursor() as cur:
num = cur.scalar("SELECT count(*) FROM pg_catalog.pg_database WHERE datname=%s", cur.execute("SELECT datname FROM pg_catalog.pg_database WHERE datname=%s",
(parse_dsn(config.get_libpq_dsn())['dbname'], )) (config.get_database_params()['dbname'], ))
nominatim_db_exists = num == 1 if isinstance(num, int) else False nominatim_db_exists = cur.rowcount > 0
if nominatim_db_exists: if nominatim_db_exists:
with connect(config.get_libpq_dsn()) as conn: with connect(config.get_libpq_dsn()) as conn:
postgis_ver: str = convert_version(conn.postgis_version_tuple()) postgis_ver: str = execute_scalar(conn, 'SELECT postgis_lib_version()')
else: else:
postgis_ver = "Unable to connect to database" postgis_ver = "Unable to connect to database"

View File

@@ -19,7 +19,8 @@ from psycopg2 import sql as pysql
from ..errors import UsageError from ..errors import UsageError
from ..config import Configuration from ..config import Configuration
from ..db.connection import connect, get_pg_env, Connection from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\
postgis_version_tuple, drop_tables, table_exists, execute_scalar
from ..db.async_connection import DBConnection from ..db.async_connection import DBConnection
from ..db.sql_preprocessor import SQLPreprocessor from ..db.sql_preprocessor import SQLPreprocessor
from .exec_utils import run_osm2pgsql from .exec_utils import run_osm2pgsql
@@ -51,10 +52,10 @@ def check_existing_database_plugins(dsn: str) -> None:
""" Check that the database has the required plugins installed.""" """ Check that the database has the required plugins installed."""
with connect(dsn) as conn: with connect(dsn) as conn:
_require_version('PostgreSQL server', _require_version('PostgreSQL server',
conn.server_version_tuple(), server_version_tuple(conn),
POSTGRESQL_REQUIRED_VERSION) POSTGRESQL_REQUIRED_VERSION)
_require_version('PostGIS', _require_version('PostGIS',
conn.postgis_version_tuple(), postgis_version_tuple(conn),
POSTGIS_REQUIRED_VERSION) POSTGIS_REQUIRED_VERSION)
_require_loaded('hstore', conn) _require_loaded('hstore', conn)
@@ -80,31 +81,30 @@ def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None:
with connect(dsn) as conn: with connect(dsn) as conn:
_require_version('PostgreSQL server', _require_version('PostgreSQL server',
conn.server_version_tuple(), server_version_tuple(conn),
POSTGRESQL_REQUIRED_VERSION) POSTGRESQL_REQUIRED_VERSION)
if rouser is not None: if rouser is not None:
with conn.cursor() as cur: cnt = execute_scalar(conn, 'SELECT count(*) FROM pg_user where usename = %s',
cnt = cur.scalar('SELECT count(*) FROM pg_user where usename = %s',
(rouser, )) (rouser, ))
if cnt == 0: if cnt == 0:
LOG.fatal("Web user '%s' does not exist. Create it with:\n" LOG.fatal("Web user '%s' does not exist. Create it with:\n"
"\n createuser %s", rouser, rouser) "\n createuser %s", rouser, rouser)
raise UsageError('Missing read-only user.') raise UsageError('Missing read-only user.')
# Create extensions. # Create extensions.
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute('CREATE EXTENSION IF NOT EXISTS hstore') cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
cur.execute('CREATE EXTENSION IF NOT EXISTS postgis') cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
postgis_version = conn.postgis_version_tuple() postgis_version = postgis_version_tuple(conn)
if postgis_version[0] >= 3: if postgis_version[0] >= 3:
cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster') cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
conn.commit() conn.commit()
_require_version('PostGIS', _require_version('PostGIS',
conn.postgis_version_tuple(), postgis_version_tuple(conn),
POSTGIS_REQUIRED_VERSION) POSTGIS_REQUIRED_VERSION)
@@ -141,7 +141,8 @@ def import_osm_data(osm_files: Union[Path, Sequence[Path]],
raise UsageError('No data imported by osm2pgsql.') raise UsageError('No data imported by osm2pgsql.')
if drop: if drop:
conn.drop_table('planet_osm_nodes') drop_tables(conn, 'planet_osm_nodes')
conn.commit()
if drop and options['flatnode_file']: if drop and options['flatnode_file']:
Path(options['flatnode_file']).unlink() Path(options['flatnode_file']).unlink()
@@ -184,7 +185,7 @@ def truncate_data_tables(conn: Connection) -> None:
cur.execute('TRUNCATE location_property_tiger') cur.execute('TRUNCATE location_property_tiger')
cur.execute('TRUNCATE location_property_osmline') cur.execute('TRUNCATE location_property_osmline')
cur.execute('TRUNCATE location_postcode') cur.execute('TRUNCATE location_postcode')
if conn.table_exists('search_name'): if table_exists(conn, 'search_name'):
cur.execute('TRUNCATE search_name') cur.execute('TRUNCATE search_name')
cur.execute('DROP SEQUENCE IF EXISTS seq_place') cur.execute('DROP SEQUENCE IF EXISTS seq_place')
cur.execute('CREATE SEQUENCE seq_place start 100000') cur.execute('CREATE SEQUENCE seq_place start 100000')

View File

@@ -12,7 +12,7 @@ from pathlib import Path
from psycopg2 import sql as pysql from psycopg2 import sql as pysql
from ..db.connection import Connection from ..db.connection import Connection, drop_tables, table_exists
UPDATE_TABLES = [ UPDATE_TABLES = [
'address_levels', 'address_levels',
@@ -39,9 +39,7 @@ def drop_update_tables(conn: Connection) -> None:
+ pysql.SQL(' or ').join(parts)) + pysql.SQL(' or ').join(parts))
tables = [r[0] for r in cur] tables = [r[0] for r in cur]
for table in tables: drop_tables(conn, *tables, cascade=True)
cur.drop_table(table, cascade=True)
conn.commit() conn.commit()
@@ -55,4 +53,4 @@ def is_frozen(conn: Connection) -> bool:
""" Returns true if database is in a frozen state """ Returns true if database is in a frozen state
""" """
return conn.table_exists('place') is False return table_exists(conn, 'place') is False

View File

@@ -15,7 +15,8 @@ from psycopg2 import sql as pysql
from ..errors import UsageError from ..errors import UsageError
from ..config import Configuration from ..config import Configuration
from ..db import properties from ..db import properties
from ..db.connection import connect, Connection from ..db.connection import connect, Connection, server_version_tuple,\
table_has_column, table_exists, execute_scalar, register_hstore
from ..version import NominatimVersion, NOMINATIM_VERSION, parse_version from ..version import NominatimVersion, NOMINATIM_VERSION, parse_version
from ..tokenizer import factory as tokenizer_factory from ..tokenizer import factory as tokenizer_factory
from . import refresh from . import refresh
@@ -29,7 +30,8 @@ def migrate(config: Configuration, paths: Any) -> int:
if necesssary. if necesssary.
""" """
with connect(config.get_libpq_dsn()) as conn: with connect(config.get_libpq_dsn()) as conn:
if conn.table_exists('nominatim_properties'): register_hstore(conn)
if table_exists(conn, 'nominatim_properties'):
db_version_str = properties.get_property(conn, 'database_version') db_version_str = properties.get_property(conn, 'database_version')
else: else:
db_version_str = None db_version_str = None
@@ -72,16 +74,15 @@ def _guess_version(conn: Connection) -> NominatimVersion:
Only migrations for 3.6 and later are supported, so bail out Only migrations for 3.6 and later are supported, so bail out
when the version seems older. when the version seems older.
""" """
with conn.cursor() as cur: # In version 3.6, the country_name table was updated. Check for that.
# In version 3.6, the country_name table was updated. Check for that. cnt = execute_scalar(conn, """SELECT count(*) FROM
cnt = cur.scalar("""SELECT count(*) FROM (SELECT svals(name) FROM country_name
(SELECT svals(name) FROM country_name WHERE country_code = 'gb')x;
WHERE country_code = 'gb')x; """)
""") if cnt < 100:
if cnt < 100: LOG.fatal('It looks like your database was imported with a version '
LOG.fatal('It looks like your database was imported with a version ' 'prior to 3.6.0. Automatic migration not possible.')
'prior to 3.6.0. Automatic migration not possible.') raise UsageError('Migration not possible.')
raise UsageError('Migration not possible.')
return NominatimVersion(3, 5, 0, 99) return NominatimVersion(3, 5, 0, 99)
@@ -125,7 +126,7 @@ def import_status_timestamp_change(conn: Connection, **_: Any) -> None:
def add_nominatim_property_table(conn: Connection, config: Configuration, **_: Any) -> None: def add_nominatim_property_table(conn: Connection, config: Configuration, **_: Any) -> None:
""" Add nominatim_property table. """ Add nominatim_property table.
""" """
if not conn.table_exists('nominatim_properties'): if not table_exists(conn, 'nominatim_properties'):
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(pysql.SQL("""CREATE TABLE nominatim_properties ( cur.execute(pysql.SQL("""CREATE TABLE nominatim_properties (
property TEXT, property TEXT,
@@ -189,13 +190,9 @@ def install_legacy_tokenizer(conn: Connection, config: Configuration, **_: Any)
configuration for the backwards-compatible legacy tokenizer configuration for the backwards-compatible legacy tokenizer
""" """
if properties.get_property(conn, 'tokenizer') is None: if properties.get_property(conn, 'tokenizer') is None:
with conn.cursor() as cur: for table in ('placex', 'location_property_osmline'):
for table in ('placex', 'location_property_osmline'): if not table_has_column(conn, table, 'token_info'):
has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns with conn.cursor() as cur:
WHERE table_name = %s
and column_name = 'token_info'""",
(table, ))
if has_column == 0:
cur.execute(pysql.SQL('ALTER TABLE {} ADD COLUMN token_info JSONB') cur.execute(pysql.SQL('ALTER TABLE {} ADD COLUMN token_info JSONB')
.format(pysql.Identifier(table))) .format(pysql.Identifier(table)))
tokenizer = tokenizer_factory.create_tokenizer(config, init_db=False, tokenizer = tokenizer_factory.create_tokenizer(config, init_db=False,
@@ -212,7 +209,7 @@ def create_tiger_housenumber_index(conn: Connection, **_: Any) -> None:
The inclusion is needed for efficient lookup of housenumbers in The inclusion is needed for efficient lookup of housenumbers in
full address searches. full address searches.
""" """
if conn.server_version_tuple() >= (11, 0, 0): if server_version_tuple(conn) >= (11, 0, 0):
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute(""" CREATE INDEX IF NOT EXISTS cur.execute(""" CREATE INDEX IF NOT EXISTS
idx_location_property_tiger_housenumber_migrated idx_location_property_tiger_housenumber_migrated
@@ -239,7 +236,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None:
Also converts the data into the stricter format which requires that Also converts the data into the stricter format which requires that
startnumbers comply with the odd/even requirements. startnumbers comply with the odd/even requirements.
""" """
if conn.table_has_column('location_property_osmline', 'step'): if table_has_column(conn, 'location_property_osmline', 'step'):
return return
with conn.cursor() as cur: with conn.cursor() as cur:
@@ -271,7 +268,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None:
def add_step_column_for_tiger(conn: Connection, **_: Any) -> None: def add_step_column_for_tiger(conn: Connection, **_: Any) -> None:
""" Add a new column 'step' to the tiger data table. """ Add a new column 'step' to the tiger data table.
""" """
if conn.table_has_column('location_property_tiger', 'step'): if table_has_column(conn, 'location_property_tiger', 'step'):
return return
with conn.cursor() as cur: with conn.cursor() as cur:
@@ -287,7 +284,7 @@ def add_derived_name_column_for_country_names(conn: Connection, **_: Any) -> Non
""" Add a new column 'derived_name' which in the future takes the """ Add a new column 'derived_name' which in the future takes the
country names as imported from OSM data. country names as imported from OSM data.
""" """
if not conn.table_has_column('country_name', 'derived_name'): if not table_has_column(conn, 'country_name', 'derived_name'):
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute("ALTER TABLE country_name ADD COLUMN derived_name public.HSTORE") cur.execute("ALTER TABLE country_name ADD COLUMN derived_name public.HSTORE")
@@ -297,12 +294,9 @@ def mark_internal_country_names(conn: Connection, config: Configuration, **_: An
""" Names from the country table should be marked as internal to prevent """ Names from the country table should be marked as internal to prevent
them from being deleted. Only necessary for ICU tokenizer. them from being deleted. Only necessary for ICU tokenizer.
""" """
import psycopg2.extras # pylint: disable=import-outside-toplevel
tokenizer = tokenizer_factory.get_tokenizer_for_db(config) tokenizer = tokenizer_factory.get_tokenizer_for_db(config)
with tokenizer.name_analyzer() as analyzer: with tokenizer.name_analyzer() as analyzer:
with conn.cursor() as cur: with conn.cursor() as cur:
psycopg2.extras.register_hstore(cur)
cur.execute("SELECT country_code, name FROM country_name") cur.execute("SELECT country_code, name FROM country_name")
for country_code, names in cur: for country_code, names in cur:
@@ -319,7 +313,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None:
The table is only necessary when updates are possible, i.e. The table is only necessary when updates are possible, i.e.
the database is not in freeze mode. the database is not in freeze mode.
""" """
if conn.table_exists('place'): if table_exists(conn, 'place'):
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute("""CREATE TABLE IF NOT EXISTS place_to_be_deleted ( cur.execute("""CREATE TABLE IF NOT EXISTS place_to_be_deleted (
osm_type CHAR(1), osm_type CHAR(1),
@@ -333,7 +327,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None:
def split_pending_index(conn: Connection, **_: Any) -> None: def split_pending_index(conn: Connection, **_: Any) -> None:
""" Reorganise indexes for pending updates. """ Reorganise indexes for pending updates.
""" """
if conn.table_exists('place'): if table_exists(conn, 'place'):
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute("""CREATE INDEX IF NOT EXISTS idx_placex_rank_address_sector cur.execute("""CREATE INDEX IF NOT EXISTS idx_placex_rank_address_sector
ON placex USING BTREE (rank_address, geometry_sector) ON placex USING BTREE (rank_address, geometry_sector)
@@ -349,7 +343,7 @@ def split_pending_index(conn: Connection, **_: Any) -> None:
def enable_forward_dependencies(conn: Connection, **_: Any) -> None: def enable_forward_dependencies(conn: Connection, **_: Any) -> None:
""" Create indexes for updates with forward dependency tracking (long-running). """ Create indexes for updates with forward dependency tracking (long-running).
""" """
if conn.table_exists('planet_osm_ways'): if table_exists(conn, 'planet_osm_ways'):
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute("""SELECT * FROM pg_indexes cur.execute("""SELECT * FROM pg_indexes
WHERE tablename = 'planet_osm_ways' WHERE tablename = 'planet_osm_ways'
@@ -398,7 +392,7 @@ def create_postcode_area_lookup_index(conn: Connection, **_: Any) -> None:
def create_postcode_parent_index(conn: Connection, **_: Any) -> None: def create_postcode_parent_index(conn: Connection, **_: Any) -> None:
""" Create index needed for updating postcodes when a parent changes. """ Create index needed for updating postcodes when a parent changes.
""" """
if conn.table_exists('planet_osm_ways'): if table_exists(conn, 'planet_osm_ways'):
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute("""CREATE INDEX IF NOT EXISTS cur.execute("""CREATE INDEX IF NOT EXISTS
idx_location_postcode_parent_place_id idx_location_postcode_parent_place_id

View File

@@ -18,7 +18,7 @@ from math import isfinite
from psycopg2 import sql as pysql from psycopg2 import sql as pysql
from ..db.connection import connect, Connection from ..db.connection import connect, Connection, table_exists
from ..utils.centroid import PointsCentroid from ..utils.centroid import PointsCentroid
from ..data.postcode_format import PostcodeFormatter, CountryPostcodeMatcher from ..data.postcode_format import PostcodeFormatter, CountryPostcodeMatcher
from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer
@@ -231,4 +231,4 @@ def can_compute(dsn: str) -> bool:
postcodes can be computed. postcodes can be computed.
""" """
with connect(dsn) as conn: with connect(dsn) as conn:
return conn.table_exists('place') return table_exists(conn, 'place')

View File

@@ -17,7 +17,8 @@ from pathlib import Path
from psycopg2 import sql as pysql from psycopg2 import sql as pysql
from ..config import Configuration from ..config import Configuration
from ..db.connection import Connection, connect from ..db.connection import Connection, connect, postgis_version_tuple,\
drop_tables, table_exists
from ..db.utils import execute_file, CopyBuffer from ..db.utils import execute_file, CopyBuffer
from ..db.sql_preprocessor import SQLPreprocessor from ..db.sql_preprocessor import SQLPreprocessor
from ..version import NOMINATIM_VERSION from ..version import NOMINATIM_VERSION
@@ -56,9 +57,9 @@ def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[s
for entry in levels: for entry in levels:
_add_address_level_rows_from_entry(rows, entry) _add_address_level_rows_from_entry(rows, entry)
with conn.cursor() as cur: drop_tables(conn, table)
cur.drop_table(table)
with conn.cursor() as cur:
cur.execute(pysql.SQL("""CREATE TABLE {} ( cur.execute(pysql.SQL("""CREATE TABLE {} (
country_code varchar(2), country_code varchar(2),
class TEXT, class TEXT,
@@ -159,10 +160,8 @@ def import_importance_csv(dsn: str, data_file: Path) -> int:
wd_done = set() wd_done = set()
with connect(dsn) as conn: with connect(dsn) as conn:
drop_tables(conn, 'wikipedia_article', 'wikipedia_redirect', 'wikimedia_importance')
with conn.cursor() as cur: with conn.cursor() as cur:
cur.drop_table('wikipedia_article')
cur.drop_table('wikipedia_redirect')
cur.drop_table('wikimedia_importance')
cur.execute("""CREATE TABLE wikimedia_importance ( cur.execute("""CREATE TABLE wikimedia_importance (
language TEXT NOT NULL, language TEXT NOT NULL,
title TEXT NOT NULL, title TEXT NOT NULL,
@@ -228,7 +227,7 @@ def import_secondary_importance(dsn: str, data_path: Path, ignore_errors: bool =
return 1 return 1
with connect(dsn) as conn: with connect(dsn) as conn:
postgis_version = conn.postgis_version_tuple() postgis_version = postgis_version_tuple(conn)
if postgis_version[0] < 3: if postgis_version[0] < 3:
LOG.error('PostGIS version is too old for using OSM raster data.') LOG.error('PostGIS version is too old for using OSM raster data.')
return 2 return 2
@@ -309,7 +308,7 @@ def setup_website(basedir: Path, config: Configuration, conn: Connection) -> Non
template = "\nrequire_once(CONST_LibDir.'/website/{}');\n" template = "\nrequire_once(CONST_LibDir.'/website/{}');\n"
search_name_table_exists = bool(conn and conn.table_exists('search_name')) search_name_table_exists = bool(conn and table_exists(conn, 'search_name'))
for script in WEBSITE_SCRIPTS: for script in WEBSITE_SCRIPTS:
if not search_name_table_exists and script == 'search.php': if not search_name_table_exists and script == 'search.php':

View File

@@ -20,7 +20,7 @@ import requests
from ..errors import UsageError from ..errors import UsageError
from ..db import status from ..db import status
from ..db.connection import Connection, connect from ..db.connection import Connection, connect, server_version_tuple
from .exec_utils import run_osm2pgsql from .exec_utils import run_osm2pgsql
try: try:
@@ -155,7 +155,7 @@ def run_osm2pgsql_updates(conn: Connection, options: MutableMapping[str, Any]) -
# Consume updates with osm2pgsql. # Consume updates with osm2pgsql.
options['append'] = True options['append'] = True
options['disable_jit'] = conn.server_version_tuple() >= (11, 0) options['disable_jit'] = server_version_tuple(conn) >= (11, 0)
run_osm2pgsql(options) run_osm2pgsql(options)
# Handle deletions # Handle deletions

View File

@@ -21,7 +21,7 @@ from psycopg2.sql import Identifier, SQL
from ...typing import Protocol from ...typing import Protocol
from ...config import Configuration from ...config import Configuration
from ...db.connection import Connection from ...db.connection import Connection, drop_tables, index_exists
from .importer_statistics import SpecialPhrasesImporterStatistics from .importer_statistics import SpecialPhrasesImporterStatistics
from .special_phrase import SpecialPhrase from .special_phrase import SpecialPhrase
from ...tokenizer.base import AbstractTokenizer from ...tokenizer.base import AbstractTokenizer
@@ -233,7 +233,7 @@ class SPImporter():
index_prefix = f'idx_place_classtype_{phrase_class}_{phrase_type}_' index_prefix = f'idx_place_classtype_{phrase_class}_{phrase_type}_'
base_table = _classtype_table(phrase_class, phrase_type) base_table = _classtype_table(phrase_class, phrase_type)
# Index on centroid # Index on centroid
if not self.db_connection.index_exists(index_prefix + 'centroid'): if not index_exists(self.db_connection, index_prefix + 'centroid'):
with self.db_connection.cursor() as db_cursor: with self.db_connection.cursor() as db_cursor:
db_cursor.execute(SQL("CREATE INDEX {} ON {} USING GIST (centroid) {}") db_cursor.execute(SQL("CREATE INDEX {} ON {} USING GIST (centroid) {}")
.format(Identifier(index_prefix + 'centroid'), .format(Identifier(index_prefix + 'centroid'),
@@ -241,7 +241,7 @@ class SPImporter():
SQL(sql_tablespace))) SQL(sql_tablespace)))
# Index on place_id # Index on place_id
if not self.db_connection.index_exists(index_prefix + 'place_id'): if not index_exists(self.db_connection, index_prefix + 'place_id'):
with self.db_connection.cursor() as db_cursor: with self.db_connection.cursor() as db_cursor:
db_cursor.execute(SQL("CREATE INDEX {} ON {} USING btree(place_id) {}") db_cursor.execute(SQL("CREATE INDEX {} ON {} USING btree(place_id) {}")
.format(Identifier(index_prefix + 'place_id'), .format(Identifier(index_prefix + 'place_id'),
@@ -259,6 +259,7 @@ class SPImporter():
.format(Identifier(table_name), .format(Identifier(table_name),
Identifier(self.config.DATABASE_WEBUSER))) Identifier(self.config.DATABASE_WEBUSER)))
def _remove_non_existent_tables_from_db(self) -> None: def _remove_non_existent_tables_from_db(self) -> None:
""" """
Remove special phrases which doesn't exist on the wiki anymore. Remove special phrases which doesn't exist on the wiki anymore.
@@ -268,7 +269,6 @@ class SPImporter():
# Delete place_classtype tables corresponding to class/type which # Delete place_classtype tables corresponding to class/type which
# are not on the wiki anymore. # are not on the wiki anymore.
with self.db_connection.cursor() as db_cursor: drop_tables(self.db_connection, *self.table_phrases_to_delete)
for table in self.table_phrases_to_delete: for _ in self.table_phrases_to_delete:
self.statistics_handler.notify_one_table_deleted() self.statistics_handler.notify_one_table_deleted()
db_cursor.drop_table(table)

View File

@@ -33,13 +33,13 @@ def simple_conns(temp_db):
conn2.close() conn2.close()
def test_simple_query(conn, temp_db_conn): def test_simple_query(conn, temp_db_cursor):
conn.connect() conn.connect()
conn.perform('CREATE TABLE foo (id INT)') conn.perform('CREATE TABLE foo (id INT)')
conn.wait() conn.wait()
temp_db_conn.table_exists('foo') assert temp_db_cursor.table_exists('foo')
def test_wait_for_query(conn): def test_wait_for_query(conn):

View File

@@ -10,61 +10,74 @@ Tests for specialised connection and cursor classes.
import pytest import pytest
import psycopg2 import psycopg2
from nominatim_db.db.connection import connect, get_pg_env import nominatim_db.db.connection as nc
@pytest.fixture @pytest.fixture
def db(dsn): def db(dsn):
with connect(dsn) as conn: with nc.connect(dsn) as conn:
yield conn yield conn
def test_connection_table_exists(db, table_factory): def test_connection_table_exists(db, table_factory):
assert not db.table_exists('foobar') assert not nc.table_exists(db, 'foobar')
table_factory('foobar') table_factory('foobar')
assert db.table_exists('foobar') assert nc.table_exists(db, 'foobar')
def test_has_column_no_table(db): def test_has_column_no_table(db):
assert not db.table_has_column('sometable', 'somecolumn') assert not nc.table_has_column(db, 'sometable', 'somecolumn')
@pytest.mark.parametrize('name,result', [('tram', True), ('car', False)]) @pytest.mark.parametrize('name,result', [('tram', True), ('car', False)])
def test_has_column(db, table_factory, name, result): def test_has_column(db, table_factory, name, result):
table_factory('stuff', 'tram TEXT') table_factory('stuff', 'tram TEXT')
assert db.table_has_column('stuff', name) == result assert nc.table_has_column(db, 'stuff', name) == result
def test_connection_index_exists(db, table_factory, temp_db_cursor): def test_connection_index_exists(db, table_factory, temp_db_cursor):
assert not db.index_exists('some_index') assert not nc.index_exists(db, 'some_index')
table_factory('foobar') table_factory('foobar')
temp_db_cursor.execute('CREATE INDEX some_index ON foobar(id)') temp_db_cursor.execute('CREATE INDEX some_index ON foobar(id)')
assert db.index_exists('some_index') assert nc.index_exists(db, 'some_index')
assert db.index_exists('some_index', table='foobar') assert nc.index_exists(db, 'some_index', table='foobar')
assert not db.index_exists('some_index', table='bar') assert not nc.index_exists(db, 'some_index', table='bar')
def test_drop_table_existing(db, table_factory): def test_drop_table_existing(db, table_factory):
table_factory('dummy') table_factory('dummy')
assert db.table_exists('dummy') assert nc.table_exists(db, 'dummy')
db.drop_table('dummy') nc.drop_tables(db, 'dummy')
assert not db.table_exists('dummy') assert not nc.table_exists(db, 'dummy')
def test_drop_table_non_existsing(db): def test_drop_table_non_existing(db):
db.drop_table('dfkjgjriogjigjgjrdghehtre') nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre')
def test_drop_many_tables(db, table_factory):
tables = [f'table{n}' for n in range(5)]
for t in tables:
table_factory(t)
assert nc.table_exists(db, t)
nc.drop_tables(db, *tables)
for t in tables:
assert not nc.table_exists(db, t)
def test_drop_table_non_existing_force(db): def test_drop_table_non_existing_force(db):
with pytest.raises(psycopg2.ProgrammingError, match='.*does not exist.*'): with pytest.raises(psycopg2.ProgrammingError, match='.*does not exist.*'):
db.drop_table('dfkjgjriogjigjgjrdghehtre', if_exists=False) nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre', if_exists=False)
def test_connection_server_version_tuple(db): def test_connection_server_version_tuple(db):
ver = db.server_version_tuple() ver = nc.server_version_tuple(db)
assert isinstance(ver, tuple) assert isinstance(ver, tuple)
assert len(ver) == 2 assert len(ver) == 2
@@ -72,7 +85,7 @@ def test_connection_server_version_tuple(db):
def test_connection_postgis_version_tuple(db, temp_db_with_extensions): def test_connection_postgis_version_tuple(db, temp_db_with_extensions):
ver = db.postgis_version_tuple() ver = nc.postgis_version_tuple(db)
assert isinstance(ver, tuple) assert isinstance(ver, tuple)
assert len(ver) == 2 assert len(ver) == 2
@@ -82,27 +95,24 @@ def test_connection_postgis_version_tuple(db, temp_db_with_extensions):
def test_cursor_scalar(db, table_factory): def test_cursor_scalar(db, table_factory):
table_factory('dummy') table_factory('dummy')
with db.cursor() as cur: assert nc.execute_scalar(db, 'SELECT count(*) FROM dummy') == 0
assert cur.scalar('SELECT count(*) FROM dummy') == 0
def test_cursor_scalar_many_rows(db): def test_cursor_scalar_many_rows(db):
with db.cursor() as cur: with pytest.raises(RuntimeError, match='Query did not return a single row.'):
with pytest.raises(RuntimeError): nc.execute_scalar(db, 'SELECT * FROM pg_tables')
cur.scalar('SELECT * FROM pg_tables')
def test_cursor_scalar_no_rows(db, table_factory): def test_cursor_scalar_no_rows(db, table_factory):
table_factory('dummy') table_factory('dummy')
with db.cursor() as cur: with pytest.raises(RuntimeError, match='Query did not return a single row.'):
with pytest.raises(RuntimeError): nc.execute_scalar(db, 'SELECT id FROM dummy')
cur.scalar('SELECT id FROM dummy')
def test_get_pg_env_add_variable(monkeypatch): def test_get_pg_env_add_variable(monkeypatch):
monkeypatch.delenv('PGPASSWORD', raising=False) monkeypatch.delenv('PGPASSWORD', raising=False)
env = get_pg_env('user=fooF') env = nc.get_pg_env('user=fooF')
assert env['PGUSER'] == 'fooF' assert env['PGUSER'] == 'fooF'
assert 'PGPASSWORD' not in env assert 'PGPASSWORD' not in env
@@ -110,12 +120,12 @@ def test_get_pg_env_add_variable(monkeypatch):
def test_get_pg_env_overwrite_variable(monkeypatch): def test_get_pg_env_overwrite_variable(monkeypatch):
monkeypatch.setenv('PGUSER', 'some default') monkeypatch.setenv('PGUSER', 'some default')
env = get_pg_env('user=overwriter') env = nc.get_pg_env('user=overwriter')
assert env['PGUSER'] == 'overwriter' assert env['PGUSER'] == 'overwriter'
def test_get_pg_env_ignore_unknown(): def test_get_pg_env_ignore_unknown():
env = get_pg_env('client_encoding=stuff', base_env={}) env = nc.get_pg_env('client_encoding=stuff', base_env={})
assert env == {} assert env == {}

View File

@@ -8,6 +8,7 @@
Legacy word table for testing with functions to prefil and test contents Legacy word table for testing with functions to prefil and test contents
of the table. of the table.
""" """
from nominatim_db.db.connection import execute_scalar
class MockIcuWordTable: class MockIcuWordTable:
""" A word table for testing using legacy word table structure. """ A word table for testing using legacy word table structure.
@@ -77,18 +78,15 @@ class MockIcuWordTable:
def count(self): def count(self):
with self.conn.cursor() as cur: return execute_scalar(self.conn, "SELECT count(*) FROM word")
return cur.scalar("SELECT count(*) FROM word")
def count_special(self): def count_special(self):
with self.conn.cursor() as cur: return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE type = 'S'")
return cur.scalar("SELECT count(*) FROM word WHERE type = 'S'")
def count_housenumbers(self): def count_housenumbers(self):
with self.conn.cursor() as cur: return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE type = 'H'")
return cur.scalar("SELECT count(*) FROM word WHERE type = 'H'")
def get_special(self): def get_special(self):

View File

@@ -8,6 +8,7 @@
Legacy word table for testing with functions to prefil and test contents Legacy word table for testing with functions to prefil and test contents
of the table. of the table.
""" """
from nominatim_db.db.connection import execute_scalar
class MockLegacyWordTable: class MockLegacyWordTable:
""" A word table for testing using legacy word table structure. """ A word table for testing using legacy word table structure.
@@ -58,13 +59,11 @@ class MockLegacyWordTable:
def count(self): def count(self):
with self.conn.cursor() as cur: return execute_scalar(self.conn, "SELECT count(*) FROM word")
return cur.scalar("SELECT count(*) FROM word")
def count_special(self): def count_special(self):
with self.conn.cursor() as cur: return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE class != 'place'")
return cur.scalar("SELECT count(*) FROM word WHERE class != 'place'")
def get_special(self): def get_special(self):

View File

@@ -199,16 +199,16 @@ def test_update_sql_functions(db_prop, temp_db_cursor,
assert test_content == set((('1133', ), )) assert test_content == set((('1133', ), ))
def test_finalize_import(tokenizer_factory, temp_db_conn, def test_finalize_import(tokenizer_factory, temp_db_cursor,
temp_db_cursor, test_config, sql_preprocessor_cfg): test_config, sql_preprocessor_cfg):
tok = tokenizer_factory() tok = tokenizer_factory()
tok.init_new_db(test_config) tok.init_new_db(test_config)
assert not temp_db_conn.index_exists('idx_word_word_id') assert not temp_db_cursor.index_exists('word', 'idx_word_word_id')
tok.finalize_import(test_config) tok.finalize_import(test_config)
assert temp_db_conn.index_exists('idx_word_word_id') assert temp_db_cursor.index_exists('word', 'idx_word_word_id')
def test_check_database(test_config, tokenizer_factory, def test_check_database(test_config, tokenizer_factory,

View File

@@ -132,7 +132,7 @@ def test_import_osm_data_simple_ignore_no_data(table_factory, osm2pgsql_options)
ignore_errors=True) ignore_errors=True)
def test_import_osm_data_drop(table_factory, temp_db_conn, tmp_path, osm2pgsql_options): def test_import_osm_data_drop(table_factory, temp_db_cursor, tmp_path, osm2pgsql_options):
table_factory('place', content=((1, ), )) table_factory('place', content=((1, ), ))
table_factory('planet_osm_nodes') table_factory('planet_osm_nodes')
@@ -144,7 +144,7 @@ def test_import_osm_data_drop(table_factory, temp_db_conn, tmp_path, osm2pgsql_o
database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options, drop=True) database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options, drop=True)
assert not flatfile.exists() assert not flatfile.exists()
assert not temp_db_conn.table_exists('planet_osm_nodes') assert not temp_db_cursor.table_exists('planet_osm_nodes')
def test_import_osm_data_default_cache(table_factory, osm2pgsql_options, capfd): def test_import_osm_data_default_cache(table_factory, osm2pgsql_options, capfd):

View File

@@ -75,7 +75,8 @@ def test_load_white_and_black_lists(sp_importer):
assert isinstance(black_list, dict) and isinstance(white_list, dict) assert isinstance(black_list, dict) and isinstance(white_list, dict)
def test_create_place_classtype_indexes(temp_db_with_extensions, temp_db_conn, def test_create_place_classtype_indexes(temp_db_with_extensions,
temp_db_conn, temp_db_cursor,
table_factory, sp_importer): table_factory, sp_importer):
""" """
Test that _create_place_classtype_indexes() create the Test that _create_place_classtype_indexes() create the
@@ -88,10 +89,11 @@ def test_create_place_classtype_indexes(temp_db_with_extensions, temp_db_conn,
table_factory(table_name, 'place_id BIGINT, centroid GEOMETRY') table_factory(table_name, 'place_id BIGINT, centroid GEOMETRY')
sp_importer._create_place_classtype_indexes('', phrase_class, phrase_type) sp_importer._create_place_classtype_indexes('', phrase_class, phrase_type)
temp_db_conn.commit()
assert check_placeid_and_centroid_indexes(temp_db_conn, phrase_class, phrase_type) assert check_placeid_and_centroid_indexes(temp_db_cursor, phrase_class, phrase_type)
def test_create_place_classtype_table(temp_db_conn, placex_table, sp_importer): def test_create_place_classtype_table(temp_db_conn, temp_db_cursor, placex_table, sp_importer):
""" """
Test that _create_place_classtype_table() create Test that _create_place_classtype_table() create
the right place_classtype table. the right place_classtype table.
@@ -99,10 +101,12 @@ def test_create_place_classtype_table(temp_db_conn, placex_table, sp_importer):
phrase_class = 'class' phrase_class = 'class'
phrase_type = 'type' phrase_type = 'type'
sp_importer._create_place_classtype_table('', phrase_class, phrase_type) sp_importer._create_place_classtype_table('', phrase_class, phrase_type)
temp_db_conn.commit()
assert check_table_exist(temp_db_conn, phrase_class, phrase_type) assert check_table_exist(temp_db_cursor, phrase_class, phrase_type)
def test_grant_access_to_web_user(temp_db_conn, table_factory, def_config, sp_importer): def test_grant_access_to_web_user(temp_db_conn, temp_db_cursor, table_factory,
def_config, sp_importer):
""" """
Test that _grant_access_to_webuser() give Test that _grant_access_to_webuser() give
right access to the web user. right access to the web user.
@@ -114,12 +118,13 @@ def test_grant_access_to_web_user(temp_db_conn, table_factory, def_config, sp_im
table_factory(table_name) table_factory(table_name)
sp_importer._grant_access_to_webuser(phrase_class, phrase_type) sp_importer._grant_access_to_webuser(phrase_class, phrase_type)
temp_db_conn.commit()
assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, phrase_class, phrase_type) assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, phrase_class, phrase_type)
def test_create_place_classtype_table_and_indexes( def test_create_place_classtype_table_and_indexes(
temp_db_conn, def_config, placex_table, temp_db_cursor, def_config, placex_table,
sp_importer): sp_importer, temp_db_conn):
""" """
Test that _create_place_classtype_table_and_indexes() Test that _create_place_classtype_table_and_indexes()
create the right place_classtype tables and place_id indexes create the right place_classtype tables and place_id indexes
@@ -129,14 +134,15 @@ def test_create_place_classtype_table_and_indexes(
pairs = set([('class1', 'type1'), ('class2', 'type2')]) pairs = set([('class1', 'type1'), ('class2', 'type2')])
sp_importer._create_classtype_table_and_indexes(pairs) sp_importer._create_classtype_table_and_indexes(pairs)
temp_db_conn.commit()
for pair in pairs: for pair in pairs:
assert check_table_exist(temp_db_conn, pair[0], pair[1]) assert check_table_exist(temp_db_cursor, pair[0], pair[1])
assert check_placeid_and_centroid_indexes(temp_db_conn, pair[0], pair[1]) assert check_placeid_and_centroid_indexes(temp_db_cursor, pair[0], pair[1])
assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, pair[0], pair[1]) assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, pair[0], pair[1])
def test_remove_non_existent_tables_from_db(sp_importer, default_phrases, def test_remove_non_existent_tables_from_db(sp_importer, default_phrases,
temp_db_conn): temp_db_conn, temp_db_cursor):
""" """
Check for the remove_non_existent_phrases_from_db() method. Check for the remove_non_existent_phrases_from_db() method.
@@ -159,15 +165,14 @@ def test_remove_non_existent_tables_from_db(sp_importer, default_phrases,
""" """
sp_importer._remove_non_existent_tables_from_db() sp_importer._remove_non_existent_tables_from_db()
temp_db_conn.commit()
# Changes are not committed yet. Use temp_db_conn for checking results. assert temp_db_cursor.row_set(query_tables) \
with temp_db_conn.cursor(cursor_factory=CursorForTesting) as cur:
assert cur.row_set(query_tables) \
== {('place_classtype_testclasstypetable_to_keep', )} == {('place_classtype_testclasstypetable_to_keep', )}
@pytest.mark.parametrize("should_replace", [(True), (False)]) @pytest.mark.parametrize("should_replace", [(True), (False)])
def test_import_phrases(monkeypatch, temp_db_conn, def_config, sp_importer, def test_import_phrases(monkeypatch, temp_db_cursor, def_config, sp_importer,
placex_table, table_factory, tokenizer_mock, placex_table, table_factory, tokenizer_mock,
xml_wiki_content, should_replace): xml_wiki_content, should_replace):
""" """
@@ -193,49 +198,49 @@ def test_import_phrases(monkeypatch, temp_db_conn, def_config, sp_importer,
class_test = 'aerialway' class_test = 'aerialway'
type_test = 'zip_line' type_test = 'zip_line'
assert check_table_exist(temp_db_conn, class_test, type_test) assert check_table_exist(temp_db_cursor, class_test, type_test)
assert check_placeid_and_centroid_indexes(temp_db_conn, class_test, type_test) assert check_placeid_and_centroid_indexes(temp_db_cursor, class_test, type_test)
assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, class_test, type_test) assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, class_test, type_test)
assert check_table_exist(temp_db_conn, 'amenity', 'animal_shelter') assert check_table_exist(temp_db_cursor, 'amenity', 'animal_shelter')
if should_replace: if should_replace:
assert not check_table_exist(temp_db_conn, 'wrong_class', 'wrong_type') assert not check_table_exist(temp_db_cursor, 'wrong_class', 'wrong_type')
assert temp_db_conn.table_exists('place_classtype_amenity_animal_shelter') assert temp_db_cursor.table_exists('place_classtype_amenity_animal_shelter')
if should_replace: if should_replace:
assert not temp_db_conn.table_exists('place_classtype_wrongclass_wrongtype') assert not temp_db_cursor.table_exists('place_classtype_wrongclass_wrongtype')
def check_table_exist(temp_db_conn, phrase_class, phrase_type): def check_table_exist(temp_db_cursor, phrase_class, phrase_type):
""" """
Verify that the place_classtype table exists for the given Verify that the place_classtype table exists for the given
phrase_class and phrase_type. phrase_class and phrase_type.
""" """
return temp_db_conn.table_exists('place_classtype_{}_{}'.format(phrase_class, phrase_type)) return temp_db_cursor.table_exists('place_classtype_{}_{}'.format(phrase_class, phrase_type))
def check_grant_access(temp_db_conn, user, phrase_class, phrase_type): def check_grant_access(temp_db_cursor, user, phrase_class, phrase_type):
""" """
Check that the web user has been granted right access to the Check that the web user has been granted right access to the
place_classtype table of the given phrase_class and phrase_type. place_classtype table of the given phrase_class and phrase_type.
""" """
table_name = 'place_classtype_{}_{}'.format(phrase_class, phrase_type) table_name = 'place_classtype_{}_{}'.format(phrase_class, phrase_type)
with temp_db_conn.cursor() as temp_db_cursor: temp_db_cursor.execute("""
temp_db_cursor.execute(""" SELECT * FROM information_schema.role_table_grants
SELECT * FROM information_schema.role_table_grants WHERE table_name='{}'
WHERE table_name='{}' AND grantee='{}'
AND grantee='{}' AND privilege_type='SELECT'""".format(table_name, user))
AND privilege_type='SELECT'""".format(table_name, user)) return temp_db_cursor.fetchone()
return temp_db_cursor.fetchone()
def check_placeid_and_centroid_indexes(temp_db_conn, phrase_class, phrase_type): def check_placeid_and_centroid_indexes(temp_db_cursor, phrase_class, phrase_type):
""" """
Check that the place_id index and centroid index exist for the Check that the place_id index and centroid index exist for the
place_classtype table of the given phrase_class and phrase_type. place_classtype table of the given phrase_class and phrase_type.
""" """
table_name = 'place_classtype_{}_{}'.format(phrase_class, phrase_type)
index_prefix = 'idx_place_classtype_{}_{}_'.format(phrase_class, phrase_type) index_prefix = 'idx_place_classtype_{}_{}_'.format(phrase_class, phrase_type)
return ( return (
temp_db_conn.index_exists(index_prefix + 'centroid') temp_db_cursor.index_exists(table_name, index_prefix + 'centroid')
and and
temp_db_conn.index_exists(index_prefix + 'place_id') temp_db_cursor.index_exists(table_name, index_prefix + 'place_id')
) )

View File

@@ -12,6 +12,7 @@ import psycopg2.extras
from nominatim_db.tools import migration from nominatim_db.tools import migration
from nominatim_db.errors import UsageError from nominatim_db.errors import UsageError
from nominatim_db.db.connection import server_version_tuple
import nominatim_db.version import nominatim_db.version
from mock_legacy_word_table import MockLegacyWordTable from mock_legacy_word_table import MockLegacyWordTable
@@ -63,7 +64,7 @@ def test_set_up_migration_for_36(temp_db_with_extensions, temp_db_cursor,
WHERE property = 'database_version'""") WHERE property = 'database_version'""")
def test_already_at_version(def_config, property_table): def test_already_at_version(temp_db_with_extensions, def_config, property_table):
property_table.set('database_version', property_table.set('database_version',
str(nominatim_db.version.NOMINATIM_VERSION)) str(nominatim_db.version.NOMINATIM_VERSION))
@@ -71,8 +72,8 @@ def test_already_at_version(def_config, property_table):
assert migration.migrate(def_config, {}) == 0 assert migration.migrate(def_config, {}) == 0
def test_run_single_migration(def_config, temp_db_cursor, property_table, def test_run_single_migration(temp_db_with_extensions, def_config, temp_db_cursor,
monkeypatch, postprocess_mock): property_table, monkeypatch, postprocess_mock):
oldversion = [x for x in nominatim_db.version.NOMINATIM_VERSION] oldversion = [x for x in nominatim_db.version.NOMINATIM_VERSION]
oldversion[0] -= 1 oldversion[0] -= 1
property_table.set('database_version', property_table.set('database_version',
@@ -226,7 +227,7 @@ def test_create_tiger_housenumber_index(temp_db_conn, temp_db_cursor, table_fact
migration.create_tiger_housenumber_index(temp_db_conn) migration.create_tiger_housenumber_index(temp_db_conn)
temp_db_conn.commit() temp_db_conn.commit()
if temp_db_conn.server_version_tuple() >= (11, 0, 0): if server_version_tuple(temp_db_conn) >= (11, 0, 0):
assert temp_db_cursor.index_exists('location_property_tiger', assert temp_db_cursor.index_exists('location_property_tiger',
'idx_location_property_tiger_housenumber_migrated') 'idx_location_property_tiger_housenumber_migrated')

View File

@@ -12,6 +12,7 @@ from pathlib import Path
import pytest import pytest
from nominatim_db.tools import refresh from nominatim_db.tools import refresh
from nominatim_db.db.connection import postgis_version_tuple
def test_refresh_import_wikipedia_not_existing(dsn): def test_refresh_import_wikipedia_not_existing(dsn):
assert refresh.import_wikipedia_articles(dsn, Path('.')) == 1 assert refresh.import_wikipedia_articles(dsn, Path('.')) == 1
@@ -23,13 +24,13 @@ def test_refresh_import_secondary_importance_non_existing(dsn):
def test_refresh_import_secondary_importance_testdb(dsn, src_dir, temp_db_conn, temp_db_cursor): def test_refresh_import_secondary_importance_testdb(dsn, src_dir, temp_db_conn, temp_db_cursor):
temp_db_cursor.execute('CREATE EXTENSION postgis') temp_db_cursor.execute('CREATE EXTENSION postgis')
if temp_db_conn.postgis_version_tuple()[0] < 3: if postgis_version_tuple(temp_db_conn)[0] < 3:
assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') > 0 assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') > 0
else: else:
temp_db_cursor.execute('CREATE EXTENSION postgis_raster') temp_db_cursor.execute('CREATE EXTENSION postgis_raster')
assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') == 0 assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') == 0
assert temp_db_conn.table_exists('secondary_importance') assert temp_db_cursor.table_exists('secondary_importance')
@pytest.mark.parametrize("replace", (True, False)) @pytest.mark.parametrize("replace", (True, False))

View File

@@ -12,6 +12,7 @@ from textwrap import dedent
import pytest import pytest
from nominatim_db.db.connection import execute_scalar
from nominatim_db.tools import tiger_data, freeze from nominatim_db.tools import tiger_data, freeze
from nominatim_db.errors import UsageError from nominatim_db.errors import UsageError
@@ -31,8 +32,7 @@ class MockTigerTable:
cur.execute("CREATE TABLE place (number INTEGER)") cur.execute("CREATE TABLE place (number INTEGER)")
def count(self): def count(self):
with self.conn.cursor() as cur: return execute_scalar(self.conn, "SELECT count(*) FROM tiger")
return cur.scalar("SELECT count(*) FROM tiger")
def row(self): def row(self):
with self.conn.cursor() as cur: with self.conn.cursor() as cur: