port code to psycopg3

This commit is contained in:
Sarah Hoffmann
2024-07-05 10:43:10 +02:00
parent 3742fa2929
commit 9659afbade
57 changed files with 800 additions and 1330 deletions

View File

@@ -7,92 +7,20 @@
"""
Main work horse for indexing (computing addresses) the database.
"""
from typing import Optional, Any, cast
from typing import cast, List, Any
import logging
import time
import psycopg2.extras
import psycopg
from ..typing import DictCursorResults
from ..db.async_connection import DBConnection, WorkerPool
from ..db.connection import connect, Connection, Cursor, execute_scalar, register_hstore
from ..db.connection import connect, execute_scalar
from ..db.query_pool import QueryPool
from ..tokenizer.base import AbstractTokenizer
from .progress import ProgressLogger
from . import runners
LOG = logging.getLogger()
class PlaceFetcher:
""" Asynchronous connection that fetches place details for processing.
"""
def __init__(self, dsn: str, setup_conn: Connection) -> None:
self.wait_time = 0.0
self.current_ids: Optional[DictCursorResults] = None
self.conn: Optional[DBConnection] = DBConnection(dsn,
cursor_factory=psycopg2.extras.DictCursor)
# need to fetch those manually because register_hstore cannot
# fetch them on an asynchronous connection below.
hstore_oid = execute_scalar(setup_conn, "SELECT 'hstore'::regtype::oid")
hstore_array_oid = execute_scalar(setup_conn, "SELECT 'hstore[]'::regtype::oid")
psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
array_oid=hstore_array_oid)
def close(self) -> None:
""" Close the underlying asynchronous connection.
"""
if self.conn:
self.conn.close()
self.conn = None
def fetch_next_batch(self, cur: Cursor, runner: runners.Runner) -> bool:
""" Send a request for the next batch of places.
If details for the places are required, they will be fetched
asynchronously.
Returns true if there is still data available.
"""
ids = cast(Optional[DictCursorResults], cur.fetchmany(100))
if not ids:
self.current_ids = None
return False
assert self.conn is not None
self.current_ids = runner.get_place_details(self.conn, ids)
return True
def get_batch(self) -> DictCursorResults:
""" Get the next batch of data, previously requested with
`fetch_next_batch`.
"""
assert self.conn is not None
assert self.conn.cursor is not None
if self.current_ids is not None and not self.current_ids:
tstart = time.time()
self.conn.wait()
self.wait_time += time.time() - tstart
self.current_ids = cast(Optional[DictCursorResults],
self.conn.cursor.fetchall())
return self.current_ids if self.current_ids is not None else []
def __enter__(self) -> 'PlaceFetcher':
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
assert self.conn is not None
self.conn.wait()
self.close()
class Indexer:
""" Main indexing routine.
"""
@@ -114,7 +42,7 @@ class Indexer:
return cur.rowcount > 0
def index_full(self, analyse: bool = True) -> None:
async def index_full(self, analyse: bool = True) -> None:
""" Index the complete database. This will first index boundaries
followed by all other objects. When `analyse` is True, then the
database will be analysed at the appropriate places to
@@ -128,23 +56,27 @@ class Indexer:
with conn.cursor() as cur:
cur.execute('ANALYZE')
if self.index_by_rank(0, 4) > 0:
_analyze()
while True:
if await self.index_by_rank(0, 4) > 0:
_analyze()
if self.index_boundaries(0, 30) > 100:
_analyze()
if await self.index_boundaries(0, 30) > 100:
_analyze()
if self.index_by_rank(5, 25) > 100:
_analyze()
if await self.index_by_rank(5, 25) > 100:
_analyze()
if self.index_by_rank(26, 30) > 1000:
_analyze()
if await self.index_by_rank(26, 30) > 1000:
_analyze()
if self.index_postcodes() > 100:
_analyze()
if await self.index_postcodes() > 100:
_analyze()
if not self.has_pending():
break
def index_boundaries(self, minrank: int, maxrank: int) -> int:
async def index_boundaries(self, minrank: int, maxrank: int) -> int:
""" Index only administrative boundaries within the given rank range.
"""
total = 0
@@ -153,11 +85,11 @@ class Indexer:
with self.tokenizer.name_analyzer() as analyzer:
for rank in range(max(minrank, 4), min(maxrank, 26)):
total += self._index(runners.BoundaryRunner(rank, analyzer))
total += await self._index(runners.BoundaryRunner(rank, analyzer))
return total
def index_by_rank(self, minrank: int, maxrank: int) -> int:
async def index_by_rank(self, minrank: int, maxrank: int) -> int:
""" Index all entries of placex in the given rank range (inclusive)
in order of their address rank.
@@ -171,21 +103,27 @@ class Indexer:
with self.tokenizer.name_analyzer() as analyzer:
for rank in range(max(1, minrank), maxrank + 1):
total += self._index(runners.RankRunner(rank, analyzer), 20 if rank == 30 else 1)
if rank >= 30:
batch = 20
elif rank >= 26:
batch = 5
else:
batch = 1
total += await self._index(runners.RankRunner(rank, analyzer), batch)
if maxrank == 30:
total += self._index(runners.RankRunner(0, analyzer))
total += self._index(runners.InterpolationRunner(analyzer), 20)
total += await self._index(runners.RankRunner(0, analyzer))
total += await self._index(runners.InterpolationRunner(analyzer), 20)
return total
def index_postcodes(self) -> int:
async def index_postcodes(self) -> int:
"""Index the entries of the location_postcode table.
"""
LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
return self._index(runners.PostcodeRunner(), 20)
return await self._index(runners.PostcodeRunner(), 20)
def update_status_table(self) -> None:
@@ -197,45 +135,58 @@ class Indexer:
conn.commit()
def _index(self, runner: runners.Runner, batch: int = 1) -> int:
async def _index(self, runner: runners.Runner, batch: int = 1) -> int:
""" Index a single rank or table. `runner` describes the SQL to use
for indexing. `batch` describes the number of objects that
should be processed with a single SQL statement
"""
LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
with connect(self.dsn) as conn:
register_hstore(conn)
total_tuples = execute_scalar(conn, runner.sql_count_objects())
LOG.debug("Total number of rows: %i", total_tuples)
total_tuples = self._prepare_indexing(runner)
conn.commit()
progress = ProgressLogger(runner.name(), total_tuples)
progress = ProgressLogger(runner.name(), total_tuples)
if total_tuples > 0:
async with await psycopg.AsyncConnection.connect(
self.dsn, row_factory=psycopg.rows.dict_row) as aconn,\
QueryPool(self.dsn, self.num_threads, autocommit=True) as pool:
fetcher_time = 0.0
tstart = time.time()
async with aconn.cursor(name='places') as cur:
query = runner.index_places_query(batch)
params: List[Any] = []
num_places = 0
async for place in cur.stream(runner.sql_get_objects()):
fetcher_time += time.time() - tstart
if total_tuples > 0:
with conn.cursor(name='places') as cur:
cur.execute(runner.sql_get_objects())
params.extend(runner.index_places_params(place))
num_places += 1
with PlaceFetcher(self.dsn, conn) as fetcher:
with WorkerPool(self.dsn, self.num_threads) as pool:
has_more = fetcher.fetch_next_batch(cur, runner)
while has_more:
places = fetcher.get_batch()
if num_places >= batch:
LOG.debug("Processing places: %s", str(params))
await pool.put_query(query, params)
progress.add(num_places)
params = []
num_places = 0
# asynchronously get the next batch
has_more = fetcher.fetch_next_batch(cur, runner)
tstart = time.time()
# And insert the current batch
for idx in range(0, len(places), batch):
part = places[idx:idx + batch]
LOG.debug("Processing places: %s", str(part))
runner.index_places(pool.next_free_worker(), part)
progress.add(len(part))
if num_places > 0:
await pool.put_query(runner.index_places_query(num_places), params)
LOG.info("Wait time: fetcher: %.2fs, pool: %.2fs",
fetcher.wait_time, pool.wait_time)
conn.commit()
LOG.info("Wait time: fetcher: %.2fs, pool: %.2fs",
fetcher_time, pool.wait_time)
return progress.done()
def _prepare_indexing(self, runner: runners.Runner) -> int:
with connect(self.dsn) as conn:
hstore_info = psycopg.types.TypeInfo.fetch(conn, "hstore")
if hstore_info is None:
raise RuntimeError('Hstore extension is requested but not installed.')
psycopg.types.hstore.register_hstore(hstore_info)
total_tuples = execute_scalar(conn, runner.sql_count_objects())
LOG.debug("Total number of rows: %i", total_tuples)
return cast(int, total_tuples)

