port code to psycopg3

This commit is contained in:
Sarah Hoffmann
2024-07-05 10:43:10 +02:00
parent 3742fa2929
commit 9659afbade
57 changed files with 800 additions and 1330 deletions

View File

@@ -10,8 +10,8 @@ Functions for database analysis and maintenance.
from typing import Optional, Tuple, Any, cast
import logging
from psycopg2.extras import Json
from psycopg2 import DataError
import psycopg
from psycopg.types.json import Json
from ..typing import DictCursorResult
from ..config import Configuration
@@ -59,7 +59,7 @@ def analyse_indexing(config: Configuration, osm_id: Optional[str] = None,
"""
with connect(config.get_libpq_dsn()) as conn:
register_hstore(conn)
with conn.cursor() as cur:
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
place = _get_place_info(cur, osm_id, place_id)
cur.execute("update placex set indexed_status = 2 where place_id = %s",
@@ -74,6 +74,9 @@ def analyse_indexing(config: Configuration, osm_id: Optional[str] = None,
tokenizer = tokenizer_factory.get_tokenizer_for_db(config)
# Enable printing of messages.
conn.add_notice_handler(lambda diag: print(diag.message_primary))
with tokenizer.name_analyzer() as analyzer:
cur.execute("""UPDATE placex
SET indexed_status = 0, address = %s, token_info = %s,
@@ -86,9 +89,6 @@ def analyse_indexing(config: Configuration, osm_id: Optional[str] = None,
# we do not want to keep the results
conn.rollback()
for msg in conn.notices:
print(msg)
def clean_deleted_relations(config: Configuration, age: str) -> None:
""" Clean deleted relations older than a given age
@@ -101,6 +101,6 @@ def clean_deleted_relations(config: Configuration, age: str) -> None:
WHERE p.osm_type = d.osm_type AND p.osm_id = d.osm_id
AND age(p.indexed_date) > %s::interval""",
(age, ))
except DataError as exc:
except psycopg.DataError as exc:
raise UsageError('Invalid PostgreSQL time interval format') from exc
conn.commit()

View File

@@ -81,7 +81,7 @@ def check_database(config: Configuration) -> int:
""" Run a number of checks on the database and return the status.
"""
try:
conn = connect(config.get_libpq_dsn()).connection
conn = connect(config.get_libpq_dsn())
except UsageError as err:
conn = _BadConnection(str(err)) # type: ignore[assignment]

View File

@@ -15,7 +15,6 @@ from pathlib import Path
from typing import List, Optional, Union
import psutil
from psycopg2.extensions import make_dsn
from ..config import Configuration
from ..db.connection import connect, server_version_tuple, execute_scalar
@@ -97,7 +96,7 @@ def report_system_information(config: Configuration) -> None:
"""Generate a report about the host system including software versions, memory,
storage, and database configuration."""
with connect(make_dsn(config.get_libpq_dsn(), dbname='postgres')) as conn:
with connect(config.get_libpq_dsn(), dbname='postgres') as conn:
postgresql_ver: str = '.'.join(map(str, server_version_tuple(conn)))
with conn.cursor() as cur:

View File

@@ -10,19 +10,20 @@ Functions for setting up and importing a new Nominatim database.
from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any
import logging
import os
import selectors
import subprocess
import asyncio
from pathlib import Path
import psutil
from psycopg2 import sql as pysql
import psycopg
from psycopg import sql as pysql
from ..errors import UsageError
from ..config import Configuration
from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\
postgis_version_tuple, drop_tables, table_exists, execute_scalar
from ..db.async_connection import DBConnection
from ..db.sql_preprocessor import SQLPreprocessor
from ..db.query_pool import QueryPool
from .exec_utils import run_osm2pgsql
from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
@@ -136,7 +137,7 @@ def import_osm_data(osm_files: Union[Path, Sequence[Path]],
with connect(options['dsn']) as conn:
if not ignore_errors:
with conn.cursor() as cur:
cur.execute('SELECT * FROM place LIMIT 1')
cur.execute('SELECT true FROM place LIMIT 1')
if cur.rowcount == 0:
raise UsageError('No data imported by osm2pgsql.')
@@ -205,54 +206,51 @@ _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier,
'extratags', 'geometry')))
def load_data(dsn: str, threads: int) -> None:
async def load_data(dsn: str, threads: int) -> None:
""" Copy data into the word and placex table.
"""
sel = selectors.DefaultSelector()
# Then copy data from place to placex in <threads - 1> chunks.
place_threads = max(1, threads - 1)
for imod in range(place_threads):
conn = DBConnection(dsn)
conn.connect()
conn.perform(
pysql.SQL("""INSERT INTO placex ({columns})
SELECT {columns} FROM place
WHERE osm_id % {total} = {mod}
AND NOT (class='place' and (type='houses' or type='postcode'))
AND ST_IsValid(geometry)
""").format(columns=_COPY_COLUMNS,
total=pysql.Literal(place_threads),
mod=pysql.Literal(imod)))
sel.register(conn, selectors.EVENT_READ, conn)
placex_threads = max(1, threads - 1)
# Address interpolations go into another table.
conn = DBConnection(dsn)
conn.connect()
conn.perform("""INSERT INTO location_property_osmline (osm_id, address, linegeo)
SELECT osm_id, address, geometry FROM place
WHERE class='place' and type='houses' and osm_type='W'
and ST_GeometryType(geometry) = 'ST_LineString'
""")
sel.register(conn, selectors.EVENT_READ, conn)
progress = asyncio.create_task(_progress_print())
# Now wait for all of them to finish.
todo = place_threads + 1
while todo > 0:
for key, _ in sel.select(1):
conn = key.data
sel.unregister(conn)
conn.wait()
conn.close()
todo -= 1
async with QueryPool(dsn, placex_threads + 1) as pool:
# Copy data from place to placex in <threads - 1> chunks.
for imod in range(placex_threads):
await pool.put_query(
pysql.SQL("""INSERT INTO placex ({columns})
SELECT {columns} FROM place
WHERE osm_id % {total} = {mod}
AND NOT (class='place'
and (type='houses' or type='postcode'))
AND ST_IsValid(geometry)
""").format(columns=_COPY_COLUMNS,
total=pysql.Literal(placex_threads),
mod=pysql.Literal(imod)), None)
# Interpolations need to be copied seperately
await pool.put_query("""
INSERT INTO location_property_osmline (osm_id, address, linegeo)
SELECT osm_id, address, geometry FROM place
WHERE class='place' and type='houses' and osm_type='W'
and ST_GeometryType(geometry) = 'ST_LineString' """, None)
progress.cancel()
async with await psycopg.AsyncConnection.connect(dsn) as aconn:
await aconn.execute('ANALYSE')
async def _progress_print() -> None:
while True:
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
print('', flush=True)
break
print('.', end='', flush=True)
print('\n')
with connect(dsn) as syn_conn:
with syn_conn.cursor() as cur:
cur.execute('ANALYSE')
def create_search_indices(conn: Connection, config: Configuration,
async def create_search_indices(conn: Connection, config: Configuration,
drop: bool = False, threads: int = 1) -> None:
""" Create tables that have explicit partitioning.
"""
@@ -271,5 +269,5 @@ def create_search_indices(conn: Connection, config: Configuration,
sql = SQLPreprocessor(conn, config)
sql.run_parallel_sql_file(config.get_libpq_dsn(),
'indices.sql', min(8, threads), drop=drop)
await sql.run_parallel_sql_file(config.get_libpq_dsn(),
'indices.sql', min(8, threads), drop=drop)

View File

@@ -10,7 +10,7 @@ Functions for removing unnecessary data from the database.
from typing import Optional
from pathlib import Path
from psycopg2 import sql as pysql
from psycopg import sql as pysql
from ..db.connection import Connection, drop_tables, table_exists

View File

@@ -10,7 +10,7 @@ Functions for database migration to newer software versions.
from typing import List, Tuple, Callable, Any
import logging
from psycopg2 import sql as pysql
from psycopg import sql as pysql
from ..errors import UsageError
from ..config import Configuration

View File

@@ -16,7 +16,7 @@ import gzip
import logging
from math import isfinite
from psycopg2 import sql as pysql
from psycopg import sql as pysql
from ..db.connection import connect, Connection, table_exists
from ..utils.centroid import PointsCentroid
@@ -76,30 +76,30 @@ class _PostcodeCollector:
with conn.cursor() as cur:
if to_add:
cur.execute_values(
cur.executemany(pysql.SQL(
"""INSERT INTO location_postcode
(place_id, indexed_status, country_code,
postcode, geometry) VALUES %s""",
to_add,
template=pysql.SQL("""(nextval('seq_place'), 1, {},
%s, 'SRID=4326;POINT(%s %s)')
""").format(pysql.Literal(self.country)))
postcode, geometry)
VALUES (nextval('seq_place'), 1, {}, %s,
ST_SetSRID(ST_MakePoint(%s, %s), 4326))
""").format(pysql.Literal(self.country)),
to_add)
if to_delete:
cur.execute("""DELETE FROM location_postcode
WHERE country_code = %s and postcode = any(%s)
""", (self.country, to_delete))
if to_update:
cur.execute_values(
cur.executemany(
pysql.SQL("""UPDATE location_postcode
SET indexed_status = 2,
geometry = ST_SetSRID(ST_Point(v.x, v.y), 4326)
FROM (VALUES %s) AS v (pc, x, y)
WHERE country_code = {} and postcode = pc
""").format(pysql.Literal(self.country)), to_update)
geometry = ST_SetSRID(ST_Point(%s, %s), 4326)
WHERE country_code = {} and postcode = %s
""").format(pysql.Literal(self.country)),
to_update)
def _compute_changes(self, conn: Connection) \
-> Tuple[List[Tuple[str, float, float]], List[str], List[Tuple[str, float, float]]]:
-> Tuple[List[Tuple[str, float, float]], List[str], List[Tuple[float, float, str]]]:
""" Compute which postcodes from the collected postcodes have to be
added or modified and which from the location_postcode table
have to be deleted.
@@ -116,7 +116,7 @@ class _PostcodeCollector:
if pcobj:
newx, newy = pcobj.centroid()
if (x - newx) > 0.0000001 or (y - newy) > 0.0000001:
to_update.append((postcode, newx, newy))
to_update.append((newx, newy, postcode))
else:
to_delete.append(postcode)

View File

@@ -14,12 +14,12 @@ import logging
from textwrap import dedent
from pathlib import Path
from psycopg2 import sql as pysql
from psycopg import sql as pysql
from ..config import Configuration
from ..db.connection import Connection, connect, postgis_version_tuple,\
drop_tables, table_exists
from ..db.utils import execute_file, CopyBuffer
from ..db.utils import execute_file
from ..db.sql_preprocessor import SQLPreprocessor
from ..version import NOMINATIM_VERSION
@@ -68,8 +68,8 @@ def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[s
rank_address SMALLINT)
""").format(pysql.Identifier(table)))
cur.execute_values(pysql.SQL("INSERT INTO {} VALUES %s")
.format(pysql.Identifier(table)), rows)
cur.executemany(pysql.SQL("INSERT INTO {} VALUES (%s, %s, %s, %s, %s)")
.format(pysql.Identifier(table)), rows)
cur.execute(pysql.SQL('CREATE UNIQUE INDEX ON {} (country_code, class, type)')
.format(pysql.Identifier(table)))
@@ -155,7 +155,7 @@ def import_importance_csv(dsn: str, data_file: Path) -> int:
if not data_file.exists():
return 1
# Only import the first occurence of a wikidata ID.
# Only import the first occurrence of a wikidata ID.
# This keeps indexes and table small.
wd_done = set()
@@ -169,24 +169,17 @@ def import_importance_csv(dsn: str, data_file: Path) -> int:
wikidata TEXT
) """)
with gzip.open(str(data_file), 'rt') as fd, CopyBuffer() as buf:
for row in csv.DictReader(fd, delimiter='\t', quotechar='|'):
wd_id = int(row['wikidata_id'][1:])
buf.add(row['language'], row['title'], row['importance'],
None if wd_id in wd_done else row['wikidata_id'])
wd_done.add(wd_id)
copy_cmd = """COPY wikimedia_importance(language, title, importance, wikidata)
FROM STDIN"""
with gzip.open(str(data_file), 'rt') as fd, cur.copy(copy_cmd) as copy:
for row in csv.DictReader(fd, delimiter='\t', quotechar='|'):
wd_id = int(row['wikidata_id'][1:])
copy.write_row((row['language'],
row['title'],
row['importance'],
None if wd_id in wd_done else row['wikidata_id']))
wd_done.add(wd_id)
if buf.size() > 10000000:
with conn.cursor() as cur:
buf.copy_out(cur, 'wikimedia_importance',
columns=['language', 'title', 'importance',
'wikidata'])
with conn.cursor() as cur:
buf.copy_out(cur, 'wikimedia_importance',
columns=['language', 'title', 'importance', 'wikidata'])
with conn.cursor() as cur:
cur.execute("""CREATE INDEX IF NOT EXISTS idx_wikimedia_importance_title
ON wikimedia_importance (title)""")
cur.execute("""CREATE INDEX IF NOT EXISTS idx_wikimedia_importance_wikidata

