add type annotations for indexer

This commit is contained in:
Sarah Hoffmann
2022-07-12 18:40:51 +02:00
parent 8adab2c6ca
commit 5617bffe2f
6 changed files with 106 additions and 78 deletions

View File

@@ -7,15 +7,18 @@
"""
Main work horse for indexing (computing addresses) the database.
"""
from typing import Optional, Any, cast
import logging
import time
import psycopg2.extras
from nominatim.tokenizer.base import AbstractTokenizer
from nominatim.indexer.progress import ProgressLogger
from nominatim.indexer import runners
from nominatim.db.async_connection import DBConnection, WorkerPool
from nominatim.db.connection import connect
from nominatim.db.connection import connect, Connection, Cursor
from nominatim.typing import DictCursorResults
LOG = logging.getLogger()
@@ -23,10 +26,11 @@ LOG = logging.getLogger()
class PlaceFetcher:
""" Asynchronous connection that fetches place details for processing.
"""
def __init__(self, dsn, setup_conn):
self.wait_time = 0
self.current_ids = None
self.conn = DBConnection(dsn, cursor_factory=psycopg2.extras.DictCursor)
def __init__(self, dsn: str, setup_conn: Connection) -> None:
self.wait_time = 0.0
self.current_ids: Optional[DictCursorResults] = None
self.conn: Optional[DBConnection] = DBConnection(dsn,
cursor_factory=psycopg2.extras.DictCursor)
with setup_conn.cursor() as cur:
# need to fetch those manually because register_hstore cannot
@@ -37,7 +41,7 @@ class PlaceFetcher:
psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
array_oid=hstore_array_oid)
def close(self):
def close(self) -> None:
""" Close the underlying asynchronous connection.
"""
if self.conn:
@@ -45,44 +49,46 @@ class PlaceFetcher:
self.conn = None
def fetch_next_batch(self, cur, runner):
def fetch_next_batch(self, cur: Cursor, runner: runners.Runner) -> bool:
""" Send a request for the next batch of places.
If details for the places are required, they will be fetched
asynchronously.
Returns true if there is still data available.
"""
ids = cur.fetchmany(100)
ids = cast(Optional[DictCursorResults], cur.fetchmany(100))
if not ids:
self.current_ids = None
return False
if hasattr(runner, 'get_place_details'):
runner.get_place_details(self.conn, ids)
self.current_ids = []
else:
self.current_ids = ids
assert self.conn is not None
self.current_ids = runner.get_place_details(self.conn, ids)
return True
def get_batch(self):
def get_batch(self) -> DictCursorResults:
""" Get the next batch of data, previously requested with
`fetch_next_batch`.
"""
assert self.conn is not None
assert self.conn.cursor is not None
if self.current_ids is not None and not self.current_ids:
tstart = time.time()
self.conn.wait()
self.wait_time += time.time() - tstart
self.current_ids = self.conn.cursor.fetchall()
self.current_ids = cast(Optional[DictCursorResults],
self.conn.cursor.fetchall())
return self.current_ids
return self.current_ids if self.current_ids is not None else []
def __enter__(self):
def __enter__(self) -> 'PlaceFetcher':
return self
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
assert self.conn is not None
self.conn.wait()
self.close()
@@ -91,13 +97,13 @@ class Indexer:
""" Main indexing routine.
"""
def __init__(self, dsn, tokenizer, num_threads):
def __init__(self, dsn: str, tokenizer: AbstractTokenizer, num_threads: int):
self.dsn = dsn
self.tokenizer = tokenizer
self.num_threads = num_threads
def has_pending(self):
def has_pending(self) -> bool:
""" Check if any data still needs indexing.
This function must only be used after the import has finished.
Otherwise it will be very expensive.
@@ -108,7 +114,7 @@ class Indexer:
return cur.rowcount > 0
def index_full(self, analyse=True):
def index_full(self, analyse: bool = True) -> None:
""" Index the complete database. This will first index boundaries
followed by all other objects. When `analyse` is True, then the
database will be analysed at the appropriate places to
@@ -117,7 +123,7 @@ class Indexer:
with connect(self.dsn) as conn:
conn.autocommit = True
def _analyze():
def _analyze() -> None:
if analyse:
with conn.cursor() as cur:
cur.execute('ANALYZE')
@@ -138,7 +144,7 @@ class Indexer:
_analyze()
def index_boundaries(self, minrank, maxrank):
def index_boundaries(self, minrank: int, maxrank: int) -> None:
""" Index only administrative boundaries within the given rank range.
"""
LOG.warning("Starting indexing boundaries using %s threads",
@@ -148,7 +154,7 @@ class Indexer:
for rank in range(max(minrank, 4), min(maxrank, 26)):
self._index(runners.BoundaryRunner(rank, analyzer))
def index_by_rank(self, minrank, maxrank):
def index_by_rank(self, minrank: int, maxrank: int) -> None:
""" Index all entries of placex in the given rank range (inclusive)
in order of their address rank.
@@ -168,7 +174,7 @@ class Indexer:
self._index(runners.InterpolationRunner(analyzer), 20)
def index_postcodes(self):
def index_postcodes(self) -> None:
"""Index the entries ofthe location_postcode table.
"""
LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
@@ -176,7 +182,7 @@ class Indexer:
self._index(runners.PostcodeRunner(), 20)
def update_status_table(self):
def update_status_table(self) -> None:
""" Update the status in the status table to 'indexed'.
"""
with connect(self.dsn) as conn:
@@ -185,7 +191,7 @@ class Indexer:
conn.commit()
def _index(self, runner, batch=1):
def _index(self, runner: runners.Runner, batch: int = 1) -> None:
""" Index a single rank or table. `runner` describes the SQL to use
for indexing. `batch` describes the number of objects that
should be processed with a single SQL statement