add type annotations to tool functions

This commit is contained in:
Sarah Hoffmann
2022-07-16 23:28:02 +02:00
parent 6c6bbe5747
commit 17bbe2637a
6 changed files with 65 additions and 39 deletions

View File

@@ -37,7 +37,7 @@ class Cursor(psycopg2.extras.DictCursor):
def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
template: Optional[str] = None) -> None:
template: Optional[Query] = None) -> None:
""" Wrapper for the psycopg2 convenience function to execute
SQL for a list of values.
"""

View File

@@ -7,6 +7,7 @@
"""
Function to add additional OSM data from a file or the API into the database.
"""
from typing import Any, MutableMapping
from pathlib import Path
import logging
import urllib
@@ -15,7 +16,7 @@ from nominatim.tools.exec_utils import run_osm2pgsql, get_url
LOG = logging.getLogger()
def add_data_from_file(fname, options):
def add_data_from_file(fname: str, options: MutableMapping[str, Any]) -> int:
""" Adds data from a OSM file to the database. The file may be a normal
OSM file or a diff file in all formats supported by libosmium.
"""
@@ -27,7 +28,8 @@ def add_data_from_file(fname, options):
return 0
def add_osm_object(osm_type, osm_id, use_main_api, options):
def add_osm_object(osm_type: str, osm_id: int, use_main_api: bool,
options: MutableMapping[str, Any]) -> None:
""" Add or update a single OSM object from the latest version of the
API.
"""

View File

@@ -7,22 +7,27 @@
"""
Functions for database analysis and maintenance.
"""
from typing import Optional, Tuple, Any
import logging
from psycopg2.extras import Json, register_hstore
from nominatim.db.connection import connect
from nominatim.config import Configuration
from nominatim.db.connection import connect, Cursor
from nominatim.tokenizer import factory as tokenizer_factory
from nominatim.errors import UsageError
from nominatim.data.place_info import PlaceInfo
from nominatim.typing import DictCursorResult
LOG = logging.getLogger()
def _get_place_info(cursor, osm_id, place_id):
def _get_place_info(cursor: Cursor, osm_id: Optional[str],
place_id: Optional[int]) -> DictCursorResult:
sql = """SELECT place_id, extra.*
FROM placex, LATERAL placex_indexing_prepare(placex) as extra
"""
values: Tuple[Any, ...]
if osm_id:
osm_type = osm_id[0].upper()
if osm_type not in 'NWR' or not osm_id[1:].isdigit():
@@ -44,10 +49,11 @@ def _get_place_info(cursor, osm_id, place_id):
LOG.fatal("OSM object %s not found in database.", osm_id)
raise UsageError("OSM object not found")
return cursor.fetchone()
return cursor.fetchone() # type: ignore[no-untyped-call]
def analyse_indexing(config, osm_id=None, place_id=None):
def analyse_indexing(config: Configuration, osm_id: Optional[str] = None,
place_id: Optional[int] = None) -> None:
""" Analyse indexing of a single Nominatim object.
"""
with connect(config.get_libpq_dsn()) as conn:

View File

@@ -8,7 +8,9 @@
Functions for importing, updating and otherwise maintaining the table
of artificial postcode centroids.
"""
from typing import Optional, Tuple, Dict, List, TextIO
from collections import defaultdict
from pathlib import Path
import csv
import gzip
import logging
@@ -16,18 +18,19 @@ from math import isfinite
from psycopg2 import sql as pysql
from nominatim.db.connection import connect
from nominatim.db.connection import connect, Connection
from nominatim.utils.centroid import PointsCentroid
from nominatim.data.postcode_format import PostcodeFormatter
from nominatim.data.postcode_format import PostcodeFormatter, CountryPostcodeMatcher
from nominatim.tokenizer.base import AbstractAnalyzer, AbstractTokenizer
LOG = logging.getLogger()
def _to_float(num, max_value):
def _to_float(numstr: str, max_value: float) -> float:
""" Convert the number in string into a float. The number is expected
to be in the range of [-max_value, max_value]. Otherwise rises a
ValueError.
"""
num = float(num)
num = float(numstr)
if not isfinite(num) or num <= -max_value or num >= max_value:
raise ValueError()
@@ -37,18 +40,19 @@ class _PostcodeCollector:
""" Collector for postcodes of a single country.
"""
def __init__(self, country, matcher):
def __init__(self, country: str, matcher: Optional[CountryPostcodeMatcher]):
self.country = country
self.matcher = matcher
self.collected = defaultdict(PointsCentroid)
self.normalization_cache = None
self.collected: Dict[str, PointsCentroid] = defaultdict(PointsCentroid)
self.normalization_cache: Optional[Tuple[str, Optional[str]]] = None
def add(self, postcode, x, y):
def add(self, postcode: str, x: float, y: float) -> None:
""" Add the given postcode to the collection cache. If the postcode
already existed, it is overwritten with the new centroid.
"""
if self.matcher is not None:
normalized: Optional[str]
if self.normalization_cache and self.normalization_cache[0] == postcode:
normalized = self.normalization_cache[1]
else:
@@ -60,7 +64,7 @@ class _PostcodeCollector:
self.collected[normalized] += (x, y)
def commit(self, conn, analyzer, project_dir):
def commit(self, conn: Connection, analyzer: AbstractAnalyzer, project_dir: Path) -> None:
""" Update postcodes for the country from the postcodes selected so far
as well as any externally supplied postcodes.
"""
@@ -94,7 +98,8 @@ class _PostcodeCollector:
""").format(pysql.Literal(self.country)), to_update)
def _compute_changes(self, conn):
def _compute_changes(self, conn: Connection) \
-> Tuple[List[Tuple[str, float, float]], List[str], List[Tuple[str, float, float]]]:
""" Compute which postcodes from the collected postcodes have to be
added or modified and which from the location_postcode table
have to be deleted.
@@ -116,12 +121,12 @@ class _PostcodeCollector:
to_delete.append(postcode)
to_add = [(k, *v.centroid()) for k, v in self.collected.items()]
self.collected = None
self.collected = defaultdict(PointsCentroid)
return to_add, to_delete, to_update
def _update_from_external(self, analyzer, project_dir):
def _update_from_external(self, analyzer: AbstractAnalyzer, project_dir: Path) -> None:
""" Look for an external postcode file for the active country in
the project directory and add missing postcodes when found.
"""
@@ -151,7 +156,7 @@ class _PostcodeCollector:
csvfile.close()
def _open_external(self, project_dir):
def _open_external(self, project_dir: Path) -> Optional[TextIO]:
fname = project_dir / f'{self.country}_postcodes.csv'
if fname.is_file():
@@ -167,7 +172,7 @@ class _PostcodeCollector:
return None
def update_postcodes(dsn, project_dir, tokenizer):
def update_postcodes(dsn: str, project_dir: Path, tokenizer: AbstractTokenizer) -> None:
""" Update the table of artificial postcodes.
Computes artificial postcode centroids from the placex table,
@@ -220,7 +225,7 @@ def update_postcodes(dsn, project_dir, tokenizer):
analyzer.update_postcodes_from_db()
def can_compute(dsn):
def can_compute(dsn: str) -> bool:
"""
Check that the place table exists so that
postcodes can be computed.

View File

@@ -7,12 +7,15 @@
"""
Functions for bringing auxiliary data in the database up-to-date.
"""
from typing import MutableSequence, Tuple, Any, Type, Mapping, Sequence, List, cast
import logging
from textwrap import dedent
from pathlib import Path
from psycopg2 import sql as pysql
from nominatim.config import Configuration
from nominatim.db.connection import Connection
from nominatim.db.utils import execute_file
from nominatim.db.sql_preprocessor import SQLPreprocessor
from nominatim.version import version_str
@@ -21,7 +24,8 @@ LOG = logging.getLogger()
OSM_TYPE = {'N': 'node', 'W': 'way', 'R': 'relation'}
def _add_address_level_rows_from_entry(rows, entry):
def _add_address_level_rows_from_entry(rows: MutableSequence[Tuple[Any, ...]],
entry: Mapping[str, Any]) -> None:
""" Converts a single entry from the JSON format for address rank
descriptions into a flat format suitable for inserting into a
PostgreSQL table and adds these lines to `rows`.
@@ -38,14 +42,15 @@ def _add_address_level_rows_from_entry(rows, entry):
for country in countries:
rows.append((country, key, value, rank_search, rank_address))
def load_address_levels(conn, table, levels):
def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[str, Any]]) -> None:
""" Replace the `address_levels` table with the contents of `levels'.
A new table is created any previously existing table is dropped.
The table has the following columns:
country, class, type, rank_search, rank_address
"""
rows = []
rows: List[Tuple[Any, ...]] = []
for entry in levels:
_add_address_level_rows_from_entry(rows, entry)
@@ -69,7 +74,7 @@ def load_address_levels(conn, table, levels):
conn.commit()
def load_address_levels_from_config(conn, config):
def load_address_levels_from_config(conn: Connection, config: Configuration) -> None:
""" Replace the `address_levels` table with the content as
defined in the given configuration. Uses the parameter
NOMINATIM_ADDRESS_LEVEL_CONFIG to determine the location of the
@@ -79,7 +84,9 @@ def load_address_levels_from_config(conn, config):
load_address_levels(conn, 'address_levels', cfg)
def create_functions(conn, config, enable_diff_updates=True, enable_debug=False):
def create_functions(conn: Connection, config: Configuration,
enable_diff_updates: bool = True,
enable_debug: bool = False) -> None:
""" (Re)create the PL/pgSQL functions.
"""
sql = SQLPreprocessor(conn, config)
@@ -116,7 +123,7 @@ PHP_CONST_DEFS = (
)
def import_wikipedia_articles(dsn, data_path, ignore_errors=False):
def import_wikipedia_articles(dsn: str, data_path: Path, ignore_errors: bool = False) -> int:
""" Replaces the wikipedia importance tables with new data.
The import is run in a single transaction so that the new data
is replace seemlessly.
@@ -140,7 +147,7 @@ def import_wikipedia_articles(dsn, data_path, ignore_errors=False):
return 0
def recompute_importance(conn):
def recompute_importance(conn: Connection) -> None:
""" Recompute wikipedia links and importance for all entries in placex.
This is a long-running operations that must not be executed in
parallel with updates.
@@ -163,12 +170,13 @@ def recompute_importance(conn):
conn.commit()
def _quote_php_variable(var_type, config, conf_name):
def _quote_php_variable(var_type: Type[Any], config: Configuration,
conf_name: str) -> str:
if var_type == bool:
return 'true' if config.get_bool(conf_name) else 'false'
if var_type == int:
return getattr(config, conf_name)
return cast(str, getattr(config, conf_name))
if not getattr(config, conf_name):
return 'false'
@@ -182,7 +190,7 @@ def _quote_php_variable(var_type, config, conf_name):
return f"'{quoted}'"
def setup_website(basedir, config, conn):
def setup_website(basedir: Path, config: Configuration, conn: Connection) -> None:
""" Create the website script stubs.
"""
if not basedir.exists():
@@ -215,7 +223,8 @@ def setup_website(basedir, config, conn):
(basedir / script).write_text(template.format(script), 'utf-8')
def invalidate_osm_object(osm_type, osm_id, conn, recursive=True):
def invalidate_osm_object(osm_type: str, osm_id: int, conn: Connection,
recursive: bool = True) -> None:
""" Mark the given OSM object for reindexing. When 'recursive' is set
to True (the default), then all dependent objects are marked for
reindexing as well.

View File

@@ -7,6 +7,7 @@
"""
Functions for updating a database from a replication source.
"""
from typing import ContextManager, MutableMapping, Any, Generator, cast
from contextlib import contextmanager
import datetime as dt
from enum import Enum
@@ -14,6 +15,7 @@ import logging
import time
from nominatim.db import status
from nominatim.db.connection import Connection
from nominatim.tools.exec_utils import run_osm2pgsql
from nominatim.errors import UsageError
@@ -27,7 +29,7 @@ except ImportError as exc:
LOG = logging.getLogger()
def init_replication(conn, base_url):
def init_replication(conn: Connection, base_url: str) -> None:
""" Set up replication for the server at the given base URL.
"""
LOG.info("Using replication source: %s", base_url)
@@ -51,7 +53,7 @@ def init_replication(conn, base_url):
LOG.warning("Updates initialised at sequence %s (%s)", seq, date)
def check_for_updates(conn, base_url):
def check_for_updates(conn: Connection, base_url: str) -> int:
""" Check if new data is available from the replication service at the
given base URL.
"""
@@ -84,7 +86,7 @@ class UpdateState(Enum):
NO_CHANGES = 3
def update(conn, options):
def update(conn: Connection, options: MutableMapping[str, Any]) -> UpdateState:
""" Update database from the next batch of data. Returns the state of
updates according to `UpdateState`.
"""
@@ -95,6 +97,8 @@ def update(conn, options):
"Please run 'nominatim replication --init' first.")
raise UsageError("Replication not set up.")
assert startdate is not None
if not indexed and options['indexed_only']:
LOG.info("Skipping update. There is data that needs indexing.")
return UpdateState.MORE_PENDING
@@ -132,17 +136,17 @@ def update(conn, options):
return UpdateState.UP_TO_DATE
def _make_replication_server(url):
def _make_replication_server(url: str) -> ContextManager[ReplicationServer]:
""" Returns a ReplicationServer in form of a context manager.
Creates a light wrapper around older versions of pyosmium that did
not support the context manager interface.
"""
if hasattr(ReplicationServer, '__enter__'):
return ReplicationServer(url)
return cast(ContextManager[ReplicationServer], ReplicationServer(url))
@contextmanager
def get_cm():
def get_cm() -> Generator[ReplicationServer, None, None]:
yield ReplicationServer(url)
return get_cm()