View File

@@ -17,7 +17,7 @@ from typing import Iterable, Tuple, Mapping, Sequence, Optional, Set
import logging
import re
from psycopg2.sql import Identifier, SQL
from psycopg.sql import Identifier, SQL
from ...typing import Protocol
from ...config import Configuration

View File

@@ -7,22 +7,22 @@
"""
Functions for importing tiger data and handling tarbar and directory files
"""
from typing import Any, TextIO, List, Union, cast
from typing import Any, TextIO, List, Union, cast, Iterator, Dict
import csv
import io
import logging
import os
import tarfile
from psycopg2.extras import Json
from psycopg.types.json import Json
from ..config import Configuration
from ..db.connection import connect
from ..db.async_connection import WorkerPool
from ..db.sql_preprocessor import SQLPreprocessor
from ..errors import UsageError
from ..db.query_pool import QueryPool
from ..data.place_info import PlaceInfo
from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer
from ..tokenizer.base import AbstractTokenizer
from . import freeze
LOG = logging.getLogger()
@@ -63,13 +63,13 @@ class TigerInput:
self.tar_handle.close()
self.tar_handle = None
def __bool__(self) -> bool:
return bool(self.files)
def next_file(self) -> TextIO:
def get_file(self, fname: Union[str, tarfile.TarInfo]) -> TextIO:
""" Return a file handle to the next file to be processed.
Raises an IndexError if there is no file left.
"""
fname = self.files.pop(0)
if self.tar_handle is not None:
extracted = self.tar_handle.extractfile(fname)
assert extracted is not None
@@ -78,47 +78,22 @@ class TigerInput:
return open(cast(str, fname), encoding='utf-8')
def __len__(self) -> int:
return len(self.files)
def __iter__(self) -> Iterator[Dict[str, Any]]:
""" Iterate over the lines in each file.
"""
for fname in self.files:
fd = self.get_file(fname)
yield from csv.DictReader(fd, delimiter=';')
def handle_threaded_sql_statements(pool: WorkerPool, fd: TextIO,
analyzer: AbstractAnalyzer) -> None:
""" Handles sql statement with multiplexing
"""
lines = 0
# Using pool of database connections to execute sql statements
sql = "SELECT tiger_line_import(%s, %s, %s, %s, %s, %s)"
for row in csv.DictReader(fd, delimiter=';'):
try:
address = dict(street=row['street'], postcode=row['postcode'])
args = ('SRID=4326;' + row['geometry'],
int(row['from']), int(row['to']), row['interpolation'],
Json(analyzer.process_place(PlaceInfo({'address': address}))),
analyzer.normalize_postcode(row['postcode']))
except ValueError:
continue
pool.next_free_worker().perform(sql, args=args)
lines += 1
if lines == 1000:
print('.', end='', flush=True)
lines = 0
def add_tiger_data(data_dir: str, config: Configuration, threads: int,
async def add_tiger_data(data_dir: str, config: Configuration, threads: int,
tokenizer: AbstractTokenizer) -> int:
""" Import tiger data from directory or tar file `data dir`.
"""
dsn = config.get_libpq_dsn()
with connect(dsn) as conn:
is_frozen = freeze.is_frozen(conn)
conn.close()
if is_frozen:
if freeze.is_frozen(conn):
raise UsageError("Tiger cannot be imported when database frozen (Github issue #3048)")
with TigerInput(data_dir) as tar:
@@ -133,13 +108,30 @@ def add_tiger_data(data_dir: str, config: Configuration, threads: int,
# sql_query in <threads - 1> chunks.
place_threads = max(1, threads - 1)
with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool:
async with QueryPool(dsn, place_threads, autocommit=True) as pool:
with tokenizer.name_analyzer() as analyzer:
while tar:
with tar.next_file() as fd:
handle_threaded_sql_statements(pool, fd, analyzer)
lines = 0
for row in tar:
try:
address = dict(street=row['street'], postcode=row['postcode'])
args = ('SRID=4326;' + row['geometry'],
int(row['from']), int(row['to']), row['interpolation'],
Json(analyzer.process_place(PlaceInfo({'address': address}))),
analyzer.normalize_postcode(row['postcode']))
except ValueError:
continue
print('\n')
await pool.put_query(
"""SELECT tiger_line_import(%s::GEOMETRY, %s::INT,
%s::INT, %s::TEXT, %s::JSONB, %s::TEXT)""",
args)
lines += 1
if lines == 1000:
print('.', end='', flush=True)
lines = 0
print('', flush=True)
LOG.warning("Creating indexes on Tiger data")
with connect(dsn) as conn: