add type annotations to tool functions

This commit is contained in:
Sarah Hoffmann
2022-07-16 23:28:02 +02:00
parent 6c6bbe5747
commit 17bbe2637a
6 changed files with 65 additions and 39 deletions

View File

@@ -37,7 +37,7 @@ class Cursor(psycopg2.extras.DictCursor):
def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]], def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
template: Optional[str] = None) -> None: template: Optional[Query] = None) -> None:
""" Wrapper for the psycopg2 convenience function to execute """ Wrapper for the psycopg2 convenience function to execute
SQL for a list of values. SQL for a list of values.
""" """

View File

@@ -7,6 +7,7 @@
""" """
Function to add additional OSM data from a file or the API into the database. Function to add additional OSM data from a file or the API into the database.
""" """
from typing import Any, MutableMapping
from pathlib import Path from pathlib import Path
import logging import logging
import urllib import urllib
@@ -15,7 +16,7 @@ from nominatim.tools.exec_utils import run_osm2pgsql, get_url
LOG = logging.getLogger() LOG = logging.getLogger()
def add_data_from_file(fname, options): def add_data_from_file(fname: str, options: MutableMapping[str, Any]) -> int:
""" Adds data from a OSM file to the database. The file may be a normal """ Adds data from a OSM file to the database. The file may be a normal
OSM file or a diff file in all formats supported by libosmium. OSM file or a diff file in all formats supported by libosmium.
""" """
@@ -27,7 +28,8 @@ def add_data_from_file(fname, options):
return 0 return 0
def add_osm_object(osm_type, osm_id, use_main_api, options): def add_osm_object(osm_type: str, osm_id: int, use_main_api: bool,
options: MutableMapping[str, Any]) -> None:
""" Add or update a single OSM object from the latest version of the """ Add or update a single OSM object from the latest version of the
API. API.
""" """

View File

@@ -7,22 +7,27 @@
""" """
Functions for database analysis and maintenance. Functions for database analysis and maintenance.
""" """
from typing import Optional, Tuple, Any
import logging import logging
from psycopg2.extras import Json, register_hstore from psycopg2.extras import Json, register_hstore
from nominatim.db.connection import connect from nominatim.config import Configuration
from nominatim.db.connection import connect, Cursor
from nominatim.tokenizer import factory as tokenizer_factory from nominatim.tokenizer import factory as tokenizer_factory
from nominatim.errors import UsageError from nominatim.errors import UsageError
from nominatim.data.place_info import PlaceInfo from nominatim.data.place_info import PlaceInfo
from nominatim.typing import DictCursorResult
LOG = logging.getLogger() LOG = logging.getLogger()
def _get_place_info(cursor, osm_id, place_id): def _get_place_info(cursor: Cursor, osm_id: Optional[str],
place_id: Optional[int]) -> DictCursorResult:
sql = """SELECT place_id, extra.* sql = """SELECT place_id, extra.*
FROM placex, LATERAL placex_indexing_prepare(placex) as extra FROM placex, LATERAL placex_indexing_prepare(placex) as extra
""" """
values: Tuple[Any, ...]
if osm_id: if osm_id:
osm_type = osm_id[0].upper() osm_type = osm_id[0].upper()
if osm_type not in 'NWR' or not osm_id[1:].isdigit(): if osm_type not in 'NWR' or not osm_id[1:].isdigit():
@@ -44,10 +49,11 @@ def _get_place_info(cursor, osm_id, place_id):
LOG.fatal("OSM object %s not found in database.", osm_id) LOG.fatal("OSM object %s not found in database.", osm_id)
raise UsageError("OSM object not found") raise UsageError("OSM object not found")
return cursor.fetchone() return cursor.fetchone() # type: ignore[no-untyped-call]
def analyse_indexing(config, osm_id=None, place_id=None): def analyse_indexing(config: Configuration, osm_id: Optional[str] = None,
place_id: Optional[int] = None) -> None:
""" Analyse indexing of a single Nominatim object. """ Analyse indexing of a single Nominatim object.
""" """
with connect(config.get_libpq_dsn()) as conn: with connect(config.get_libpq_dsn()) as conn:

View File

@@ -8,7 +8,9 @@
Functions for importing, updating and otherwise maintaining the table Functions for importing, updating and otherwise maintaining the table
of artificial postcode centroids. of artificial postcode centroids.
""" """
from typing import Optional, Tuple, Dict, List, TextIO
from collections import defaultdict from collections import defaultdict
from pathlib import Path
import csv import csv
import gzip import gzip
import logging import logging
@@ -16,18 +18,19 @@ from math import isfinite
from psycopg2 import sql as pysql from psycopg2 import sql as pysql
from nominatim.db.connection import connect from nominatim.db.connection import connect, Connection
from nominatim.utils.centroid import PointsCentroid from nominatim.utils.centroid import PointsCentroid
from nominatim.data.postcode_format import PostcodeFormatter from nominatim.data.postcode_format import PostcodeFormatter, CountryPostcodeMatcher
from nominatim.tokenizer.base import AbstractAnalyzer, AbstractTokenizer
LOG = logging.getLogger() LOG = logging.getLogger()
def _to_float(num, max_value): def _to_float(numstr: str, max_value: float) -> float:
""" Convert the number in string into a float. The number is expected """ Convert the number in string into a float. The number is expected
to be in the range of [-max_value, max_value]. Otherwise rises a to be in the range of [-max_value, max_value]. Otherwise rises a
ValueError. ValueError.
""" """
num = float(num) num = float(numstr)
if not isfinite(num) or num <= -max_value or num >= max_value: if not isfinite(num) or num <= -max_value or num >= max_value:
raise ValueError() raise ValueError()
@@ -37,18 +40,19 @@ class _PostcodeCollector:
""" Collector for postcodes of a single country. """ Collector for postcodes of a single country.
""" """
def __init__(self, country, matcher): def __init__(self, country: str, matcher: Optional[CountryPostcodeMatcher]):
self.country = country self.country = country
self.matcher = matcher self.matcher = matcher
self.collected = defaultdict(PointsCentroid) self.collected: Dict[str, PointsCentroid] = defaultdict(PointsCentroid)
self.normalization_cache = None self.normalization_cache: Optional[Tuple[str, Optional[str]]] = None
def add(self, postcode, x, y): def add(self, postcode: str, x: float, y: float) -> None:
""" Add the given postcode to the collection cache. If the postcode """ Add the given postcode to the collection cache. If the postcode
already existed, it is overwritten with the new centroid. already existed, it is overwritten with the new centroid.
""" """
if self.matcher is not None: if self.matcher is not None:
normalized: Optional[str]
if self.normalization_cache and self.normalization_cache[0] == postcode: if self.normalization_cache and self.normalization_cache[0] == postcode:
normalized = self.normalization_cache[1] normalized = self.normalization_cache[1]
else: else:
@@ -60,7 +64,7 @@ class _PostcodeCollector:
self.collected[normalized] += (x, y) self.collected[normalized] += (x, y)
def commit(self, conn, analyzer, project_dir): def commit(self, conn: Connection, analyzer: AbstractAnalyzer, project_dir: Path) -> None:
""" Update postcodes for the country from the postcodes selected so far """ Update postcodes for the country from the postcodes selected so far
as well as any externally supplied postcodes. as well as any externally supplied postcodes.
""" """
@@ -94,7 +98,8 @@ class _PostcodeCollector:
""").format(pysql.Literal(self.country)), to_update) """).format(pysql.Literal(self.country)), to_update)
def _compute_changes(self, conn): def _compute_changes(self, conn: Connection) \
-> Tuple[List[Tuple[str, float, float]], List[str], List[Tuple[str, float, float]]]:
""" Compute which postcodes from the collected postcodes have to be """ Compute which postcodes from the collected postcodes have to be
added or modified and which from the location_postcode table added or modified and which from the location_postcode table
have to be deleted. have to be deleted.
@@ -116,12 +121,12 @@ class _PostcodeCollector:
to_delete.append(postcode) to_delete.append(postcode)
to_add = [(k, *v.centroid()) for k, v in self.collected.items()] to_add = [(k, *v.centroid()) for k, v in self.collected.items()]
self.collected = None self.collected = defaultdict(PointsCentroid)
return to_add, to_delete, to_update return to_add, to_delete, to_update
def _update_from_external(self, analyzer, project_dir): def _update_from_external(self, analyzer: AbstractAnalyzer, project_dir: Path) -> None:
""" Look for an external postcode file for the active country in """ Look for an external postcode file for the active country in
the project directory and add missing postcodes when found. the project directory and add missing postcodes when found.
""" """
@@ -151,7 +156,7 @@ class _PostcodeCollector:
csvfile.close() csvfile.close()
def _open_external(self, project_dir): def _open_external(self, project_dir: Path) -> Optional[TextIO]:
fname = project_dir / f'{self.country}_postcodes.csv' fname = project_dir / f'{self.country}_postcodes.csv'
if fname.is_file(): if fname.is_file():
@@ -167,7 +172,7 @@ class _PostcodeCollector:
return None return None
def update_postcodes(dsn, project_dir, tokenizer): def update_postcodes(dsn: str, project_dir: Path, tokenizer: AbstractTokenizer) -> None:
""" Update the table of artificial postcodes. """ Update the table of artificial postcodes.
Computes artificial postcode centroids from the placex table, Computes artificial postcode centroids from the placex table,
@@ -220,7 +225,7 @@ def update_postcodes(dsn, project_dir, tokenizer):
analyzer.update_postcodes_from_db() analyzer.update_postcodes_from_db()
def can_compute(dsn): def can_compute(dsn: str) -> bool:
""" """
Check that the place table exists so that Check that the place table exists so that
postcodes can be computed. postcodes can be computed.

