port code to psycopg3

This commit is contained in:
Sarah Hoffmann
2024-07-05 10:43:10 +02:00
parent 3742fa2929
commit 9659afbade
57 changed files with 800 additions and 1330 deletions

View File

@@ -7,92 +7,20 @@
"""
Main work horse for indexing (computing addresses) the database.
"""
from typing import Optional, Any, cast
from typing import cast, List, Any
import logging
import time
import psycopg2.extras
import psycopg
from ..typing import DictCursorResults
from ..db.async_connection import DBConnection, WorkerPool
from ..db.connection import connect, Connection, Cursor, execute_scalar, register_hstore
from ..db.connection import connect, execute_scalar
from ..db.query_pool import QueryPool
from ..tokenizer.base import AbstractTokenizer
from .progress import ProgressLogger
from . import runners
LOG = logging.getLogger()
class PlaceFetcher:
""" Asynchronous connection that fetches place details for processing.
"""
def __init__(self, dsn: str, setup_conn: Connection) -> None:
self.wait_time = 0.0
self.current_ids: Optional[DictCursorResults] = None
self.conn: Optional[DBConnection] = DBConnection(dsn,
cursor_factory=psycopg2.extras.DictCursor)
# need to fetch those manually because register_hstore cannot
# fetch them on an asynchronous connection below.
hstore_oid = execute_scalar(setup_conn, "SELECT 'hstore'::regtype::oid")
hstore_array_oid = execute_scalar(setup_conn, "SELECT 'hstore[]'::regtype::oid")
psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
array_oid=hstore_array_oid)
def close(self) -> None:
""" Close the underlying asynchronous connection.
"""
if self.conn:
self.conn.close()
self.conn = None
def fetch_next_batch(self, cur: Cursor, runner: runners.Runner) -> bool:
""" Send a request for the next batch of places.
If details for the places are required, they will be fetched
asynchronously.
Returns true if there is still data available.
"""
ids = cast(Optional[DictCursorResults], cur.fetchmany(100))
if not ids:
self.current_ids = None
return False
assert self.conn is not None
self.current_ids = runner.get_place_details(self.conn, ids)
return True
def get_batch(self) -> DictCursorResults:
""" Get the next batch of data, previously requested with
`fetch_next_batch`.
"""
assert self.conn is not None
assert self.conn.cursor is not None
if self.current_ids is not None and not self.current_ids:
tstart = time.time()
self.conn.wait()
self.wait_time += time.time() - tstart
self.current_ids = cast(Optional[DictCursorResults],
self.conn.cursor.fetchall())
return self.current_ids if self.current_ids is not None else []
def __enter__(self) -> 'PlaceFetcher':
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
assert self.conn is not None
self.conn.wait()
self.close()
class Indexer:
""" Main indexing routine.
"""
@@ -114,7 +42,7 @@ class Indexer:
return cur.rowcount > 0
def index_full(self, analyse: bool = True) -> None:
async def index_full(self, analyse: bool = True) -> None:
""" Index the complete database. This will first index boundaries
followed by all other objects. When `analyse` is True, then the
database will be analysed at the appropriate places to
@@ -128,23 +56,27 @@ class Indexer:
with conn.cursor() as cur:
cur.execute('ANALYZE')
if self.index_by_rank(0, 4) > 0:
_analyze()
while True:
if await self.index_by_rank(0, 4) > 0:
_analyze()
if self.index_boundaries(0, 30) > 100:
_analyze()
if await self.index_boundaries(0, 30) > 100:
_analyze()
if self.index_by_rank(5, 25) > 100:
_analyze()
if await self.index_by_rank(5, 25) > 100:
_analyze()
if self.index_by_rank(26, 30) > 1000:
_analyze()
if await self.index_by_rank(26, 30) > 1000:
_analyze()
if self.index_postcodes() > 100:
_analyze()
if await self.index_postcodes() > 100:
_analyze()
if not self.has_pending():
break
def index_boundaries(self, minrank: int, maxrank: int) -> int:
async def index_boundaries(self, minrank: int, maxrank: int) -> int:
""" Index only administrative boundaries within the given rank range.
"""
total = 0
@@ -153,11 +85,11 @@ class Indexer:
with self.tokenizer.name_analyzer() as analyzer:
for rank in range(max(minrank, 4), min(maxrank, 26)):
total += self._index(runners.BoundaryRunner(rank, analyzer))
total += await self._index(runners.BoundaryRunner(rank, analyzer))
return total
def index_by_rank(self, minrank: int, maxrank: int) -> int:
async def index_by_rank(self, minrank: int, maxrank: int) -> int:
""" Index all entries of placex in the given rank range (inclusive)
in order of their address rank.
@@ -171,21 +103,27 @@ class Indexer:
with self.tokenizer.name_analyzer() as analyzer:
for rank in range(max(1, minrank), maxrank + 1):
total += self._index(runners.RankRunner(rank, analyzer), 20 if rank == 30 else 1)
if rank >= 30:
batch = 20
elif rank >= 26:
batch = 5
else:
batch = 1
total += await self._index(runners.RankRunner(rank, analyzer), batch)
if maxrank == 30:
total += self._index(runners.RankRunner(0, analyzer))
total += self._index(runners.InterpolationRunner(analyzer), 20)
total += await self._index(runners.RankRunner(0, analyzer))
total += await self._index(runners.InterpolationRunner(analyzer), 20)
return total
def index_postcodes(self) -> int:
async def index_postcodes(self) -> int:
"""Index the entries of the location_postcode table.
"""
LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
return self._index(runners.PostcodeRunner(), 20)
return await self._index(runners.PostcodeRunner(), 20)
def update_status_table(self) -> None:
@@ -197,45 +135,58 @@ class Indexer:
conn.commit()
def _index(self, runner: runners.Runner, batch: int = 1) -> int:
async def _index(self, runner: runners.Runner, batch: int = 1) -> int:
""" Index a single rank or table. `runner` describes the SQL to use
for indexing. `batch` describes the number of objects that
should be processed with a single SQL statement
"""
LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
with connect(self.dsn) as conn:
register_hstore(conn)
total_tuples = execute_scalar(conn, runner.sql_count_objects())
LOG.debug("Total number of rows: %i", total_tuples)
total_tuples = self._prepare_indexing(runner)
conn.commit()
progress = ProgressLogger(runner.name(), total_tuples)
progress = ProgressLogger(runner.name(), total_tuples)
if total_tuples > 0:
async with await psycopg.AsyncConnection.connect(
self.dsn, row_factory=psycopg.rows.dict_row) as aconn,\
QueryPool(self.dsn, self.num_threads, autocommit=True) as pool:
fetcher_time = 0.0
tstart = time.time()
async with aconn.cursor(name='places') as cur:
query = runner.index_places_query(batch)
params: List[Any] = []
num_places = 0
async for place in cur.stream(runner.sql_get_objects()):
fetcher_time += time.time() - tstart
if total_tuples > 0:
with conn.cursor(name='places') as cur:
cur.execute(runner.sql_get_objects())
params.extend(runner.index_places_params(place))
num_places += 1
with PlaceFetcher(self.dsn, conn) as fetcher:
with WorkerPool(self.dsn, self.num_threads) as pool:
has_more = fetcher.fetch_next_batch(cur, runner)
while has_more:
places = fetcher.get_batch()
if num_places >= batch:
LOG.debug("Processing places: %s", str(params))
await pool.put_query(query, params)
progress.add(num_places)
params = []
num_places = 0
# asynchronously get the next batch
has_more = fetcher.fetch_next_batch(cur, runner)
tstart = time.time()
# And insert the current batch
for idx in range(0, len(places), batch):
part = places[idx:idx + batch]
LOG.debug("Processing places: %s", str(part))
runner.index_places(pool.next_free_worker(), part)
progress.add(len(part))
if num_places > 0:
await pool.put_query(runner.index_places_query(num_places), params)
LOG.info("Wait time: fetcher: %.2fs, pool: %.2fs",
fetcher.wait_time, pool.wait_time)
conn.commit()
LOG.info("Wait time: fetcher: %.2fs, pool: %.2fs",
fetcher_time, pool.wait_time)
return progress.done()
def _prepare_indexing(self, runner: runners.Runner) -> int:
with connect(self.dsn) as conn:
hstore_info = psycopg.types.TypeInfo.fetch(conn, "hstore")
if hstore_info is None:
raise RuntimeError('Hstore extension is requested but not installed.')
psycopg.types.hstore.register_hstore(hstore_info)
total_tuples = execute_scalar(conn, runner.sql_count_objects())
LOG.debug("Total number of rows: %i", total_tuples)
return cast(int, total_tuples)