add type annotations for indexer

This commit is contained in:
Sarah Hoffmann
2022-07-12 18:40:51 +02:00
parent 8adab2c6ca
commit 5617bffe2f
6 changed files with 106 additions and 78 deletions

View File

@@ -6,7 +6,7 @@
# For a full list of authors see the git log. # For a full list of authors see the git log.
""" Non-blocking database connections. """ Non-blocking database connections.
""" """
from typing import Callable, Any, Optional, List, Iterator from typing import Callable, Any, Optional, Iterator, Sequence
import logging import logging
import select import select
import time import time
@@ -22,7 +22,7 @@ try:
except ImportError: except ImportError:
__has_psycopg2_errors__ = False __has_psycopg2_errors__ = False
from nominatim.typing import T_cursor from nominatim.typing import T_cursor, Query
LOG = logging.getLogger() LOG = logging.getLogger()
@@ -65,8 +65,8 @@ class DBConnection:
ignore_sql_errors: bool = False) -> None: ignore_sql_errors: bool = False) -> None:
self.dsn = dsn self.dsn = dsn
self.current_query: Optional[str] = None self.current_query: Optional[Query] = None
self.current_params: Optional[List[Any]] = None self.current_params: Optional[Sequence[Any]] = None
self.ignore_sql_errors = ignore_sql_errors self.ignore_sql_errors = ignore_sql_errors
self.conn: Optional['psycopg2.connection'] = None self.conn: Optional['psycopg2.connection'] = None
@@ -128,7 +128,7 @@ class DBConnection:
self.current_query = None self.current_query = None
return return
def perform(self, sql: str, args: Optional[List[Any]] = None) -> None: def perform(self, sql: Query, args: Optional[Sequence[Any]] = None) -> None:
""" Send SQL query to the server. Returns immediately without """ Send SQL query to the server. Returns immediately without
blocking. blocking.
""" """

View File

@@ -74,7 +74,7 @@ class Cursor(psycopg2.extras.DictCursor):
if cascade: if cascade:
sql += ' CASCADE' sql += ' CASCADE'
self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore[no-untyped-call] self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
class Connection(psycopg2.extensions.connection): class Connection(psycopg2.extensions.connection):

View File