View File

@@ -7,12 +7,15 @@
""" """
Functions for bringing auxiliary data in the database up-to-date. Functions for bringing auxiliary data in the database up-to-date.
""" """
from typing import MutableSequence, Tuple, Any, Type, Mapping, Sequence, List, cast
import logging import logging
from textwrap import dedent from textwrap import dedent
from pathlib import Path from pathlib import Path
from psycopg2 import sql as pysql from psycopg2 import sql as pysql
from nominatim.config import Configuration
from nominatim.db.connection import Connection
from nominatim.db.utils import execute_file from nominatim.db.utils import execute_file
from nominatim.db.sql_preprocessor import SQLPreprocessor from nominatim.db.sql_preprocessor import SQLPreprocessor
from nominatim.version import version_str from nominatim.version import version_str
@@ -21,7 +24,8 @@ LOG = logging.getLogger()
OSM_TYPE = {'N': 'node', 'W': 'way', 'R': 'relation'} OSM_TYPE = {'N': 'node', 'W': 'way', 'R': 'relation'}
def _add_address_level_rows_from_entry(rows, entry): def _add_address_level_rows_from_entry(rows: MutableSequence[Tuple[Any, ...]],
entry: Mapping[str, Any]) -> None:
""" Converts a single entry from the JSON format for address rank """ Converts a single entry from the JSON format for address rank
descriptions into a flat format suitable for inserting into a descriptions into a flat format suitable for inserting into a
PostgreSQL table and adds these lines to `rows`. PostgreSQL table and adds these lines to `rows`.
@@ -38,14 +42,15 @@ def _add_address_level_rows_from_entry(rows, entry):
for country in countries: for country in countries:
rows.append((country, key, value, rank_search, rank_address)) rows.append((country, key, value, rank_search, rank_address))
def load_address_levels(conn, table, levels):
def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[str, Any]]) -> None:
""" Replace the `address_levels` table with the contents of `levels'. """ Replace the `address_levels` table with the contents of `levels'.
A new table is created any previously existing table is dropped. A new table is created any previously existing table is dropped.
The table has the following columns: The table has the following columns:
country, class, type, rank_search, rank_address country, class, type, rank_search, rank_address
""" """
rows = [] rows: List[Tuple[Any, ...]] = []
for entry in levels: for entry in levels:
_add_address_level_rows_from_entry(rows, entry) _add_address_level_rows_from_entry(rows, entry)
@@ -69,7 +74,7 @@ def load_address_levels(conn, table, levels):
conn.commit() conn.commit()
def load_address_levels_from_config(conn, config): def load_address_levels_from_config(conn: Connection, config: Configuration) -> None:
""" Replace the `address_levels` table with the content as """ Replace the `address_levels` table with the content as
defined in the given configuration. Uses the parameter defined in the given configuration. Uses the parameter
NOMINATIM_ADDRESS_LEVEL_CONFIG to determine the location of the NOMINATIM_ADDRESS_LEVEL_CONFIG to determine the location of the
@@ -79,7 +84,9 @@ def load_address_levels_from_config(conn, config):
load_address_levels(conn, 'address_levels', cfg) load_address_levels(conn, 'address_levels', cfg)
def create_functions(conn, config, enable_diff_updates=True, enable_debug=False): def create_functions(conn: Connection, config: Configuration,
enable_diff_updates: bool = True,
enable_debug: bool = False) -> None:
""" (Re)create the PL/pgSQL functions. """ (Re)create the PL/pgSQL functions.
""" """
sql = SQLPreprocessor(conn, config) sql = SQLPreprocessor(conn, config)
@@ -116,7 +123,7 @@ PHP_CONST_DEFS = (
) )
def import_wikipedia_articles(dsn, data_path, ignore_errors=False): def import_wikipedia_articles(dsn: str, data_path: Path, ignore_errors: bool = False) -> int:
""" Replaces the wikipedia importance tables with new data. """ Replaces the wikipedia importance tables with new data.
The import is run in a single transaction so that the new data The import is run in a single transaction so that the new data
is replace seemlessly. is replace seemlessly.
@@ -140,7 +147,7 @@ def import_wikipedia_articles(dsn, data_path, ignore_errors=False):
return 0 return 0
def recompute_importance(conn): def recompute_importance(conn: Connection) -> None:
""" Recompute wikipedia links and importance for all entries in placex. """ Recompute wikipedia links and importance for all entries in placex.
This is a long-running operations that must not be executed in This is a long-running operations that must not be executed in
parallel with updates. parallel with updates.
@@ -163,12 +170,13 @@ def recompute_importance(conn):
conn.commit() conn.commit()
def _quote_php_variable(var_type, config, conf_name): def _quote_php_variable(var_type: Type[Any], config: Configuration,
conf_name: str) -> str:
if var_type == bool: if var_type == bool:
return 'true' if config.get_bool(conf_name) else 'false' return 'true' if config.get_bool(conf_name) else 'false'
if var_type == int: if var_type == int:
return getattr(config, conf_name) return cast(str, getattr(config, conf_name))
if not getattr(config, conf_name): if not getattr(config, conf_name):
return 'false' return 'false'
@@ -182,7 +190,7 @@ def _quote_php_variable(var_type, config, conf_name):
return f"'{quoted}'" return f"'{quoted}'"
def setup_website(basedir, config, conn): def setup_website(basedir: Path, config: Configuration, conn: Connection) -> None:
""" Create the website script stubs. """ Create the website script stubs.
""" """
if not basedir.exists(): if not basedir.exists():
@@ -215,7 +223,8 @@ def setup_website(basedir, config, conn):
(basedir / script).write_text(template.format(script), 'utf-8') (basedir / script).write_text(template.format(script), 'utf-8')
def invalidate_osm_object(osm_type, osm_id, conn, recursive=True): def invalidate_osm_object(osm_type: str, osm_id: int, conn: Connection,
recursive: bool = True) -> None:
""" Mark the given OSM object for reindexing. When 'recursive' is set """ Mark the given OSM object for reindexing. When 'recursive' is set
to True (the default), then all dependent objects are marked for to True (the default), then all dependent objects are marked for
reindexing as well. reindexing as well.

View File

@@ -7,6 +7,7 @@
""" """
Functions for updating a database from a replication source. Functions for updating a database from a replication source.
""" """
from typing import ContextManager, MutableMapping, Any, Generator, cast
from contextlib import contextmanager from contextlib import contextmanager
import datetime as dt import datetime as dt
from enum import Enum from enum import Enum
@@ -14,6 +15,7 @@ import logging
import time import time
from nominatim.db import status from nominatim.db import status
from nominatim.db.connection import Connection
from nominatim.tools.exec_utils import run_osm2pgsql from nominatim.tools.exec_utils import run_osm2pgsql
from nominatim.errors import UsageError from nominatim.errors import UsageError
@@ -27,7 +29,7 @@ except ImportError as exc:
LOG = logging.getLogger() LOG = logging.getLogger()
def init_replication(conn, base_url): def init_replication(conn: Connection, base_url: str) -> None:
""" Set up replication for the server at the given base URL. """ Set up replication for the server at the given base URL.
""" """
LOG.info("Using replication source: %s", base_url) LOG.info("Using replication source: %s", base_url)
@@ -51,7 +53,7 @@ def init_replication(conn, base_url):
LOG.warning("Updates initialised at sequence %s (%s)", seq, date) LOG.warning("Updates initialised at sequence %s (%s)", seq, date)
def check_for_updates(conn, base_url): def check_for_updates(conn: Connection, base_url: str) -> int:
""" Check if new data is available from the replication service at the """ Check if new data is available from the replication service at the
given base URL. given base URL.
""" """
@@ -84,7 +86,7 @@ class UpdateState(Enum):
NO_CHANGES = 3 NO_CHANGES = 3
def update(conn, options): def update(conn: Connection, options: MutableMapping[str, Any]) -> UpdateState:
""" Update database from the next batch of data. Returns the state of """ Update database from the next batch of data. Returns the state of
updates according to `UpdateState`. updates according to `UpdateState`.
""" """
@@ -95,6 +97,8 @@ def update(conn, options):
"Please run 'nominatim replication --init' first.") "Please run 'nominatim replication --init' first.")
raise UsageError("Replication not set up.") raise UsageError("Replication not set up.")
assert startdate is not None
if not indexed and options['indexed_only']: if not indexed and options['indexed_only']:
LOG.info("Skipping update. There is data that needs indexing.") LOG.info("Skipping update. There is data that needs indexing.")
return UpdateState.MORE_PENDING return UpdateState.MORE_PENDING
@@ -132,17 +136,17 @@ def update(conn, options):
return UpdateState.UP_TO_DATE return UpdateState.UP_TO_DATE
def _make_replication_server(url): def _make_replication_server(url: str) -> ContextManager[ReplicationServer]:
""" Returns a ReplicationServer in form of a context manager. """ Returns a ReplicationServer in form of a context manager.
Creates a light wrapper around older versions of pyosmium that did Creates a light wrapper around older versions of pyosmium that did
not support the context manager interface. not support the context manager interface.
""" """
if hasattr(ReplicationServer, '__enter__'): if hasattr(ReplicationServer, '__enter__'):
return ReplicationServer(url) return cast(ContextManager[ReplicationServer], ReplicationServer(url))
@contextmanager @contextmanager
def get_cm(): def get_cm() -> Generator[ReplicationServer, None, None]:
yield ReplicationServer(url) yield ReplicationServer(url)
return get_cm() return get_cm()