View File

@@ -8,14 +8,14 @@
Mix-ins that provide the actual commands for the indexer for various indexing
tasks.
"""
from typing import Any, List
import functools
from typing import Any, Sequence
from psycopg2 import sql as pysql
import psycopg2.extras
from psycopg import sql as pysql
from psycopg.abc import Query
from psycopg.rows import DictRow
from psycopg.types.json import Json
from ..typing import Query, DictCursorResult, DictCursorResults, Protocol
from ..db.async_connection import DBConnection
from ..typing import Protocol
from ..data.place_info import PlaceInfo
from ..tokenizer.base import AbstractAnalyzer
@@ -24,58 +24,48 @@ from ..tokenizer.base import AbstractAnalyzer
def _mk_valuelist(template: str, num: int) -> pysql.Composed:
return pysql.SQL(',').join([pysql.SQL(template)] * num)
def _analyze_place(place: DictCursorResult, analyzer: AbstractAnalyzer) -> psycopg2.extras.Json:
return psycopg2.extras.Json(analyzer.process_place(PlaceInfo(place)))
def _analyze_place(place: DictRow, analyzer: AbstractAnalyzer) -> Json:
return Json(analyzer.process_place(PlaceInfo(place)))
class Runner(Protocol):
def name(self) -> str: ...
def sql_count_objects(self) -> Query: ...
def sql_get_objects(self) -> Query: ...
def get_place_details(self, worker: DBConnection,
ids: DictCursorResults) -> DictCursorResults: ...
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None: ...
def index_places_query(self, batch_size: int) -> Query: ...
def index_places_params(self, place: DictRow) -> Sequence[Any]: ...
SELECT_SQL = pysql.SQL("""SELECT place_id, extra.*
FROM (SELECT * FROM placex {}) as px,
LATERAL placex_indexing_prepare(px) as extra """)
UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)"
class AbstractPlacexRunner:
""" Returns SQL commands for indexing of the placex table.
"""
SELECT_SQL = pysql.SQL('SELECT place_id FROM placex ')
UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)"
def __init__(self, rank: int, analyzer: AbstractAnalyzer) -> None:
self.rank = rank
self.analyzer = analyzer
@functools.lru_cache(maxsize=1)
def _index_sql(self, num_places: int) -> pysql.Composed:
def index_places_query(self, batch_size: int) -> Query:
return pysql.SQL(
""" UPDATE placex
SET indexed_status = 0, address = v.addr, token_info = v.ti,
name = v.name, linked_place_id = v.linked_place_id
FROM (VALUES {}) as v(id, name, addr, linked_place_id, ti)
WHERE place_id = v.id
""").format(_mk_valuelist(AbstractPlacexRunner.UPDATE_LINE, num_places))
""").format(_mk_valuelist(UPDATE_LINE, batch_size))
def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
worker.perform("""SELECT place_id, extra.*
FROM placex, LATERAL placex_indexing_prepare(placex) as extra
WHERE place_id IN %s""",
(tuple((p[0] for p in ids)), ))
return []
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
values: List[Any] = []
for place in places:
for field in ('place_id', 'name', 'address', 'linked_place_id'):
values.append(place[field])
values.append(_analyze_place(place, self.analyzer))
worker.perform(self._index_sql(len(places)), values)
def index_places_params(self, place: DictRow) -> Sequence[Any]:
return (place['place_id'],
place['name'],
place['address'],
place['linked_place_id'],
_analyze_place(place, self.analyzer))
class RankRunner(AbstractPlacexRunner):
@@ -91,10 +81,10 @@ class RankRunner(AbstractPlacexRunner):
""").format(pysql.Literal(self.rank))
def sql_get_objects(self) -> pysql.Composed:
return self.SELECT_SQL + pysql.SQL(
"""WHERE indexed_status > 0 and rank_address = {}
ORDER BY geometry_sector
""").format(pysql.Literal(self.rank))
return SELECT_SQL.format(pysql.SQL(
"""WHERE placex.indexed_status > 0 and placex.rank_address = {}
ORDER BY placex.geometry_sector
""").format(pysql.Literal(self.rank)))
class BoundaryRunner(AbstractPlacexRunner):
@@ -105,19 +95,19 @@ class BoundaryRunner(AbstractPlacexRunner):
def name(self) -> str:
return f"boundaries rank {self.rank}"
def sql_count_objects(self) -> pysql.Composed:
def sql_count_objects(self) -> Query:
return pysql.SQL("""SELECT count(*) FROM placex
WHERE indexed_status > 0
AND rank_search = {}
AND class = 'boundary' and type = 'administrative'
""").format(pysql.Literal(self.rank))
def sql_get_objects(self) -> pysql.Composed:
return self.SELECT_SQL + pysql.SQL(
"""WHERE indexed_status > 0 and rank_search = {}
and class = 'boundary' and type = 'administrative'
ORDER BY partition, admin_level
""").format(pysql.Literal(self.rank))
def sql_get_objects(self) -> Query:
return SELECT_SQL.format(pysql.SQL(
"""WHERE placex.indexed_status > 0 and placex.rank_search = {}
and placex.class = 'boundary' and placex.type = 'administrative'
ORDER BY placex.partition, placex.admin_level
""").format(pysql.Literal(self.rank)))
class InterpolationRunner:
@@ -132,40 +122,29 @@ class InterpolationRunner:
def name(self) -> str:
return "interpolation lines (location_property_osmline)"
def sql_count_objects(self) -> str:
def sql_count_objects(self) -> Query:
return """SELECT count(*) FROM location_property_osmline
WHERE indexed_status > 0"""
def sql_get_objects(self) -> str:
return """SELECT place_id
def sql_get_objects(self) -> Query:
return """SELECT place_id, get_interpolation_address(address, osm_id) as address
FROM location_property_osmline
WHERE indexed_status > 0
ORDER BY geometry_sector"""
def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
worker.perform("""SELECT place_id, get_interpolation_address(address, osm_id) as address
FROM location_property_osmline WHERE place_id IN %s""",
(tuple((p[0] for p in ids)), ))
return []
@functools.lru_cache(maxsize=1)
def _index_sql(self, num_places: int) -> pysql.Composed:
def index_places_query(self, batch_size: int) -> Query:
return pysql.SQL("""UPDATE location_property_osmline
SET indexed_status = 0, address = v.addr, token_info = v.ti
FROM (VALUES {}) as v(id, addr, ti)
WHERE place_id = v.id
""").format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", num_places))
""").format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", batch_size))
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
values: List[Any] = []
for place in places:
values.extend((place[x] for x in ('place_id', 'address')))
values.append(_analyze_place(place, self.analyzer))
worker.perform(self._index_sql(len(places)), values)
def index_places_params(self, place: DictRow) -> Sequence[Any]:
return (place['place_id'], place['address'],
_analyze_place(place, self.analyzer))
@@ -177,20 +156,21 @@ class PostcodeRunner(Runner):
return "postcodes (location_postcode)"
def sql_count_objects(self) -> str:
def sql_count_objects(self) -> Query:
return 'SELECT count(*) FROM location_postcode WHERE indexed_status > 0'
def sql_get_objects(self) -> str:
def sql_get_objects(self) -> Query:
return """SELECT place_id FROM location_postcode
WHERE indexed_status > 0
ORDER BY country_code, postcode"""
def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
return ids
def index_places_query(self, batch_size: int) -> Query:
return pysql.SQL("""UPDATE location_postcode SET indexed_status = 0
WHERE place_id IN ({})""")\
.format(pysql.SQL(',').join((pysql.Placeholder() for _ in range(batch_size))))
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
worker.perform(pysql.SQL("""UPDATE location_postcode SET indexed_status = 0
WHERE place_id IN ({})""")
.format(pysql.SQL(',').join((pysql.Literal(i[0]) for i in places))))
def index_places_params(self, place: DictRow) -> Sequence[Any]:
return (place['place_id'], )