@@ -7,15 +7,18 @@
""" """
Main work horse for indexing (computing addresses) the database. Main work horse for indexing (computing addresses) the database.
""" """
from typing import Optional, Any, cast
import logging import logging
import time import time
import psycopg2.extras import psycopg2.extras
from nominatim.tokenizer.base import AbstractTokenizer
from nominatim.indexer.progress import ProgressLogger from nominatim.indexer.progress import ProgressLogger
from nominatim.indexer import runners from nominatim.indexer import runners
from nominatim.db.async_connection import DBConnection, WorkerPool from nominatim.db.async_connection import DBConnection, WorkerPool
from nominatim.db.connection import connect from nominatim.db.connection import connect, Connection, Cursor
from nominatim.typing import DictCursorResults
LOG = logging.getLogger() LOG = logging.getLogger()
@@ -23,10 +26,11 @@ LOG = logging.getLogger()
class PlaceFetcher: class PlaceFetcher:
""" Asynchronous connection that fetches place details for processing. """ Asynchronous connection that fetches place details for processing.
""" """
def __init__(self, dsn, setup_conn): def __init__(self, dsn: str, setup_conn: Connection) -> None:
self.wait_time = 0 self.wait_time = 0.0
self.current_ids = None self.current_ids: Optional[DictCursorResults] = None
self.conn = DBConnection(dsn, cursor_factory=psycopg2.extras.DictCursor) self.conn: Optional[DBConnection] = DBConnection(dsn,
cursor_factory=psycopg2.extras.DictCursor)
with setup_conn.cursor() as cur: with setup_conn.cursor() as cur:
# need to fetch those manually because register_hstore cannot # need to fetch those manually because register_hstore cannot
@@ -37,7 +41,7 @@ class PlaceFetcher:
psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid, psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
array_oid=hstore_array_oid) array_oid=hstore_array_oid)
def close(self): def close(self) -> None:
""" Close the underlying asynchronous connection. """ Close the underlying asynchronous connection.
""" """
if self.conn: if self.conn:
@@ -45,44 +49,46 @@ class PlaceFetcher:
self.conn = None self.conn = None
def fetch_next_batch(self, cur, runner): def fetch_next_batch(self, cur: Cursor, runner: runners.Runner) -> bool:
""" Send a request for the next batch of places. """ Send a request for the next batch of places.
If details for the places are required, they will be fetched If details for the places are required, they will be fetched
asynchronously. asynchronously.
Returns true if there is still data available. Returns true if there is still data available.
""" """
ids = cur.fetchmany(100) ids = cast(Optional[DictCursorResults], cur.fetchmany(100))
if not ids: if not ids:
self.current_ids = None self.current_ids = None
return False return False
if hasattr(runner, 'get_place_details'): assert self.conn is not None
runner.get_place_details(self.conn, ids) self.current_ids = runner.get_place_details(self.conn, ids)
self.current_ids = []
else:
self.current_ids = ids
return True return True
def get_batch(self): def get_batch(self) -> DictCursorResults:
""" Get the next batch of data, previously requested with """ Get the next batch of data, previously requested with
`fetch_next_batch`. `fetch_next_batch`.
""" """
assert self.conn is not None
assert self.conn.cursor is not None
if self.current_ids is not None and not self.current_ids: if self.current_ids is not None and not self.current_ids:
tstart = time.time() tstart = time.time()
self.conn.wait() self.conn.wait()
self.wait_time += time.time() - tstart self.wait_time += time.time() - tstart
self.current_ids = self.conn.cursor.fetchall() self.current_ids = cast(Optional[DictCursorResults],
self.conn.cursor.fetchall())
return self.current_ids return self.current_ids if self.current_ids is not None else []
def __enter__(self): def __enter__(self) -> 'PlaceFetcher':
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
assert self.conn is not None
self.conn.wait() self.conn.wait()
self.close() self.close()
@@ -91,13 +97,13 @@ class Indexer:
""" Main indexing routine. """ Main indexing routine.
""" """
def __init__(self, dsn, tokenizer, num_threads): def __init__(self, dsn: str, tokenizer: AbstractTokenizer, num_threads: int):
self.dsn = dsn self.dsn = dsn
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.num_threads = num_threads self.num_threads = num_threads
def has_pending(self): def has_pending(self) -> bool:
""" Check if any data still needs indexing. """ Check if any data still needs indexing.
This function must only be used after the import has finished. This function must only be used after the import has finished.
Otherwise it will be very expensive. Otherwise it will be very expensive.
@@ -108,7 +114,7 @@ class Indexer:
return cur.rowcount > 0 return cur.rowcount > 0
def index_full(self, analyse=True): def index_full(self, analyse: bool = True) -> None:
""" Index the complete database. This will first index boundaries """ Index the complete database. This will first index boundaries
followed by all other objects. When `analyse` is True, then the followed by all other objects. When `analyse` is True, then the
database will be analysed at the appropriate places to database will be analysed at the appropriate places to
@@ -117,7 +123,7 @@ class Indexer:
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
conn.autocommit = True conn.autocommit = True
def _analyze(): def _analyze() -> None:
if analyse: if analyse:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute('ANALYZE') cur.execute('ANALYZE')
@@ -138,7 +144,7 @@ class Indexer:
_analyze() _analyze()
def index_boundaries(self, minrank, maxrank): def index_boundaries(self, minrank: int, maxrank: int) -> None:
""" Index only administrative boundaries within the given rank range. """ Index only administrative boundaries within the given rank range.
""" """
LOG.warning("Starting indexing boundaries using %s threads", LOG.warning("Starting indexing boundaries using %s threads",
@@ -148,7 +154,7 @@ class Indexer:
for rank in range(max(minrank, 4), min(maxrank, 26)): for rank in range(max(minrank, 4), min(maxrank, 26)):
self._index(runners.BoundaryRunner(rank, analyzer)) self._index(runners.BoundaryRunner(rank, analyzer))
def index_by_rank(self, minrank, maxrank): def index_by_rank(self, minrank: int, maxrank: int) -> None:
""" Index all entries of placex in the given rank range (inclusive) """ Index all entries of placex in the given rank range (inclusive)
in order of their address rank. in order of their address rank.
@@ -168,7 +174,7 @@ class Indexer:
self._index(runners.InterpolationRunner(analyzer), 20) self._index(runners.InterpolationRunner(analyzer), 20)
def index_postcodes(self): def index_postcodes(self) -> None:
"""Index the entries ofthe location_postcode table. """Index the entries ofthe location_postcode table.
""" """
LOG.warning("Starting indexing postcodes using %s threads", self.num_threads) LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
@@ -176,7 +182,7 @@ class Indexer:
self._index(runners.PostcodeRunner(), 20) self._index(runners.PostcodeRunner(), 20)
def update_status_table(self): def update_status_table(self) -> None:
""" Update the status in the status table to 'indexed'. """ Update the status in the status table to 'indexed'.
""" """
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
@@ -185,7 +191,7 @@ class Indexer:
conn.commit() conn.commit()
def _index(self, runner, batch=1): def _index(self, runner: runners.Runner, batch: int = 1) -> None:
""" Index a single rank or table. `runner` describes the SQL to use """ Index a single rank or table. `runner` describes the SQL to use
for indexing. `batch` describes the number of objects that for indexing. `batch` describes the number of objects that
should be processed with a single SQL statement should be processed with a single SQL statement

View File

@@ -22,7 +22,7 @@ class ProgressLogger:
should be reported. should be reported.
""" """
def __init__(self, name, total, log_interval=1): def __init__(self, name: str, total: int, log_interval: int = 1) -> None:
self.name = name self.name = name
self.total_places = total self.total_places = total
self.done_places = 0 self.done_places = 0
@@ -30,7 +30,7 @@ class ProgressLogger:
self.log_interval = log_interval self.log_interval = log_interval
self.next_info = INITIAL_PROGRESS if LOG.isEnabledFor(logging.WARNING) else total + 1 self.next_info = INITIAL_PROGRESS if LOG.isEnabledFor(logging.WARNING) else total + 1
def add(self, num=1): def add(self, num: int = 1) -> None:
""" Mark `num` places as processed. Print a log message if the """ Mark `num` places as processed. Print a log message if the
logging is at least info and the log interval has passed. logging is at least info and the log interval has passed.
""" """
@@ -55,14 +55,14 @@ class ProgressLogger:
self.next_info += int(places_per_sec) * self.log_interval self.next_info += int(places_per_sec) * self.log_interval
def done(self): def done(self) -> None:
""" Print final statistics about the progress. """ Print final statistics about the progress.
""" """
rank_end_time = datetime.now() rank_end_time = datetime.now()
if rank_end_time == self.rank_start_time: if rank_end_time == self.rank_start_time:
diff_seconds = 0 diff_seconds = 0.0
places_per_sec = self.done_places places_per_sec = float(self.done_places)
else: else:
diff_seconds = (rank_end_time - self.rank_start_time).total_seconds() diff_seconds = (rank_end_time - self.rank_start_time).total_seconds()
places_per_sec = self.done_places / diff_seconds places_per_sec = self.done_places / diff_seconds

View File

@@ -8,35 +8,49 @@
Mix-ins that provide the actual commands for the indexer for various indexing Mix-ins that provide the actual commands for the indexer for various indexing
tasks. tasks.
""" """
from typing import Any, List
import functools import functools
from typing_extensions import Protocol
from psycopg2 import sql as pysql from psycopg2 import sql as pysql
import psycopg2.extras import psycopg2.extras
from nominatim.data.place_info import PlaceInfo from nominatim.data.place_info import PlaceInfo
from nominatim.tokenizer.base import AbstractAnalyzer
from nominatim.db.async_connection import DBConnection
from nominatim.typing import Query, DictCursorResult, DictCursorResults
# pylint: disable=C0111 # pylint: disable=C0111
def _mk_valuelist(template, num): def _mk_valuelist(template: str, num: int) -> pysql.Composed:
return pysql.SQL(',').join([pysql.SQL(template)] * num) return pysql.SQL(',').join([pysql.SQL(template)] * num)
def _analyze_place(place, analyzer): def _analyze_place(place: DictCursorResult, analyzer: AbstractAnalyzer) -> psycopg2.extras.Json:
return psycopg2.extras.Json(analyzer.process_place(PlaceInfo(place))) return psycopg2.extras.Json(analyzer.process_place(PlaceInfo(place)))
class Runner(Protocol):
def name(self) -> str: ...
def sql_count_objects(self) -> Query: ...
def sql_get_objects(self) -> Query: ...
def get_place_details(self, worker: DBConnection,
ids: DictCursorResults) -> DictCursorResults: ...
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None: ...
class AbstractPlacexRunner: class AbstractPlacexRunner:
""" Returns SQL commands for indexing of the placex table. """ Returns SQL commands for indexing of the placex table.
""" """
SELECT_SQL = pysql.SQL('SELECT place_id FROM placex ') SELECT_SQL = pysql.SQL('SELECT place_id FROM placex ')
UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)" UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)"
def __init__(self, rank, analyzer): def __init__(self, rank: int, analyzer: AbstractAnalyzer) -> None:
self.rank = rank self.rank = rank
self.analyzer = analyzer self.analyzer = analyzer
@staticmethod
@functools.lru_cache(maxsize=1) @functools.lru_cache(maxsize=1)
def _index_sql(num_places): def _index_sql(self, num_places: int) -> pysql.Composed:
return pysql.SQL( return pysql.SQL(
""" UPDATE placex """ UPDATE placex
SET indexed_status = 0, address = v.addr, token_info = v.ti, SET indexed_status = 0, address = v.addr, token_info = v.ti,
@@ -46,16 +60,17 @@ class AbstractPlacexRunner:
""").format(_mk_valuelist(AbstractPlacexRunner.UPDATE_LINE, num_places)) """).format(_mk_valuelist(AbstractPlacexRunner.UPDATE_LINE, num_places))
@staticmethod def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
def get_place_details(worker, ids):
worker.perform("""SELECT place_id, extra.* worker.perform("""SELECT place_id, extra.*
FROM placex, LATERAL placex_indexing_prepare(placex) as extra FROM placex, LATERAL placex_indexing_prepare(placex) as extra
WHERE place_id IN %s""", WHERE place_id IN %s""",
(tuple((p[0] for p in ids)), )) (tuple((p[0] for p in ids)), ))
return []
def index_places(self, worker, places):
values = [] def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
values: List[Any] = []
for place in places: for place in places:
for field in ('place_id', 'name', 'address', 'linked_place_id'): for field in ('place_id', 'name', 'address', 'linked_place_id'):
values.append(place[field]) values.append(place[field])
@@ -68,15 +83,15 @@ class RankRunner(AbstractPlacexRunner):
""" Returns SQL commands for indexing one rank within the placex table. """ Returns SQL commands for indexing one rank within the placex table.
""" """
def name(self): def name(self) -> str:
return f"rank {self.rank}" return f"rank {self.rank}"
def sql_count_objects(self): def sql_count_objects(self) -> pysql.Composed:
return pysql.SQL("""SELECT count(*) FROM placex return pysql.SQL("""SELECT count(*) FROM placex
WHERE rank_address = {} and indexed_status > 0 WHERE rank_address = {} and indexed_status > 0
""").format(pysql.Literal(self.rank)) """).format(pysql.Literal(self.rank))
def sql_get_objects(self): def sql_get_objects(self) -> pysql.Composed:
return self.SELECT_SQL + pysql.SQL( return self.SELECT_SQL + pysql.SQL(
"""WHERE indexed_status > 0 and rank_address = {} """WHERE indexed_status > 0 and rank_address = {}
ORDER BY geometry_sector ORDER BY geometry_sector
@@ -88,17 +103,17 @@ class BoundaryRunner(AbstractPlacexRunner):
of a certain rank. of a certain rank.
""" """
def name(self): def name(self) -> str:
return f"boundaries rank {self.rank}" return f"boundaries rank {self.rank}"
def sql_count_objects(self): def sql_count_objects(self) -> pysql.Composed:
return pysql.SQL("""SELECT count(*) FROM placex return pysql.SQL("""SELECT count(*) FROM placex
WHERE indexed_status > 0 WHERE indexed_status > 0
AND rank_search = {} AND rank_search = {}
AND class = 'boundary' and type = 'administrative' AND class = 'boundary' and type = 'administrative'
""").format(pysql.Literal(self.rank)) """).format(pysql.Literal(self.rank))
def sql_get_objects(self): def sql_get_objects(self) -> pysql.Composed:
return self.SELECT_SQL + pysql.SQL( return self.SELECT_SQL + pysql.SQL(
"""WHERE indexed_status > 0 and rank_search = {} """WHERE indexed_status > 0 and rank_search = {}
and class = 'boundary' and type = 'administrative' and class = 'boundary' and type = 'administrative'
@@ -111,37 +126,33 @@ class InterpolationRunner:
location_property_osmline. location_property_osmline.
""" """
def __init__(self, analyzer): def __init__(self, analyzer: AbstractAnalyzer) -> None:
self.analyzer = analyzer self.analyzer = analyzer
@staticmethod def name(self) -> str:
def name():
return "interpolation lines (location_property_osmline)" return "interpolation lines (location_property_osmline)"
@staticmethod def sql_count_objects(self) -> str:
def sql_count_objects():
return """SELECT count(*) FROM location_property_osmline return """SELECT count(*) FROM location_property_osmline
WHERE indexed_status > 0""" WHERE indexed_status > 0"""
@staticmethod def sql_get_objects(self) -> str:
def sql_get_objects():
return """SELECT place_id return """SELECT place_id
FROM location_property_osmline FROM location_property_osmline
WHERE indexed_status > 0 WHERE indexed_status > 0
ORDER BY geometry_sector""" ORDER BY geometry_sector"""
@staticmethod def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
def get_place_details(worker, ids):
worker.perform("""SELECT place_id, get_interpolation_address(address, osm_id) as address worker.perform("""SELECT place_id, get_interpolation_address(address, osm_id) as address
FROM location_property_osmline WHERE place_id IN %s""", FROM location_property_osmline WHERE place_id IN %s""",
(tuple((p[0] for p in ids)), )) (tuple((p[0] for p in ids)), ))
return []
@staticmethod
@functools.lru_cache(maxsize=1) @functools.lru_cache(maxsize=1)
def _index_sql(num_places): def _index_sql(self, num_places: int) -> pysql.Composed:
return pysql.SQL("""UPDATE location_property_osmline return pysql.SQL("""UPDATE location_property_osmline
SET indexed_status = 0, address = v.addr, token_info = v.ti SET indexed_status = 0, address = v.addr, token_info = v.ti
FROM (VALUES {}) as v(id, addr, ti) FROM (VALUES {}) as v(id, addr, ti)
@@ -149,8 +160,8 @@ class InterpolationRunner:
""").format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", num_places)) """).format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", num_places))
def index_places(self, worker, places): def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
values = [] values: List[Any] = []
for place in places: for place in places:
values.extend((place[x] for x in ('place_id', 'address'))) values.extend((place[x] for x in ('place_id', 'address')))
values.append(_analyze_place(place, self.analyzer)) values.append(_analyze_place(place, self.analyzer))
@@ -159,26 +170,28 @@ class InterpolationRunner:
class PostcodeRunner: class PostcodeRunner(Runner):
""" Provides the SQL commands for indexing the location_postcode table. """ Provides the SQL commands for indexing the location_postcode table.
""" """
@staticmethod def name(self) -> str:
def name():
return "postcodes (location_postcode)" return "postcodes (location_postcode)"
@staticmethod
def sql_count_objects(): def sql_count_objects(self) -> str:
return 'SELECT count(*) FROM location_postcode WHERE indexed_status > 0' return 'SELECT count(*) FROM location_postcode WHERE indexed_status > 0'
@staticmethod
def sql_get_objects(): def sql_get_objects(self) -> str:
return """SELECT place_id FROM location_postcode return """SELECT place_id FROM location_postcode
WHERE indexed_status > 0 WHERE indexed_status > 0
ORDER BY country_code, postcode""" ORDER BY country_code, postcode"""
@staticmethod
def index_places(worker, ids): def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
return ids
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
worker.perform(pysql.SQL("""UPDATE location_postcode SET indexed_status = 0 worker.perform(pysql.SQL("""UPDATE location_postcode SET indexed_status = 0
WHERE place_id IN ({})""") WHERE place_id IN ({})""")
.format(pysql.SQL(',').join((pysql.Literal(i[0]) for i in ids)))) .format(pysql.SQL(',').join((pysql.Literal(i[0]) for i in places))))

View File

@@ -9,14 +9,15 @@ Type definitions for typing annotations.
Complex type definitions are moved here, to keep the source files readable. Complex type definitions are moved here, to keep the source files readable.
""" """
from typing import Union, Mapping, TypeVar, TYPE_CHECKING from typing import Any, Union, Mapping, TypeVar, Sequence, TYPE_CHECKING
# Generics varaible names do not confirm to naming styles, ignore globally here. # Generics varaible names do not confirm to naming styles, ignore globally here.
# pylint: disable=invalid-name # pylint: disable=invalid-name,abstract-method,multiple-statements,missing-class-docstring
if TYPE_CHECKING: if TYPE_CHECKING:
import psycopg2.sql import psycopg2.sql
import psycopg2.extensions import psycopg2.extensions
import psycopg2.extras
import os import os
StrPath = Union[str, 'os.PathLike[str]'] StrPath = Union[str, 'os.PathLike[str]']
@@ -26,4 +27,12 @@ SysEnv = Mapping[str, str]
# psycopg2-related types # psycopg2-related types
Query = Union[str, bytes, 'psycopg2.sql.Composable'] Query = Union[str, bytes, 'psycopg2.sql.Composable']
T_ResultKey = TypeVar('T_ResultKey', int, str)
class DictCursorResult(Mapping[str, Any]):
def __getitem__(self, x: Union[int, str]) -> Any: ...
DictCursorResults = Sequence[DictCursorResult]
T_cursor = TypeVar('T_cursor', bound='psycopg2.extensions.cursor') T_cursor = TypeVar('T_cursor', bound='psycopg2.extensions.cursor')