diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 53b76a03..481ec767 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -98,8 +98,8 @@ jobs: run: sudo apt-get install -y -qq python3-pytest if: matrix.ubuntu == 22 - - name: Install latest pylint - run: pip3 install pylint + - name: Install latest pylint/mypy + run: pip3 install -U pylint mypy types-PyYAML types-jinja2 types-psycopg2 types-psutil typing-extensions - name: PHP linting run: phpcs --report-width=120 . @@ -109,6 +109,11 @@ jobs: run: pylint nominatim working-directory: Nominatim + - name: Python static typechecking + run: mypy --strict nominatim + working-directory: Nominatim + + - name: PHP unit tests run: phpunit ./ working-directory: Nominatim/test/php diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 00000000..81a5c2e7 --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,13 @@ +[mypy] + +[mypy-icu.*] +ignore_missing_imports = True + +[mypy-osmium.*] +ignore_missing_imports = True + +[mypy-datrie.*] +ignore_missing_imports = True + +[mypy-dotenv.*] +ignore_missing_imports = True diff --git a/.pylintrc b/.pylintrc index 52d9fcf9..e8609407 100644 --- a/.pylintrc +++ b/.pylintrc @@ -11,6 +11,8 @@ ignored-modules=icu,datrie # 'with' statements. ignored-classes=NominatimArgs,closing # 'too-many-ancestors' is triggered already by deriving from UserDict -disable=too-few-public-methods,duplicate-code,too-many-ancestors,bad-option-value,no-self-use +# 'not-context-manager' disabled because it causes false positives once +# typed Python is enabled. See also https://github.com/PyCQA/pylint/issues/5273 +disable=too-few-public-methods,duplicate-code,too-many-ancestors,bad-option-value,no-self-use,not-context-manager good-names=i,x,y,fd,db,cc diff --git a/docs/develop/Development-Environment.md b/docs/develop/Development-Environment.md index 3cda610e..65dc7990 100644 --- a/docs/develop/Development-Environment.md +++ b/docs/develop/Development-Environment.md @@ -33,6 +33,8 @@ It has the following additional requirements: * [phpunit](https://phpunit.de) (9.5 is known to work) * [PHP CodeSniffer](https://github.com/squizlabs/PHP_CodeSniffer) * [Pylint](https://pylint.org/) (CI always runs the latest version from pip) +* [mypy](http://mypy-lang.org/) (plus typing information for external libs) +* [Python Typing Extensions](https://github.com/python/typing_extensions) (for Python < 3.9) * [pytest](https://pytest.org) The documentation is built with mkdocs: @@ -50,9 +52,10 @@ To install all necessary packages run: ```sh sudo apt install php-cgi phpunit php-codesniffer \ - python3-pip python3-setuptools python3-dev pylint + python3-pip python3-setuptools python3-dev -pip3 install --user behave mkdocs mkdocstrings pytest +pip3 install --user behave mkdocs mkdocstrings pytest \ + pylint mypy types-PyYAML types-jinja2 types-psycopg2 ``` The `mkdocs` executable will be located in `.local/bin`. You may have to add diff --git a/lib-sql/tables.sql b/lib-sql/tables.sql index 538286b8..03431d95 100644 --- a/lib-sql/tables.sql +++ b/lib-sql/tables.sql @@ -45,7 +45,7 @@ GRANT SELECT ON TABLE country_name TO "{{config.DATABASE_WEBUSER}}"; DROP TABLE IF EXISTS nominatim_properties; CREATE TABLE nominatim_properties ( - property TEXT, + property TEXT NOT NULL, value TEXT ); GRANT SELECT ON TABLE nominatim_properties TO "{{config.DATABASE_WEBUSER}}"; diff --git a/nominatim/cli.py b/nominatim/cli.py index f911023b..8c2136f4 100644 --- a/nominatim/cli.py +++ b/nominatim/cli.py @@ -8,6 +8,7 @@ Command-line interface to the Nominatim functions for import, update, database administration and querying. """ +from typing import Optional, Any, List, Union import logging import os import sys @@ -19,16 +20,15 @@ from nominatim.tools.exec_utils import run_legacy_script, run_php_server from nominatim.errors import UsageError from nominatim import clicmd from nominatim import version -from nominatim.clicmd.args import NominatimArgs +from nominatim.clicmd.args import NominatimArgs, Subcommand LOG = logging.getLogger() - class CommandlineParser: """ Wraps some of the common functions for parsing the command line and setting up subcommands. """ - def __init__(self, prog, description): + def __init__(self, prog: str, description: Optional[str]): self.parser = argparse.ArgumentParser( prog=prog, description=description, @@ -56,8 +56,8 @@ class CommandlineParser: group.add_argument('-j', '--threads', metavar='NUM', type=int, help='Number of parallel threads to use') - @staticmethod - def nominatim_version_text(): + + def nominatim_version_text(self) -> str: """ Program name and version number as string """ text = f'Nominatim version {version.version_str()}' @@ -65,11 +65,14 @@ class CommandlineParser: text += f' ({version.GIT_COMMIT_HASH})' return text - def add_subcommand(self, name, cmd): + + def add_subcommand(self, name: str, cmd: Subcommand) -> None: """ Add a subcommand to the parser. The subcommand must be a class with a function add_args() that adds the parameters for the subcommand and a run() function that executes the command. """ + assert cmd.__doc__ is not None + parser = self.subs.add_parser(name, parents=[self.default_args], help=cmd.__doc__.split('\n', 1)[0], description=cmd.__doc__, @@ -78,7 +81,8 @@ class CommandlineParser: parser.set_defaults(command=cmd) cmd.add_args(parser) - def run(self, **kwargs): + + def run(self, **kwargs: Any) -> int: """ Parse the command line arguments of the program and execute the appropriate subcommand. """ @@ -89,7 +93,7 @@ class CommandlineParser: return 1 if args.version: - print(CommandlineParser.nominatim_version_text()) + print(self.nominatim_version_text()) return 0 if args.subcommand is None: @@ -145,8 +149,7 @@ class QueryExport: Export addresses as CSV file from the database. """ - @staticmethod - def add_args(parser): + def add_args(self, parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group('Output arguments') group.add_argument('--output-type', default='street', choices=('continent', 'country', 'state', 'county', @@ -175,11 +178,10 @@ class QueryExport: help='Export only children of this OSM relation') - @staticmethod - def run(args): - params = ['export.php', - '--output-type', args.output_type, - '--output-format', args.output_format] + def run(self, args: NominatimArgs) -> int: + params: List[Union[int, str]] = [ + '--output-type', args.output_type, + '--output-format', args.output_format] if args.output_all_postcodes: params.append('--output-all-postcodes') if args.language: @@ -193,7 +195,7 @@ class QueryExport: if args.restrict_to_osm_relation: params.extend(('--restrict-to-osm-relation', args.restrict_to_osm_relation)) - return run_legacy_script(*params, nominatim_env=args) + return run_legacy_script('export.php', *params, nominatim_env=args) class AdminServe: @@ -207,51 +209,52 @@ class AdminServe: By the default, the webserver can be accessed at: http://127.0.0.1:8088 """ - @staticmethod - def add_args(parser): + def add_args(self, parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group('Server arguments') group.add_argument('--server', default='127.0.0.1:8088', help='The address the server will listen to.') - @staticmethod - def run(args): - run_php_server(args.server, args.project_dir / 'website') -def get_set_parser(**kwargs): + def run(self, args: NominatimArgs) -> int: + run_php_server(args.server, args.project_dir / 'website') + return 0 + + +def get_set_parser(**kwargs: Any) -> CommandlineParser: """\ Initializes the parser and adds various subcommands for nominatim cli. """ parser = CommandlineParser('nominatim', nominatim.__doc__) - parser.add_subcommand('import', clicmd.SetupAll) - parser.add_subcommand('freeze', clicmd.SetupFreeze) - parser.add_subcommand('replication', clicmd.UpdateReplication) + parser.add_subcommand('import', clicmd.SetupAll()) + parser.add_subcommand('freeze', clicmd.SetupFreeze()) + parser.add_subcommand('replication', clicmd.UpdateReplication()) - parser.add_subcommand('special-phrases', clicmd.ImportSpecialPhrases) + parser.add_subcommand('special-phrases', clicmd.ImportSpecialPhrases()) - parser.add_subcommand('add-data', clicmd.UpdateAddData) - parser.add_subcommand('index', clicmd.UpdateIndex) + parser.add_subcommand('add-data', clicmd.UpdateAddData()) + parser.add_subcommand('index', clicmd.UpdateIndex()) parser.add_subcommand('refresh', clicmd.UpdateRefresh()) - parser.add_subcommand('admin', clicmd.AdminFuncs) + parser.add_subcommand('admin', clicmd.AdminFuncs()) - parser.add_subcommand('export', QueryExport) - parser.add_subcommand('serve', AdminServe) + parser.add_subcommand('export', QueryExport()) + parser.add_subcommand('serve', AdminServe()) if kwargs.get('phpcgi_path'): - parser.add_subcommand('search', clicmd.APISearch) - parser.add_subcommand('reverse', clicmd.APIReverse) - parser.add_subcommand('lookup', clicmd.APILookup) - parser.add_subcommand('details', clicmd.APIDetails) - parser.add_subcommand('status', clicmd.APIStatus) + parser.add_subcommand('search', clicmd.APISearch()) + parser.add_subcommand('reverse', clicmd.APIReverse()) + parser.add_subcommand('lookup', clicmd.APILookup()) + parser.add_subcommand('details', clicmd.APIDetails()) + parser.add_subcommand('status', clicmd.APIStatus()) else: parser.parser.epilog = 'php-cgi not found. Query commands not available.' return parser -def nominatim(**kwargs): +def nominatim(**kwargs: Any) -> int: """\ Command-line tools for importing, updating, administrating and querying the Nominatim database. diff --git a/nominatim/clicmd/__init__.py b/nominatim/clicmd/__init__.py index de541134..bdd9bafe 100644 --- a/nominatim/clicmd/__init__.py +++ b/nominatim/clicmd/__init__.py @@ -7,13 +7,20 @@ """ Subcommand definitions for the command-line tool. """ +# mypy and pylint disagree about the style of explicit exports, +# see https://github.com/PyCQA/pylint/issues/6006. +# pylint: disable=useless-import-alias -from nominatim.clicmd.setup import SetupAll -from nominatim.clicmd.replication import UpdateReplication -from nominatim.clicmd.api import APISearch, APIReverse, APILookup, APIDetails, APIStatus -from nominatim.clicmd.index import UpdateIndex -from nominatim.clicmd.refresh import UpdateRefresh -from nominatim.clicmd.add_data import UpdateAddData -from nominatim.clicmd.admin import AdminFuncs -from nominatim.clicmd.freeze import SetupFreeze -from nominatim.clicmd.special_phrases import ImportSpecialPhrases +from nominatim.clicmd.setup import SetupAll as SetupAll +from nominatim.clicmd.replication import UpdateReplication as UpdateReplication +from nominatim.clicmd.api import (APISearch as APISearch, + APIReverse as APIReverse, + APILookup as APILookup, + APIDetails as APIDetails, + APIStatus as APIStatus) +from nominatim.clicmd.index import UpdateIndex as UpdateIndex +from nominatim.clicmd.refresh import UpdateRefresh as UpdateRefresh +from nominatim.clicmd.add_data import UpdateAddData as UpdateAddData +from nominatim.clicmd.admin import AdminFuncs as AdminFuncs +from nominatim.clicmd.freeze import SetupFreeze as SetupFreeze +from nominatim.clicmd.special_phrases import ImportSpecialPhrases as ImportSpecialPhrases diff --git a/nominatim/clicmd/add_data.py b/nominatim/clicmd/add_data.py index 013d5310..8905bc21 100644 --- a/nominatim/clicmd/add_data.py +++ b/nominatim/clicmd/add_data.py @@ -7,10 +7,14 @@ """ Implementation of the 'add-data' subcommand. """ +from typing import cast +import argparse import logging import psutil +from nominatim.clicmd.args import NominatimArgs + # Do not repeat documentation of subcommand classes. # pylint: disable=C0111 # Using non-top-level imports to avoid eventually unused imports. @@ -35,32 +39,31 @@ class UpdateAddData: for more information. """ - @staticmethod - def add_args(parser): + def add_args(self, parser: argparse.ArgumentParser) -> None: group_name = parser.add_argument_group('Source') - group = group_name.add_mutually_exclusive_group(required=True) - group.add_argument('--file', metavar='FILE', - help='Import data from an OSM file or diff file') - group.add_argument('--diff', metavar='FILE', - help='Import data from an OSM diff file (deprecated: use --file)') - group.add_argument('--node', metavar='ID', type=int, - help='Import a single node from the API') - group.add_argument('--way', metavar='ID', type=int, - help='Import a single way from the API') - group.add_argument('--relation', metavar='ID', type=int, - help='Import a single relation from the API') - group.add_argument('--tiger-data', metavar='DIR', - help='Add housenumbers from the US TIGER census database') - group = parser.add_argument_group('Extra arguments') - group.add_argument('--use-main-api', action='store_true', - help='Use OSM API instead of Overpass to download objects') - group.add_argument('--osm2pgsql-cache', metavar='SIZE', type=int, - help='Size of cache to be used by osm2pgsql (in MB)') - group.add_argument('--socket-timeout', dest='socket_timeout', type=int, default=60, - help='Set timeout for file downloads') + group1 = group_name.add_mutually_exclusive_group(required=True) + group1.add_argument('--file', metavar='FILE', + help='Import data from an OSM file or diff file') + group1.add_argument('--diff', metavar='FILE', + help='Import data from an OSM diff file (deprecated: use --file)') + group1.add_argument('--node', metavar='ID', type=int, + help='Import a single node from the API') + group1.add_argument('--way', metavar='ID', type=int, + help='Import a single way from the API') + group1.add_argument('--relation', metavar='ID', type=int, + help='Import a single relation from the API') + group1.add_argument('--tiger-data', metavar='DIR', + help='Add housenumbers from the US TIGER census database') + group2 = parser.add_argument_group('Extra arguments') + group2.add_argument('--use-main-api', action='store_true', + help='Use OSM API instead of Overpass to download objects') + group2.add_argument('--osm2pgsql-cache', metavar='SIZE', type=int, + help='Size of cache to be used by osm2pgsql (in MB)') + group2.add_argument('--socket-timeout', dest='socket_timeout', type=int, default=60, + help='Set timeout for file downloads') - @staticmethod - def run(args): + + def run(self, args: NominatimArgs) -> int: from nominatim.tokenizer import factory as tokenizer_factory from nominatim.tools import tiger_data, add_osm_data @@ -73,7 +76,7 @@ class UpdateAddData: osm2pgsql_params = args.osm2pgsql_options(default_cache=1000, default_threads=1) if args.file or args.diff: - return add_osm_data.add_data_from_file(args.file or args.diff, + return add_osm_data.add_data_from_file(cast(str, args.file or args.diff), osm2pgsql_params) if args.node: diff --git a/nominatim/clicmd/admin.py b/nominatim/clicmd/admin.py index 1ed0ac9b..ad900579 100644 --- a/nominatim/clicmd/admin.py +++ b/nominatim/clicmd/admin.py @@ -8,8 +8,10 @@ Implementation of the 'admin' subcommand. """ import logging +import argparse from nominatim.tools.exec_utils import run_legacy_script +from nominatim.clicmd.args import NominatimArgs # Do not repeat documentation of subcommand classes. # pylint: disable=C0111 @@ -23,8 +25,7 @@ class AdminFuncs: Analyse and maintain the database. """ - @staticmethod - def add_args(parser): + def add_args(self, parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group('Admin tasks') objs = group.add_mutually_exclusive_group(required=True) objs.add_argument('--warm', action='store_true', @@ -49,10 +50,9 @@ class AdminFuncs: mgroup.add_argument('--place-id', type=int, help='Analyse indexing of the given Nominatim object') - @staticmethod - def run(args): + def run(self, args: NominatimArgs) -> int: if args.warm: - return AdminFuncs._warm(args) + return self._warm(args) if args.check_database: LOG.warning('Checking database') @@ -73,8 +73,7 @@ class AdminFuncs: return 1 - @staticmethod - def _warm(args): + def _warm(self, args: NominatimArgs) -> int: LOG.warning('Warming database caches') params = ['warm.php'] if args.target == 'reverse': diff --git a/nominatim/clicmd/api.py b/nominatim/clicmd/api.py index ab7d1658..b899afad 100644 --- a/nominatim/clicmd/api.py +++ b/nominatim/clicmd/api.py @@ -7,10 +7,13 @@ """ Subcommand definitions for API calls from the command line. """ +from typing import Mapping, Dict +import argparse import logging from nominatim.tools.exec_utils import run_api_script from nominatim.errors import UsageError +from nominatim.clicmd.args import NominatimArgs # Do not repeat documentation of subcommand classes. # pylint: disable=C0111 @@ -42,7 +45,7 @@ DETAILS_SWITCHES = ( ('polygon_geojson', 'Include geometry of result') ) -def _add_api_output_arguments(parser): +def _add_api_output_arguments(parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group('Output arguments') group.add_argument('--format', default='jsonv2', choices=['xml', 'json', 'jsonv2', 'geojson', 'geocodejson'], @@ -60,7 +63,7 @@ def _add_api_output_arguments(parser): "Parameter is difference tolerance in degrees.")) -def _run_api(endpoint, args, params): +def _run_api(endpoint: str, args: NominatimArgs, params: Mapping[str, object]) -> int: script_file = args.project_dir / 'website' / (endpoint + '.php') if not script_file.exists(): @@ -82,8 +85,7 @@ class APISearch: https://nominatim.org/release-docs/latest/api/Search/ """ - @staticmethod - def add_args(parser): + def add_args(self, parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group('Query arguments') group.add_argument('--query', help='Free-form query string') @@ -109,8 +111,8 @@ class APISearch: help='Do not remove duplicates from the result list') - @staticmethod - def run(args): + def run(self, args: NominatimArgs) -> int: + params: Dict[str, object] if args.query: params = dict(q=args.query) else: @@ -145,8 +147,7 @@ class APIReverse: https://nominatim.org/release-docs/latest/api/Reverse/ """ - @staticmethod - def add_args(parser): + def add_args(self, parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group('Query arguments') group.add_argument('--lat', type=float, required=True, help='Latitude of coordinate to look up (in WGS84)') @@ -158,8 +159,7 @@ class APIReverse: _add_api_output_arguments(parser) - @staticmethod - def run(args): + def run(self, args: NominatimArgs) -> int: params = dict(lat=args.lat, lon=args.lon, format=args.format) if args.zoom is not None: params['zoom'] = args.zoom @@ -187,8 +187,7 @@ class APILookup: https://nominatim.org/release-docs/latest/api/Lookup/ """ - @staticmethod - def add_args(parser): + def add_args(self, parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group('Query arguments') group.add_argument('--id', metavar='OSMID', action='append', required=True, dest='ids', @@ -197,9 +196,8 @@ class APILookup: _add_api_output_arguments(parser) - @staticmethod - def run(args): - params = dict(osm_ids=','.join(args.ids), format=args.format) + def run(self, args: NominatimArgs) -> int: + params: Dict[str, object] = dict(osm_ids=','.join(args.ids), format=args.format) for param, _ in EXTRADATA_PARAMS: if getattr(args, param): @@ -224,8 +222,7 @@ class APIDetails: https://nominatim.org/release-docs/latest/api/Details/ """ - @staticmethod - def add_args(parser): + def add_args(self, parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group('Query arguments') objs = group.add_mutually_exclusive_group(required=True) objs.add_argument('--node', '-n', type=int, @@ -246,8 +243,8 @@ class APIDetails: group.add_argument('--lang', '--accept-language', metavar='LANGS', help='Preferred language order for presenting search results') - @staticmethod - def run(args): + + def run(self, args: NominatimArgs) -> int: if args.node: params = dict(osmtype='N', osmid=args.node) elif args.way: @@ -276,12 +273,11 @@ class APIStatus: https://nominatim.org/release-docs/latest/api/Status/ """ - @staticmethod - def add_args(parser): + def add_args(self, parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group('API parameters') group.add_argument('--format', default='text', choices=['text', 'json'], help='Format of result') - @staticmethod - def run(args): + + def run(self, args: NominatimArgs) -> int: return _run_api('status', args, dict(format=args.format)) diff --git a/nominatim/clicmd/args.py b/nominatim/clicmd/args.py index d1f47ba0..c976f394 100644 --- a/nominatim/clicmd/args.py +++ b/nominatim/clicmd/args.py @@ -7,19 +7,174 @@ """ Provides custom functions over command-line arguments. """ +from typing import Optional, List, Dict, Any, Sequence, Tuple +import argparse import logging from pathlib import Path from nominatim.errors import UsageError +from nominatim.config import Configuration +from nominatim.typing import Protocol LOG = logging.getLogger() +class Subcommand(Protocol): + """ + Interface to be implemented by classes implementing a CLI subcommand. + """ + + def add_args(self, parser: argparse.ArgumentParser) -> None: + """ + Fill the given parser for the subcommand with the appropriate + parameters. + """ + + def run(self, args: 'NominatimArgs') -> int: + """ + Run the subcommand with the given parsed arguments. + """ + + class NominatimArgs: """ Customized namespace class for the nominatim command line tool to receive the command-line arguments. """ + # Basic environment set by root program. + config: Configuration + project_dir: Path + module_dir: Path + osm2pgsql_path: Path + phplib_dir: Path + sqllib_dir: Path + data_dir: Path + config_dir: Path + phpcgi_path: Path - def osm2pgsql_options(self, default_cache, default_threads): + # Global switches + version: bool + subcommand: Optional[str] + command: Subcommand + + # Shared parameters + osm2pgsql_cache: Optional[int] + socket_timeout: int + + # Arguments added to all subcommands. + verbose: int + threads: Optional[int] + + # Arguments to 'add-data' + file: Optional[str] + diff: Optional[str] + node: Optional[int] + way: Optional[int] + relation: Optional[int] + tiger_data: Optional[str] + use_main_api: bool + + # Arguments to 'admin' + warm: bool + check_database: bool + migrate: bool + analyse_indexing: bool + target: Optional[str] + osm_id: Optional[str] + place_id: Optional[int] + + # Arguments to 'import' + osm_file: List[str] + continue_at: Optional[str] + reverse_only: bool + no_partitions: bool + no_updates: bool + offline: bool + ignore_errors: bool + index_noanalyse: bool + + # Arguments to 'index' + boundaries_only: bool + no_boundaries: bool + minrank: int + maxrank: int + + # Arguments to 'export' + output_type: str + output_format: str + output_all_postcodes: bool + language: Optional[str] + restrict_to_country: Optional[str] + restrict_to_osm_node: Optional[int] + restrict_to_osm_way: Optional[int] + restrict_to_osm_relation: Optional[int] + + # Arguments to 'refresh' + postcodes: bool + word_tokens: bool + word_counts: bool + address_levels: bool + functions: bool + wiki_data: bool + importance: bool + website: bool + diffs: bool + enable_debug_statements: bool + data_object: Sequence[Tuple[str, int]] + data_area: Sequence[Tuple[str, int]] + + # Arguments to 'replication' + init: bool + update_functions: bool + check_for_updates: bool + once: bool + catch_up: bool + do_index: bool + + # Arguments to 'serve' + server: str + + # Arguments to 'special-phrases + import_from_wiki: bool + import_from_csv: Optional[str] + no_replace: bool + + # Arguments to all query functions + format: str + addressdetails: bool + extratags: bool + namedetails: bool + lang: Optional[str] + polygon_output: Optional[str] + polygon_threshold: Optional[float] + + # Arguments to 'search' + query: Optional[str] + street: Optional[str] + city: Optional[str] + county: Optional[str] + state: Optional[str] + country: Optional[str] + postalcode: Optional[str] + countrycodes: Optional[str] + exclude_place_ids: Optional[str] + limit: Optional[int] + viewbox: Optional[str] + bounded: bool + dedupe: bool + + # Arguments to 'reverse' + lat: float + lon: float + zoom: Optional[int] + + # Arguments to 'lookup' + ids: Sequence[str] + + # Arguments to 'details' + object_class: Optional[str] + + + def osm2pgsql_options(self, default_cache: int, + default_threads: int) -> Dict[str, Any]: """ Return the standard osm2pgsql options that can be derived from the command line arguments. The resulting dict can be further customized and then used in `run_osm2pgsql()`. @@ -29,7 +184,7 @@ class NominatimArgs: osm2pgsql_style=self.config.get_import_style_file(), threads=self.threads or default_threads, dsn=self.config.get_libpq_dsn(), - flatnode_file=str(self.config.get_path('FLATNODE_FILE')), + flatnode_file=str(self.config.get_path('FLATNODE_FILE') or ''), tablespaces=dict(slim_data=self.config.TABLESPACE_OSM_DATA, slim_index=self.config.TABLESPACE_OSM_INDEX, main_data=self.config.TABLESPACE_PLACE_DATA, @@ -38,7 +193,7 @@ class NominatimArgs: ) - def get_osm_file_list(self): + def get_osm_file_list(self) -> Optional[List[Path]]: """ Return the --osm-file argument as a list of Paths or None if no argument was given. The function also checks if the files exist and raises a UsageError if one cannot be found. diff --git a/nominatim/clicmd/freeze.py b/nominatim/clicmd/freeze.py index 85eb1b4a..5dfdd255 100644 --- a/nominatim/clicmd/freeze.py +++ b/nominatim/clicmd/freeze.py @@ -7,8 +7,10 @@ """ Implementation of the 'freeze' subcommand. """ +import argparse from nominatim.db.connection import connect +from nominatim.clicmd.args import NominatimArgs # Do not repeat documentation of subcommand classes. # pylint: disable=C0111 @@ -27,16 +29,15 @@ class SetupFreeze: This command has the same effect as the `--no-updates` option for imports. """ - @staticmethod - def add_args(parser): + def add_args(self, parser: argparse.ArgumentParser) -> None: pass # No options - @staticmethod - def run(args): + + def run(self, args: NominatimArgs) -> int: from ..tools import freeze with connect(args.config.get_libpq_dsn()) as conn: freeze.drop_update_tables(conn) - freeze.drop_flatnode_file(str(args.config.get_path('FLATNODE_FILE'))) + freeze.drop_flatnode_file(args.config.get_path('FLATNODE_FILE')) return 0 diff --git a/nominatim/clicmd/index.py b/nominatim/clicmd/index.py index 73258be2..16b5311c 100644 --- a/nominatim/clicmd/index.py +++ b/nominatim/clicmd/index.py @@ -7,10 +7,13 @@ """ Implementation of the 'index' subcommand. """ +import argparse + import psutil from nominatim.db import status from nominatim.db.connection import connect +from nominatim.clicmd.args import NominatimArgs # Do not repeat documentation of subcommand classes. # pylint: disable=C0111 @@ -28,8 +31,7 @@ class UpdateIndex: of indexing. For other cases, this function allows to run indexing manually. """ - @staticmethod - def add_args(parser): + def add_args(self, parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group('Filter arguments') group.add_argument('--boundaries-only', action='store_true', help="""Index only administrative boundaries.""") @@ -40,8 +42,8 @@ class UpdateIndex: group.add_argument('--maxrank', '-R', type=int, metavar='RANK', default=30, help='Maximum/finishing rank') - @staticmethod - def run(args): + + def run(self, args: NominatimArgs) -> int: from ..indexer.indexer import Indexer from ..tokenizer import factory as tokenizer_factory diff --git a/nominatim/clicmd/refresh.py b/nominatim/clicmd/refresh.py index ecc7498e..dce28d98 100644 --- a/nominatim/clicmd/refresh.py +++ b/nominatim/clicmd/refresh.py @@ -7,11 +7,15 @@ """ Implementation of 'refresh' subcommand. """ -from argparse import ArgumentTypeError +from typing import Tuple, Optional +import argparse import logging from pathlib import Path +from nominatim.config import Configuration from nominatim.db.connection import connect +from nominatim.tokenizer.base import AbstractTokenizer +from nominatim.clicmd.args import NominatimArgs # Do not repeat documentation of subcommand classes. # pylint: disable=C0111 @@ -20,12 +24,12 @@ from nominatim.db.connection import connect LOG = logging.getLogger() -def _parse_osm_object(obj): +def _parse_osm_object(obj: str) -> Tuple[str, int]: """ Parse the given argument into a tuple of OSM type and ID. Raises an ArgumentError if the format is not recognized. """ if len(obj) < 2 or obj[0].lower() not in 'nrw' or not obj[1:].isdigit(): - raise ArgumentTypeError("Cannot parse OSM ID. Expect format: [N|W|R].") + raise argparse.ArgumentTypeError("Cannot parse OSM ID. Expect format: [N|W|R].") return (obj[0].upper(), int(obj[1:])) @@ -42,11 +46,10 @@ class UpdateRefresh: Warning: the 'update' command must not be run in parallel with other update commands like 'replication' or 'add-data'. """ - def __init__(self): - self.tokenizer = None + def __init__(self) -> None: + self.tokenizer: Optional[AbstractTokenizer] = None - @staticmethod - def add_args(parser): + def add_args(self, parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group('Data arguments') group.add_argument('--postcodes', action='store_true', help='Update postcode centroid table') @@ -80,7 +83,7 @@ class UpdateRefresh: help='Enable debug warning statements in functions') - def run(self, args): #pylint: disable=too-many-branches + def run(self, args: NominatimArgs) -> int: #pylint: disable=too-many-branches from ..tools import refresh, postcodes from ..indexer.indexer import Indexer @@ -155,7 +158,7 @@ class UpdateRefresh: return 0 - def _get_tokenizer(self, config): + def _get_tokenizer(self, config: Configuration) -> AbstractTokenizer: if self.tokenizer is None: from ..tokenizer import factory as tokenizer_factory diff --git a/nominatim/clicmd/replication.py b/nominatim/clicmd/replication.py index 9d946304..2d6396a1 100644 --- a/nominatim/clicmd/replication.py +++ b/nominatim/clicmd/replication.py @@ -7,6 +7,8 @@ """ Implementation of the 'replication' sub-command. """ +from typing import Optional +import argparse import datetime as dt import logging import socket @@ -15,6 +17,7 @@ import time from nominatim.db import status from nominatim.db.connection import connect from nominatim.errors import UsageError +from nominatim.clicmd.args import NominatimArgs LOG = logging.getLogger() @@ -41,8 +44,7 @@ class UpdateReplication: downloads and imports the next batch of updates. """ - @staticmethod - def add_args(parser): + def add_args(self, parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group('Arguments for initialisation') group.add_argument('--init', action='store_true', help='Initialise the update process') @@ -68,8 +70,8 @@ class UpdateReplication: group.add_argument('--socket-timeout', dest='socket_timeout', type=int, default=60, help='Set timeout for file downloads') - @staticmethod - def _init_replication(args): + + def _init_replication(self, args: NominatimArgs) -> int: from ..tools import replication, refresh LOG.warning("Initialising replication updates") @@ -81,16 +83,17 @@ class UpdateReplication: return 0 - @staticmethod - def _check_for_updates(args): + def _check_for_updates(self, args: NominatimArgs) -> int: from ..tools import replication with connect(args.config.get_libpq_dsn()) as conn: return replication.check_for_updates(conn, base_url=args.config.REPLICATION_URL) - @staticmethod - def _report_update(batchdate, start_import, start_index): - def round_time(delta): + + def _report_update(self, batchdate: dt.datetime, + start_import: dt.datetime, + start_index: Optional[dt.datetime]) -> None: + def round_time(delta: dt.timedelta) -> dt.timedelta: return dt.timedelta(seconds=int(delta.total_seconds())) end = dt.datetime.now(dt.timezone.utc) @@ -101,8 +104,7 @@ class UpdateReplication: round_time(end - batchdate)) - @staticmethod - def _compute_update_interval(args): + def _compute_update_interval(self, args: NominatimArgs) -> int: if args.catch_up: return 0 @@ -119,13 +121,13 @@ class UpdateReplication: return update_interval - @staticmethod - def _update(args): + def _update(self, args: NominatimArgs) -> None: + # pylint: disable=too-many-locals from ..tools import replication from ..indexer.indexer import Indexer from ..tokenizer import factory as tokenizer_factory - update_interval = UpdateReplication._compute_update_interval(args) + update_interval = self._compute_update_interval(args) params = args.osm2pgsql_options(default_cache=2000, default_threads=1) params.update(base_url=args.config.REPLICATION_URL, @@ -169,7 +171,8 @@ class UpdateReplication: indexer.index_full(analyse=False) if LOG.isEnabledFor(logging.WARNING): - UpdateReplication._report_update(batchdate, start, index_start) + assert batchdate is not None + self._report_update(batchdate, start, index_start) if args.once or (args.catch_up and state is replication.UpdateState.NO_CHANGES): break @@ -179,15 +182,14 @@ class UpdateReplication: time.sleep(recheck_interval) - @staticmethod - def run(args): + def run(self, args: NominatimArgs) -> int: socket.setdefaulttimeout(args.socket_timeout) if args.init: - return UpdateReplication._init_replication(args) + return self._init_replication(args) if args.check_for_updates: - return UpdateReplication._check_for_updates(args) + return self._check_for_updates(args) - UpdateReplication._update(args) + self._update(args) return 0 diff --git a/nominatim/clicmd/setup.py b/nominatim/clicmd/setup.py index 73095468..6ffa7afb 100644 --- a/nominatim/clicmd/setup.py +++ b/nominatim/clicmd/setup.py @@ -7,14 +7,20 @@ """ Implementation of the 'import' subcommand. """ +from typing import Optional +import argparse import logging from pathlib import Path import psutil -from nominatim.db.connection import connect +from nominatim.config import Configuration +from nominatim.db.connection import connect, Connection from nominatim.db import status, properties +from nominatim.tokenizer.base import AbstractTokenizer from nominatim.version import version_str +from nominatim.clicmd.args import NominatimArgs +from nominatim.errors import UsageError # Do not repeat documentation of subcommand classes. # pylint: disable=C0111 @@ -32,38 +38,36 @@ class SetupAll: needs superuser rights on the database. """ - @staticmethod - def add_args(parser): + def add_args(self, parser: argparse.ArgumentParser) -> None: group_name = parser.add_argument_group('Required arguments') - group = group_name.add_mutually_exclusive_group(required=True) - group.add_argument('--osm-file', metavar='FILE', action='append', + group1 = group_name.add_mutually_exclusive_group(required=True) + group1.add_argument('--osm-file', metavar='FILE', action='append', help='OSM file to be imported' ' (repeat for importing multiple files)') - group.add_argument('--continue', dest='continue_at', + group1.add_argument('--continue', dest='continue_at', choices=['load-data', 'indexing', 'db-postprocess'], help='Continue an import that was interrupted') - group = parser.add_argument_group('Optional arguments') - group.add_argument('--osm2pgsql-cache', metavar='SIZE', type=int, + group2 = parser.add_argument_group('Optional arguments') + group2.add_argument('--osm2pgsql-cache', metavar='SIZE', type=int, help='Size of cache to be used by osm2pgsql (in MB)') - group.add_argument('--reverse-only', action='store_true', + group2.add_argument('--reverse-only', action='store_true', help='Do not create tables and indexes for searching') - group.add_argument('--no-partitions', action='store_true', + group2.add_argument('--no-partitions', action='store_true', help=("Do not partition search indices " "(speeds up import of single country extracts)")) - group.add_argument('--no-updates', action='store_true', + group2.add_argument('--no-updates', action='store_true', help="Do not keep tables that are only needed for " "updating the database later") - group.add_argument('--offline', action='store_true', + group2.add_argument('--offline', action='store_true', help="Do not attempt to load any additional data from the internet") - group = parser.add_argument_group('Expert options') - group.add_argument('--ignore-errors', action='store_true', + group3 = parser.add_argument_group('Expert options') + group3.add_argument('--ignore-errors', action='store_true', help='Continue import even when errors in SQL are present') - group.add_argument('--index-noanalyse', action='store_true', + group3.add_argument('--index-noanalyse', action='store_true', help='Do not perform analyse operations during index (expert only)') - @staticmethod - def run(args): # pylint: disable=too-many-statements + def run(self, args: NominatimArgs) -> int: # pylint: disable=too-many-statements from ..data import country_info from ..tools import database_import, refresh, postcodes, freeze from ..indexer.indexer import Indexer @@ -72,6 +76,8 @@ class SetupAll: if args.continue_at is None: files = args.get_osm_file_list() + if not files: + raise UsageError("No input files (use --osm-file).") LOG.warning('Creating database') database_import.setup_database_skeleton(args.config.get_libpq_dsn(), @@ -88,7 +94,7 @@ class SetupAll: drop=args.no_updates, ignore_errors=args.ignore_errors) - SetupAll._setup_tables(args.config, args.reverse_only) + self._setup_tables(args.config, args.reverse_only) LOG.warning('Importing wikipedia importance data') data_path = Path(args.config.WIKIPEDIA_DATA_PATH or args.project_dir) @@ -107,7 +113,7 @@ class SetupAll: args.threads or psutil.cpu_count() or 1) LOG.warning("Setting up tokenizer") - tokenizer = SetupAll._get_tokenizer(args.continue_at, args.config) + tokenizer = self._get_tokenizer(args.continue_at, args.config) if args.continue_at is None or args.continue_at == 'load-data': LOG.warning('Calculate postcodes') @@ -117,7 +123,7 @@ class SetupAll: if args.continue_at is None or args.continue_at in ('load-data', 'indexing'): if args.continue_at is not None and args.continue_at != 'load-data': with connect(args.config.get_libpq_dsn()) as conn: - SetupAll._create_pending_index(conn, args.config.TABLESPACE_ADDRESS_INDEX) + self._create_pending_index(conn, args.config.TABLESPACE_ADDRESS_INDEX) LOG.warning('Indexing places') indexer = Indexer(args.config.get_libpq_dsn(), tokenizer, args.threads or psutil.cpu_count() or 1) @@ -142,13 +148,12 @@ class SetupAll: with connect(args.config.get_libpq_dsn()) as conn: refresh.setup_website(webdir, args.config, conn) - SetupAll._finalize_database(args.config.get_libpq_dsn(), args.offline) + self._finalize_database(args.config.get_libpq_dsn(), args.offline) return 0 - @staticmethod - def _setup_tables(config, reverse_only): + def _setup_tables(self, config: Configuration, reverse_only: bool) -> None: """ Set up the basic database layout: tables, indexes and functions. """ from ..tools import database_import, refresh @@ -169,8 +174,8 @@ class SetupAll: refresh.create_functions(conn, config, False, False) - @staticmethod - def _get_tokenizer(continue_at, config): + def _get_tokenizer(self, continue_at: Optional[str], + config: Configuration) -> AbstractTokenizer: """ Set up a new tokenizer or load an already initialised one. """ from ..tokenizer import factory as tokenizer_factory @@ -182,8 +187,8 @@ class SetupAll: # just load the tokenizer return tokenizer_factory.get_tokenizer_for_db(config) - @staticmethod - def _create_pending_index(conn, tablespace): + + def _create_pending_index(self, conn: Connection, tablespace: str) -> None: """ Add a supporting index for finding places still to be indexed. This index is normally created at the end of the import process @@ -204,8 +209,7 @@ class SetupAll: conn.commit() - @staticmethod - def _finalize_database(dsn, offline): + def _finalize_database(self, dsn: str, offline: bool) -> None: """ Determine the database date and set the status accordingly. """ with connect(dsn) as conn: diff --git a/nominatim/clicmd/special_phrases.py b/nominatim/clicmd/special_phrases.py index a2c346de..beac0c84 100644 --- a/nominatim/clicmd/special_phrases.py +++ b/nominatim/clicmd/special_phrases.py @@ -7,13 +7,16 @@ """ Implementation of the 'special-phrases' command. """ +import argparse import logging from pathlib import Path + from nominatim.errors import UsageError from nominatim.db.connection import connect -from nominatim.tools.special_phrases.sp_importer import SPImporter +from nominatim.tools.special_phrases.sp_importer import SPImporter, SpecialPhraseLoader from nominatim.tools.special_phrases.sp_wiki_loader import SPWikiLoader from nominatim.tools.special_phrases.sp_csv_loader import SPCsvLoader +from nominatim.clicmd.args import NominatimArgs LOG = logging.getLogger() @@ -49,8 +52,8 @@ class ImportSpecialPhrases: with custom rules into the project directory or by using the `--config` option to point to another configuration file. """ - @staticmethod - def add_args(parser): + + def add_args(self, parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group('Input arguments') group.add_argument('--import-from-wiki', action='store_true', help='Import special phrases from the OSM wiki to the database') @@ -58,26 +61,24 @@ class ImportSpecialPhrases: help='Import special phrases from a CSV file') group.add_argument('--no-replace', action='store_true', help='Keep the old phrases and only add the new ones') - group.add_argument('--config', action='store', - help='Configuration file for black/white listing ' - '(default: phrase-settings.json)') - @staticmethod - def run(args): + + def run(self, args: NominatimArgs) -> int: + if args.import_from_wiki: - ImportSpecialPhrases.start_import(args, SPWikiLoader(args.config)) + self.start_import(args, SPWikiLoader(args.config)) if args.import_from_csv: if not Path(args.import_from_csv).is_file(): LOG.fatal("CSV file '%s' does not exist.", args.import_from_csv) raise UsageError('Cannot access file.') - ImportSpecialPhrases.start_import(args, SPCsvLoader(args.import_from_csv)) + self.start_import(args, SPCsvLoader(args.import_from_csv)) return 0 - @staticmethod - def start_import(args, loader): + + def start_import(self, args: NominatimArgs, loader: SpecialPhraseLoader) -> None: """ Create the SPImporter object containing the right sp loader and then start the import of special phrases. diff --git a/nominatim/config.py b/nominatim/config.py index b3934b49..78496550 100644 --- a/nominatim/config.py +++ b/nominatim/config.py @@ -7,6 +7,7 @@ """ Nominatim configuration accessor. """ +from typing import Dict, Any, List, Mapping, Optional import logging import os from pathlib import Path @@ -15,12 +16,13 @@ import yaml from dotenv import dotenv_values +from nominatim.typing import StrPath from nominatim.errors import UsageError LOG = logging.getLogger() -CONFIG_CACHE = {} +CONFIG_CACHE : Dict[str, Any] = {} -def flatten_config_list(content, section=''): +def flatten_config_list(content: Any, section: str = '') -> List[Any]: """ Flatten YAML configuration lists that contain include sections which are lists themselves. """ @@ -54,7 +56,8 @@ class Configuration: avoid conflicts with other environment variables. """ - def __init__(self, project_dir, config_dir, environ=None): + def __init__(self, project_dir: Path, config_dir: Path, + environ: Optional[Mapping[str, str]] = None) -> None: self.environ = environ or os.environ self.project_dir = project_dir self.config_dir = config_dir @@ -63,25 +66,32 @@ class Configuration: self._config.update(dotenv_values(str((project_dir / '.env').resolve()))) class _LibDirs: - pass + module: Path + osm2pgsql: Path + php: Path + sql: Path + data: Path self.lib_dir = _LibDirs() - def set_libdirs(self, **kwargs): + + def set_libdirs(self, **kwargs: StrPath) -> None: """ Set paths to library functions and data. """ for key, value in kwargs.items(): setattr(self.lib_dir, key, Path(value).resolve()) - def __getattr__(self, name): + + def __getattr__(self, name: str) -> str: name = 'NOMINATIM_' + name if name in self.environ: return self.environ[name] - return self._config[name] + return self._config[name] or '' - def get_bool(self, name): + + def get_bool(self, name: str) -> bool: """ Return the given configuration parameter as a boolean. Values of '1', 'yes' and 'true' are accepted as truthy values, everything else is interpreted as false. @@ -89,7 +99,7 @@ class Configuration: return getattr(self, name).lower() in ('1', 'yes', 'true') - def get_int(self, name): + def get_int(self, name: str) -> int: """ Return the given configuration parameter as an int. """ try: @@ -99,7 +109,7 @@ class Configuration: raise UsageError("Configuration error.") from exp - def get_str_list(self, name): + def get_str_list(self, name: str) -> Optional[List[str]]: """ Return the given configuration parameter as a list of strings. The values are assumed to be given as a comma-sparated list and will be stripped before returning them. On empty values None @@ -110,30 +120,31 @@ class Configuration: return [v.strip() for v in raw.split(',')] if raw else None - def get_path(self, name): + def get_path(self, name: str) -> Optional[Path]: """ Return the given configuration parameter as a Path. If a relative path is configured, then the function converts this into an absolute path with the project directory as root path. - If the configuration is unset, a falsy value is returned. + If the configuration is unset, None is returned. """ value = getattr(self, name) - if value: - value = Path(value) + if not value: + return None - if not value.is_absolute(): - value = self.project_dir / value + cfgpath = Path(value) - value = value.resolve() + if not cfgpath.is_absolute(): + cfgpath = self.project_dir / cfgpath - return value + return cfgpath.resolve() - def get_libpq_dsn(self): + + def get_libpq_dsn(self) -> str: """ Get configured database DSN converted into the key/value format understood by libpq and psycopg. """ dsn = self.DATABASE_DSN - def quote_param(param): + def quote_param(param: str) -> str: key, val = param.split('=') val = val.replace('\\', '\\\\').replace("'", "\\'") if ' ' in val: @@ -147,7 +158,7 @@ class Configuration: return dsn - def get_import_style_file(self): + def get_import_style_file(self) -> Path: """ Return the import style file as a path object. Translates the name of the standard styles automatically into a file in the config style. @@ -160,7 +171,7 @@ class Configuration: return self.find_config_file('', 'IMPORT_STYLE') - def get_os_env(self): + def get_os_env(self) -> Dict[str, Optional[str]]: """ Return a copy of the OS environment with the Nominatim configuration merged in. """ @@ -170,7 +181,8 @@ class Configuration: return env - def load_sub_configuration(self, filename, config=None): + def load_sub_configuration(self, filename: StrPath, + config: Optional[str] = None) -> Any: """ Load additional configuration from a file. `filename` is the name of the configuration file. The file is first searched in the project directory and then in the global settings dirctory. @@ -207,16 +219,17 @@ class Configuration: return result - def find_config_file(self, filename, config=None): + def find_config_file(self, filename: StrPath, + config: Optional[str] = None) -> Path: """ Resolve the location of a configuration file given a filename and an optional configuration option with the file name. Raises a UsageError when the file cannot be found or is not a regular file. """ if config is not None: - cfg_filename = getattr(self, config) - if cfg_filename: - cfg_filename = Path(cfg_filename) + cfg_value = getattr(self, config) + if cfg_value: + cfg_filename = Path(cfg_value) if cfg_filename.is_absolute(): cfg_filename = cfg_filename.resolve() @@ -240,7 +253,7 @@ class Configuration: raise UsageError("Config file not found.") - def _load_from_yaml(self, cfgfile): + def _load_from_yaml(self, cfgfile: Path) -> Any: """ Load a YAML configuration file. This installs a special handler that allows to include other YAML files using the '!include' operator. """ @@ -249,7 +262,7 @@ class Configuration: return yaml.safe_load(cfgfile.read_text(encoding='utf-8')) - def _yaml_include_representer(self, loader, node): + def _yaml_include_representer(self, loader: Any, node: yaml.Node) -> Any: """ Handler for the '!include' operator in YAML files. When the filename is relative, then the file is first searched in the diff --git a/nominatim/data/country_info.py b/nominatim/data/country_info.py index d754b4dd..eb0190b5 100644 --- a/nominatim/data/country_info.py +++ b/nominatim/data/country_info.py @@ -7,13 +7,17 @@ """ Functions for importing and managing static country information. """ +from typing import Dict, Any, Iterable, Tuple, Optional, Container, overload +from pathlib import Path import psycopg2.extras from nominatim.db import utils as db_utils -from nominatim.db.connection import connect +from nominatim.db.connection import connect, Connection from nominatim.errors import UsageError +from nominatim.config import Configuration +from nominatim.tokenizer.base import AbstractTokenizer -def _flatten_name_list(names): +def _flatten_name_list(names: Any) -> Dict[str, str]: if names is None: return {} @@ -41,11 +45,11 @@ class _CountryInfo: """ Caches country-specific properties from the configuration file. """ - def __init__(self): - self._info = {} + def __init__(self) -> None: + self._info: Dict[str, Dict[str, Any]] = {} - def load(self, config): + def load(self, config: Configuration) -> None: """ Load the country properties from the configuration files, if they are not loaded yet. """ @@ -61,12 +65,12 @@ class _CountryInfo: prop['names'] = _flatten_name_list(prop.get('names')) - def items(self): + def items(self) -> Iterable[Tuple[str, Dict[str, Any]]]: """ Return tuples of (country_code, property dict) as iterable. """ return self._info.items() - def get(self, country_code): + def get(self, country_code: str) -> Dict[str, Any]: """ Get country information for the country with the given country code. """ return self._info.get(country_code, {}) @@ -76,15 +80,22 @@ class _CountryInfo: _COUNTRY_INFO = _CountryInfo() -def setup_country_config(config): +def setup_country_config(config: Configuration) -> None: """ Load country properties from the configuration file. Needs to be called before using any other functions in this file. """ _COUNTRY_INFO.load(config) +@overload +def iterate() -> Iterable[Tuple[str, Dict[str, Any]]]: + ... -def iterate(prop=None): +@overload +def iterate(prop: str) -> Iterable[Tuple[str, Any]]: + ... + +def iterate(prop: Optional[str] = None) -> Iterable[Tuple[str, Dict[str, Any]]]: """ Iterate over country code and properties. When `prop` is None, all countries are returned with their complete @@ -100,7 +111,7 @@ def iterate(prop=None): return ((c, p[prop]) for c, p in _COUNTRY_INFO.items() if prop in p) -def setup_country_tables(dsn, sql_dir, ignore_partitions=False): +def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = False) -> None: """ Create and populate the tables with basic static data that provides the background for geocoding. Data is assumed to not yet exist. """ @@ -112,7 +123,7 @@ def setup_country_tables(dsn, sql_dir, ignore_partitions=False): if ignore_partitions: partition = 0 else: - partition = props.get('partition') + partition = props.get('partition', 0) lang = props['languages'][0] if len( props['languages']) == 1 else None @@ -135,13 +146,14 @@ def setup_country_tables(dsn, sql_dir, ignore_partitions=False): conn.commit() -def create_country_names(conn, tokenizer, languages=None): +def create_country_names(conn: Connection, tokenizer: AbstractTokenizer, + languages: Optional[Container[str]] = None) -> None: """ Add default country names to search index. `languages` is a comma- separated list of language codes as used in OSM. If `languages` is not empty then only name translations for the given languages are added to the index. """ - def _include_key(key): + def _include_key(key: str) -> bool: return ':' not in key or not languages or \ key[key.index(':') + 1:] in languages diff --git a/nominatim/data/place_info.py b/nominatim/data/place_info.py index d2ba3979..96912a61 100644 --- a/nominatim/data/place_info.py +++ b/nominatim/data/place_info.py @@ -8,18 +8,19 @@ Wrapper around place information the indexer gets from the database and hands to the tokenizer. """ +from typing import Optional, Mapping, Any class PlaceInfo: """ Data class containing all information the tokenizer gets about a place it should process the names for. """ - def __init__(self, info): + def __init__(self, info: Mapping[str, Any]) -> None: self._info = info @property - def name(self): + def name(self) -> Optional[Mapping[str, str]]: """ A dictionary with the names of the place or None if the place has no names. """ @@ -27,7 +28,7 @@ class PlaceInfo: @property - def address(self): + def address(self) -> Optional[Mapping[str, str]]: """ A dictionary with the address elements of the place or None if no address information is available. """ @@ -35,7 +36,7 @@ class PlaceInfo: @property - def country_code(self): + def country_code(self) -> Optional[str]: """ The country code of the country the place is in. Guaranteed to be a two-letter lower-case string or None, if no country could be found. @@ -44,20 +45,20 @@ class PlaceInfo: @property - def rank_address(self): + def rank_address(self) -> int: """ The computed rank address before rank correction. """ - return self._info.get('rank_address') + return self._info.get('rank_address', 0) - def is_a(self, key, value): + def is_a(self, key: str, value: str) -> bool: """ Check if the place's primary tag corresponds to the given key and value. """ return self._info.get('class') == key and self._info.get('type') == value - def is_country(self): + def is_country(self) -> bool: """ Check if the place is a valid country boundary. """ return self.rank_address == 4 \ diff --git a/nominatim/data/postcode_format.py b/nominatim/data/postcode_format.py index 366ea505..dad35b7a 100644 --- a/nominatim/data/postcode_format.py +++ b/nominatim/data/postcode_format.py @@ -8,6 +8,7 @@ Functions for formatting postcodes according to their country-specific format. """ +from typing import Any, Mapping, Optional, Set, Match import re from nominatim.errors import UsageError @@ -17,7 +18,7 @@ class CountryPostcodeMatcher: """ Matches and formats a postcode according to a format definition of the given country. """ - def __init__(self, country_code, config): + def __init__(self, country_code: str, config: Mapping[str, Any]) -> None: if 'pattern' not in config: raise UsageError("Field 'pattern' required for 'postcode' " f"for country '{country_code}'") @@ -30,7 +31,7 @@ class CountryPostcodeMatcher: self.output = config.get('output', r'\g<0>') - def match(self, postcode): + def match(self, postcode: str) -> Optional[Match[str]]: """ Match the given postcode against the postcode pattern for this matcher. Returns a `re.Match` object if the match was successful and None otherwise. @@ -44,7 +45,7 @@ class CountryPostcodeMatcher: return None - def normalize(self, match): + def normalize(self, match: Match[str]) -> str: """ Return the default format of the postcode for the given match. `match` must be a `re.Match` object previously returned by `match()` @@ -56,9 +57,9 @@ class PostcodeFormatter: """ Container for different postcode formats of the world and access functions. """ - def __init__(self): + def __init__(self) -> None: # Objects without a country code can't have a postcode per definition. - self.country_without_postcode = {None} + self.country_without_postcode: Set[Optional[str]] = {None} self.country_matcher = {} self.default_matcher = CountryPostcodeMatcher('', {'pattern': '.*'}) @@ -71,14 +72,14 @@ class PostcodeFormatter: raise UsageError(f"Invalid entry 'postcode' for country '{ccode}'") - def set_default_pattern(self, pattern): + def set_default_pattern(self, pattern: str) -> None: """ Set the postcode match pattern to use, when a country does not - have a specific pattern or is marked as country without postcode. + have a specific pattern. """ self.default_matcher = CountryPostcodeMatcher('', {'pattern': pattern}) - def get_matcher(self, country_code): + def get_matcher(self, country_code: Optional[str]) -> Optional[CountryPostcodeMatcher]: """ Return the CountryPostcodeMatcher for the given country. Returns None if the country doesn't have a postcode and the default matcher if there is no specific matcher configured for @@ -87,10 +88,12 @@ class PostcodeFormatter: if country_code in self.country_without_postcode: return None + assert country_code is not None + return self.country_matcher.get(country_code, self.default_matcher) - def match(self, country_code, postcode): + def match(self, country_code: Optional[str], postcode: str) -> Optional[Match[str]]: """ Match the given postcode against the postcode pattern for this matcher. Returns a `re.Match` object if the country has a pattern and the match was successful or None if the match failed. @@ -98,10 +101,12 @@ class PostcodeFormatter: if country_code in self.country_without_postcode: return None + assert country_code is not None + return self.country_matcher.get(country_code, self.default_matcher).match(postcode) - def normalize(self, country_code, match): + def normalize(self, country_code: str, match: Match[str]) -> str: """ Return the default format of the postcode for the given match. `match` must be a `re.Match` object previously returned by `match()` diff --git a/nominatim/db/async_connection.py b/nominatim/db/async_connection.py index 285463a5..a2c8fe4d 100644 --- a/nominatim/db/async_connection.py +++ b/nominatim/db/async_connection.py @@ -4,8 +4,9 @@ # # Copyright (C) 2022 by the Nominatim developer community. # For a full list of authors see the git log. -""" Database helper functions for the indexer. +""" Non-blocking database connections. """ +from typing import Callable, Any, Optional, Iterator, Sequence import logging import select import time @@ -21,6 +22,8 @@ try: except ImportError: __has_psycopg2_errors__ = False +from nominatim.typing import T_cursor, Query + LOG = logging.getLogger() class DeadlockHandler: @@ -29,14 +32,14 @@ class DeadlockHandler: normally. """ - def __init__(self, handler, ignore_sql_errors=False): + def __init__(self, handler: Callable[[], None], ignore_sql_errors: bool = False) -> None: self.handler = handler self.ignore_sql_errors = ignore_sql_errors - def __enter__(self): + def __enter__(self) -> 'DeadlockHandler': return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool: if __has_psycopg2_errors__: if exc_type == psycopg2.errors.DeadlockDetected: # pylint: disable=E1101 self.handler() @@ -57,26 +60,31 @@ class DBConnection: """ A single non-blocking database connection. """ - def __init__(self, dsn, cursor_factory=None, ignore_sql_errors=False): - self.current_query = None - self.current_params = None + def __init__(self, dsn: str, + cursor_factory: Optional[Callable[..., T_cursor]] = None, + ignore_sql_errors: bool = False) -> None: self.dsn = dsn + + self.current_query: Optional[Query] = None + self.current_params: Optional[Sequence[Any]] = None self.ignore_sql_errors = ignore_sql_errors - self.conn = None - self.cursor = None + self.conn: Optional['psycopg2.connection'] = None + self.cursor: Optional['psycopg2.cursor'] = None self.connect(cursor_factory=cursor_factory) - def close(self): + def close(self) -> None: """ Close all open connections. Does not wait for pending requests. """ if self.conn is not None: - self.cursor.close() + if self.cursor is not None: + self.cursor.close() # type: ignore[no-untyped-call] + self.cursor = None self.conn.close() self.conn = None - def connect(self, cursor_factory=None): + def connect(self, cursor_factory: Optional[Callable[..., T_cursor]] = None) -> None: """ (Re)connect to the database. Creates an asynchronous connection with JIT and parallel processing disabled. If a connection was already open, it is closed and a new connection established. @@ -89,7 +97,10 @@ class DBConnection: self.conn = psycopg2.connect(**{'dsn': self.dsn, 'async': True}) self.wait() - self.cursor = self.conn.cursor(cursor_factory=cursor_factory) + if cursor_factory is not None: + self.cursor = self.conn.cursor(cursor_factory=cursor_factory) + else: + self.cursor = self.conn.cursor() # Disable JIT and parallel workers as they are known to cause problems. # Update pg_settings instead of using SET because it does not yield # errors on older versions of Postgres where the settings are not @@ -100,11 +111,15 @@ class DBConnection: WHERE name = 'max_parallel_workers_per_gather';""") self.wait() - def _deadlock_handler(self): + def _deadlock_handler(self) -> None: LOG.info("Deadlock detected (params = %s), retry.", str(self.current_params)) + assert self.cursor is not None + assert self.current_query is not None + assert self.current_params is not None + self.cursor.execute(self.current_query, self.current_params) - def wait(self): + def wait(self) -> None: """ Block until any pending operation is done. """ while True: @@ -113,25 +128,29 @@ class DBConnection: self.current_query = None return - def perform(self, sql, args=None): + def perform(self, sql: Query, args: Optional[Sequence[Any]] = None) -> None: """ Send SQL query to the server. Returns immediately without blocking. """ + assert self.cursor is not None self.current_query = sql self.current_params = args self.cursor.execute(sql, args) - def fileno(self): + def fileno(self) -> int: """ File descriptor to wait for. (Makes this class select()able.) """ + assert self.conn is not None return self.conn.fileno() - def is_done(self): + def is_done(self) -> bool: """ Check if the connection is available for a new query. Also checks if the previous query has run into a deadlock. If so, then the previous query is repeated. """ + assert self.conn is not None + if self.current_query is None: return True @@ -150,14 +169,14 @@ class WorkerPool: """ REOPEN_CONNECTIONS_AFTER = 100000 - def __init__(self, dsn, pool_size, ignore_sql_errors=False): + def __init__(self, dsn: str, pool_size: int, ignore_sql_errors: bool = False) -> None: self.threads = [DBConnection(dsn, ignore_sql_errors=ignore_sql_errors) for _ in range(pool_size)] self.free_workers = self._yield_free_worker() - self.wait_time = 0 + self.wait_time = 0.0 - def finish_all(self): + def finish_all(self) -> None: """ Wait for all connection to finish. """ for thread in self.threads: @@ -166,22 +185,22 @@ class WorkerPool: self.free_workers = self._yield_free_worker() - def close(self): + def close(self) -> None: """ Close all connections and clear the pool. """ for thread in self.threads: thread.close() self.threads = [] - self.free_workers = None + self.free_workers = iter([]) - def next_free_worker(self): + def next_free_worker(self) -> DBConnection: """ Get the next free connection. """ return next(self.free_workers) - def _yield_free_worker(self): + def _yield_free_worker(self) -> Iterator[DBConnection]: ready = self.threads command_stat = 0 while True: @@ -200,17 +219,17 @@ class WorkerPool: self.wait_time += time.time() - tstart - def _reconnect_threads(self): + def _reconnect_threads(self) -> None: for thread in self.threads: while not thread.is_done(): thread.wait() thread.connect() - def __enter__(self): + def __enter__(self) -> 'WorkerPool': return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.finish_all() self.close() diff --git a/nominatim/db/connection.py b/nominatim/db/connection.py index c60bcfdd..4f32dfce 100644 --- a/nominatim/db/connection.py +++ b/nominatim/db/connection.py @@ -7,6 +7,7 @@ """ Specialised connection and cursor functions. """ +from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable import contextlib import logging import os @@ -16,25 +17,27 @@ import psycopg2.extensions import psycopg2.extras from psycopg2 import sql as pysql +from nominatim.typing import SysEnv, Query, T_cursor from nominatim.errors import UsageError LOG = logging.getLogger() -class _Cursor(psycopg2.extras.DictCursor): +class Cursor(psycopg2.extras.DictCursor): """ A cursor returning dict-like objects and providing specialised execution functions. """ - # pylint: disable=arguments-renamed,arguments-differ - def execute(self, query, args=None): + def execute(self, query: Query, args: Any = None) -> None: """ Query execution that logs the SQL query when debugging is enabled. """ - LOG.debug(self.mogrify(query, args).decode('utf-8')) + if LOG.isEnabledFor(logging.DEBUG): + LOG.debug(self.mogrify(query, args).decode('utf-8')) # type: ignore[no-untyped-call] super().execute(query, args) - def execute_values(self, sql, argslist, template=None): + def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]], + template: Optional[Query] = None) -> None: """ Wrapper for the psycopg2 convenience function to execute SQL for a list of values. """ @@ -43,7 +46,7 @@ class _Cursor(psycopg2.extras.DictCursor): psycopg2.extras.execute_values(self, sql, argslist, template=template) - def scalar(self, sql, args=None): + def scalar(self, sql: Query, args: Any = None) -> Any: """ Execute query that returns a single value. The value is returned. If the query yields more than one row, a ValueError is raised. """ @@ -52,10 +55,13 @@ class _Cursor(psycopg2.extras.DictCursor): if self.rowcount != 1: raise RuntimeError("Query did not return a single row.") - return self.fetchone()[0] + result = self.fetchone() # type: ignore[no-untyped-call] + assert result is not None + + return result[0] - def drop_table(self, name, if_exists=True, cascade=False): + def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None: """ Drop the table with the given name. Set `if_exists` to False if a non-existant table should raise an exception instead of just being ignored. If 'cascade' is set @@ -71,27 +77,38 @@ class _Cursor(psycopg2.extras.DictCursor): self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) -class _Connection(psycopg2.extensions.connection): +class Connection(psycopg2.extensions.connection): """ A connection that provides the specialised cursor by default and adds convenience functions for administrating the database. """ + @overload # type: ignore[override] + def cursor(self) -> Cursor: + ... - def cursor(self, cursor_factory=_Cursor, **kwargs): + @overload + def cursor(self, name: str) -> Cursor: + ... + + @overload + def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor: + ... + + def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore """ Return a new cursor. By default the specialised cursor is returned. """ return super().cursor(cursor_factory=cursor_factory, **kwargs) - def table_exists(self, table): + def table_exists(self, table: str) -> bool: """ Check that a table with the given name exists in the database. """ with self.cursor() as cur: num = cur.scalar("""SELECT count(*) FROM pg_tables WHERE tablename = %s and schemaname = 'public'""", (table, )) - return num == 1 + return num == 1 if isinstance(num, int) else False - def table_has_column(self, table, column): + def table_has_column(self, table: str, column: str) -> bool: """ Check if the table 'table' exists and has a column with name 'column'. """ with self.cursor() as cur: @@ -99,10 +116,10 @@ class _Connection(psycopg2.extensions.connection): WHERE table_name = %s and column_name = %s""", (table, column)) - return has_column > 0 + return has_column > 0 if isinstance(has_column, int) else False - def index_exists(self, index, table=None): + def index_exists(self, index: str, table: Optional[str] = None) -> bool: """ Check that an index with the given name exists in the database. If table is not None then the index must relate to the given table. @@ -114,13 +131,15 @@ class _Connection(psycopg2.extensions.connection): return False if table is not None: - row = cur.fetchone() + row = cur.fetchone() # type: ignore[no-untyped-call] + if row is None or not isinstance(row[0], str): + return False return row[0] == table return True - def drop_table(self, name, if_exists=True, cascade=False): + def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None: """ Drop the table with the given name. Set `if_exists` to False if a non-existant table should raise an exception instead of just being ignored. @@ -130,18 +149,18 @@ class _Connection(psycopg2.extensions.connection): self.commit() - def server_version_tuple(self): + def server_version_tuple(self) -> Tuple[int, int]: """ Return the server version as a tuple of (major, minor). Converts correctly for pre-10 and post-10 PostgreSQL versions. """ version = self.server_version if version < 100000: - return (int(version / 10000), (version % 10000) / 100) + return (int(version / 10000), int((version % 10000) / 100)) return (int(version / 10000), version % 10000) - def postgis_version_tuple(self): + def postgis_version_tuple(self) -> Tuple[int, int]: """ Return the postgis version installed in the database as a tuple of (major, minor). Assumes that the PostGIS extension has been installed already. @@ -149,19 +168,28 @@ class _Connection(psycopg2.extensions.connection): with self.cursor() as cur: version = cur.scalar('SELECT postgis_lib_version()') - return tuple((int(x) for x in version.split('.')[:2])) + version_parts = version.split('.') + if len(version_parts) < 2: + raise UsageError(f"Error fetching Postgis version. Bad format: {version}") + return (int(version_parts[0]), int(version_parts[1])) -def connect(dsn): +class ConnectionContext(ContextManager[Connection]): + """ Context manager of the connection that also provides direct access + to the underlying connection. + """ + connection: Connection + +def connect(dsn: str) -> ConnectionContext: """ Open a connection to the database using the specialised connection factory. The returned object may be used in conjunction with 'with'. When used outside a context manager, use the `connection` attribute to get the connection. """ try: - conn = psycopg2.connect(dsn, connection_factory=_Connection) - ctxmgr = contextlib.closing(conn) - ctxmgr.connection = conn + conn = psycopg2.connect(dsn, connection_factory=Connection) + ctxmgr = cast(ConnectionContext, contextlib.closing(conn)) + ctxmgr.connection = cast(Connection, conn) return ctxmgr except psycopg2.OperationalError as err: raise UsageError(f"Cannot connect to database: {err}") from err @@ -199,7 +227,8 @@ _PG_CONNECTION_STRINGS = { } -def get_pg_env(dsn, base_env=None): +def get_pg_env(dsn: str, + base_env: Optional[SysEnv] = None) -> Dict[str, str]: """ Return a copy of `base_env` with the environment variables for PostgresSQL set up from the given database connection string. If `base_env` is None, then the OS environment is used as a base @@ -207,7 +236,7 @@ def get_pg_env(dsn, base_env=None): """ env = dict(base_env if base_env is not None else os.environ) - for param, value in psycopg2.extensions.parse_dsn(dsn).items(): + for param, value in psycopg2.extensions.parse_dsn(dsn).items(): # type: ignore if param in _PG_CONNECTION_STRINGS: env[_PG_CONNECTION_STRINGS[param]] = value else: diff --git a/nominatim/db/properties.py b/nominatim/db/properties.py index 27020487..3624c950 100644 --- a/nominatim/db/properties.py +++ b/nominatim/db/properties.py @@ -7,8 +7,11 @@ """ Query and access functions for the in-database property table. """ +from typing import Optional, cast -def set_property(conn, name, value): +from nominatim.db.connection import Connection + +def set_property(conn: Connection, name: str, value: str) -> None: """ Add or replace the propery with the given name. """ with conn.cursor() as cur: @@ -23,8 +26,9 @@ def set_property(conn, name, value): cur.execute(sql, (value, name)) conn.commit() -def get_property(conn, name): - """ Return the current value of the given propery or None if the property + +def get_property(conn: Connection, name: str) -> Optional[str]: + """ Return the current value of the given property or None if the property is not set. """ if not conn.table_exists('nominatim_properties'): @@ -34,4 +38,7 @@ def get_property(conn, name): cur.execute('SELECT value FROM nominatim_properties WHERE property = %s', (name, )) - return cur.fetchone()[0] if cur.rowcount > 0 else None + if cur.rowcount == 0: + return None + + return cast(Optional[str], cur.fetchone()[0]) # type: ignore[no-untyped-call] diff --git a/nominatim/db/sql_preprocessor.py b/nominatim/db/sql_preprocessor.py index 4de53886..b450422d 100644 --- a/nominatim/db/sql_preprocessor.py +++ b/nominatim/db/sql_preprocessor.py @@ -7,10 +7,13 @@ """ Preprocessing of SQL files. """ +from typing import Set, Dict, Any import jinja2 +from nominatim.db.connection import Connection +from nominatim.config import Configuration -def _get_partitions(conn): +def _get_partitions(conn: Connection) -> Set[int]: """ Get the set of partitions currently in use. """ with conn.cursor() as cur: @@ -22,7 +25,7 @@ def _get_partitions(conn): return partitions -def _get_tables(conn): +def _get_tables(conn: Connection) -> Set[str]: """ Return the set of tables currently in use. Only includes non-partitioned """ @@ -32,7 +35,7 @@ def _get_tables(conn): return set((row[0] for row in list(cur))) -def _setup_tablespace_sql(config): +def _setup_tablespace_sql(config: Configuration) -> Dict[str, str]: """ Returns a dict with tablespace expressions for the different tablespace kinds depending on whether a tablespace is configured or not. """ @@ -47,7 +50,7 @@ def _setup_tablespace_sql(config): return out -def _setup_postgresql_features(conn): +def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]: """ Set up a dictionary with various optional Postgresql/Postgis features that depend on the database version. """ @@ -69,11 +72,11 @@ class SQLPreprocessor: and follows its syntax. """ - def __init__(self, conn, config): + def __init__(self, conn: Connection, config: Configuration) -> None: self.env = jinja2.Environment(autoescape=False, loader=jinja2.FileSystemLoader(str(config.lib_dir.sql))) - db_info = {} + db_info: Dict[str, Any] = {} db_info['partitions'] = _get_partitions(conn) db_info['tables'] = _get_tables(conn) db_info['reverse_only'] = 'search_name' not in db_info['tables'] @@ -84,7 +87,7 @@ class SQLPreprocessor: self.env.globals['postgres'] = _setup_postgresql_features(conn) - def run_sql_file(self, conn, name, **kwargs): + def run_sql_file(self, conn: Connection, name: str, **kwargs: Any) -> None: """ Execute the given SQL file on the connection. The keyword arguments may supply additional parameters for preprocessing. """ diff --git a/nominatim/db/status.py b/nominatim/db/status.py index d31196b3..aea25a97 100644 --- a/nominatim/db/status.py +++ b/nominatim/db/status.py @@ -7,17 +7,29 @@ """ Access and helper functions for the status and status log table. """ +from typing import Optional, Tuple, cast import datetime as dt import logging import re +from nominatim.db.connection import Connection from nominatim.tools.exec_utils import get_url from nominatim.errors import UsageError +from nominatim.typing import TypedDict LOG = logging.getLogger() ISODATE_FORMAT = '%Y-%m-%dT%H:%M:%S' -def compute_database_date(conn): + +class StatusRow(TypedDict): + """ Dictionary of columns of the import_status table. + """ + lastimportdate: dt.datetime + sequence_id: Optional[int] + indexed: Optional[bool] + + +def compute_database_date(conn: Connection) -> dt.datetime: """ Determine the date of the database from the newest object in the data base. """ @@ -49,10 +61,12 @@ def compute_database_date(conn): return dt.datetime.strptime(match.group(1), ISODATE_FORMAT).replace(tzinfo=dt.timezone.utc) -def set_status(conn, date, seq=None, indexed=True): +def set_status(conn: Connection, date: Optional[dt.datetime], + seq: Optional[int] = None, indexed: bool = True) -> None: """ Replace the current status with the given status. If date is `None` then only sequence and indexed will be updated as given. Otherwise the whole status is replaced. + The change will be committed to the database. """ assert date is None or date.tzinfo == dt.timezone.utc with conn.cursor() as cur: @@ -67,7 +81,7 @@ def set_status(conn, date, seq=None, indexed=True): conn.commit() -def get_status(conn): +def get_status(conn: Connection) -> Tuple[Optional[dt.datetime], Optional[int], Optional[bool]]: """ Return the current status as a triple of (date, sequence, indexed). If status has not been set up yet, a triple of None is returned. """ @@ -76,11 +90,11 @@ def get_status(conn): if cur.rowcount < 1: return None, None, None - row = cur.fetchone() + row = cast(StatusRow, cur.fetchone()) # type: ignore[no-untyped-call] return row['lastimportdate'], row['sequence_id'], row['indexed'] -def set_indexed(conn, state): +def set_indexed(conn: Connection, state: bool) -> None: """ Set the indexed flag in the status table to the given state. """ with conn.cursor() as cur: @@ -88,7 +102,8 @@ def set_indexed(conn, state): conn.commit() -def log_status(conn, start, event, batchsize=None): +def log_status(conn: Connection, start: dt.datetime, + event: str, batchsize: Optional[int] = None) -> None: """ Write a new status line to the `import_osmosis_log` table. """ with conn.cursor() as cur: @@ -96,3 +111,4 @@ def log_status(conn, start, event, batchsize=None): (batchend, batchseq, batchsize, starttime, endtime, event) SELECT lastimportdate, sequence_id, %s, %s, now(), %s FROM import_status""", (batchsize, start, event)) + conn.commit() diff --git a/nominatim/db/utils.py b/nominatim/db/utils.py index b859afa8..9a7b4f16 100644 --- a/nominatim/db/utils.py +++ b/nominatim/db/utils.py @@ -7,17 +7,21 @@ """ Helper functions for handling DB accesses. """ +from typing import IO, Optional, Union, Any, Iterable import subprocess import logging import gzip import io +from pathlib import Path -from nominatim.db.connection import get_pg_env +from nominatim.db.connection import get_pg_env, Cursor from nominatim.errors import UsageError LOG = logging.getLogger() -def _pipe_to_proc(proc, fdesc): +def _pipe_to_proc(proc: 'subprocess.Popen[bytes]', + fdesc: Union[IO[bytes], gzip.GzipFile]) -> int: + assert proc.stdin is not None chunk = fdesc.read(2048) while chunk and proc.poll() is None: try: @@ -28,7 +32,10 @@ def _pipe_to_proc(proc, fdesc): return len(chunk) -def execute_file(dsn, fname, ignore_errors=False, pre_code=None, post_code=None): +def execute_file(dsn: str, fname: Path, + ignore_errors: bool = False, + pre_code: Optional[str] = None, + post_code: Optional[str] = None) -> None: """ Read an SQL file and run its contents against the given database using psql. Use `pre_code` and `post_code` to run extra commands before or after executing the file. The commands are run within the @@ -42,6 +49,7 @@ def execute_file(dsn, fname, ignore_errors=False, pre_code=None, post_code=None) cmd.append('--quiet') with subprocess.Popen(cmd, env=get_pg_env(dsn), stdin=subprocess.PIPE) as proc: + assert proc.stdin is not None try: if not LOG.isEnabledFor(logging.INFO): proc.stdin.write('set client_min_messages to WARNING;'.encode('utf-8')) @@ -76,20 +84,20 @@ class CopyBuffer: """ Data collector for the copy_from command. """ - def __init__(self): + def __init__(self) -> None: self.buffer = io.StringIO() - def __enter__(self): + def __enter__(self) -> 'CopyBuffer': return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if self.buffer is not None: self.buffer.close() - def add(self, *data): + def add(self, *data: Any) -> None: """ Add another row of data to the copy buffer. """ first = True @@ -105,9 +113,9 @@ class CopyBuffer: self.buffer.write('\n') - def copy_out(self, cur, table, columns=None): + def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None: """ Copy all collected data into the given table. """ if self.buffer.tell() > 0: self.buffer.seek(0) - cur.copy_from(self.buffer, table, columns=columns) + cur.copy_from(self.buffer, table, columns=columns) # type: ignore[no-untyped-call] diff --git a/nominatim/indexer/indexer.py b/nominatim/indexer/indexer.py index 555f8704..4f767530 100644 --- a/nominatim/indexer/indexer.py +++ b/nominatim/indexer/indexer.py @@ -7,15 +7,18 @@ """ Main work horse for indexing (computing addresses) the database. """ +from typing import Optional, Any, cast import logging import time import psycopg2.extras +from nominatim.tokenizer.base import AbstractTokenizer from nominatim.indexer.progress import ProgressLogger from nominatim.indexer import runners from nominatim.db.async_connection import DBConnection, WorkerPool -from nominatim.db.connection import connect +from nominatim.db.connection import connect, Connection, Cursor +from nominatim.typing import DictCursorResults LOG = logging.getLogger() @@ -23,10 +26,11 @@ LOG = logging.getLogger() class PlaceFetcher: """ Asynchronous connection that fetches place details for processing. """ - def __init__(self, dsn, setup_conn): - self.wait_time = 0 - self.current_ids = None - self.conn = DBConnection(dsn, cursor_factory=psycopg2.extras.DictCursor) + def __init__(self, dsn: str, setup_conn: Connection) -> None: + self.wait_time = 0.0 + self.current_ids: Optional[DictCursorResults] = None + self.conn: Optional[DBConnection] = DBConnection(dsn, + cursor_factory=psycopg2.extras.DictCursor) with setup_conn.cursor() as cur: # need to fetch those manually because register_hstore cannot @@ -37,7 +41,7 @@ class PlaceFetcher: psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid, array_oid=hstore_array_oid) - def close(self): + def close(self) -> None: """ Close the underlying asynchronous connection. """ if self.conn: @@ -45,44 +49,46 @@ class PlaceFetcher: self.conn = None - def fetch_next_batch(self, cur, runner): + def fetch_next_batch(self, cur: Cursor, runner: runners.Runner) -> bool: """ Send a request for the next batch of places. If details for the places are required, they will be fetched asynchronously. Returns true if there is still data available. """ - ids = cur.fetchmany(100) + ids = cast(Optional[DictCursorResults], cur.fetchmany(100)) if not ids: self.current_ids = None return False - if hasattr(runner, 'get_place_details'): - runner.get_place_details(self.conn, ids) - self.current_ids = [] - else: - self.current_ids = ids + assert self.conn is not None + self.current_ids = runner.get_place_details(self.conn, ids) return True - def get_batch(self): + def get_batch(self) -> DictCursorResults: """ Get the next batch of data, previously requested with `fetch_next_batch`. """ + assert self.conn is not None + assert self.conn.cursor is not None + if self.current_ids is not None and not self.current_ids: tstart = time.time() self.conn.wait() self.wait_time += time.time() - tstart - self.current_ids = self.conn.cursor.fetchall() + self.current_ids = cast(Optional[DictCursorResults], + self.conn.cursor.fetchall()) - return self.current_ids + return self.current_ids if self.current_ids is not None else [] - def __enter__(self): + def __enter__(self) -> 'PlaceFetcher': return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + assert self.conn is not None self.conn.wait() self.close() @@ -91,13 +97,13 @@ class Indexer: """ Main indexing routine. """ - def __init__(self, dsn, tokenizer, num_threads): + def __init__(self, dsn: str, tokenizer: AbstractTokenizer, num_threads: int): self.dsn = dsn self.tokenizer = tokenizer self.num_threads = num_threads - def has_pending(self): + def has_pending(self) -> bool: """ Check if any data still needs indexing. This function must only be used after the import has finished. Otherwise it will be very expensive. @@ -108,7 +114,7 @@ class Indexer: return cur.rowcount > 0 - def index_full(self, analyse=True): + def index_full(self, analyse: bool = True) -> None: """ Index the complete database. This will first index boundaries followed by all other objects. When `analyse` is True, then the database will be analysed at the appropriate places to @@ -117,7 +123,7 @@ class Indexer: with connect(self.dsn) as conn: conn.autocommit = True - def _analyze(): + def _analyze() -> None: if analyse: with conn.cursor() as cur: cur.execute('ANALYZE') @@ -138,7 +144,7 @@ class Indexer: _analyze() - def index_boundaries(self, minrank, maxrank): + def index_boundaries(self, minrank: int, maxrank: int) -> None: """ Index only administrative boundaries within the given rank range. """ LOG.warning("Starting indexing boundaries using %s threads", @@ -148,7 +154,7 @@ class Indexer: for rank in range(max(minrank, 4), min(maxrank, 26)): self._index(runners.BoundaryRunner(rank, analyzer)) - def index_by_rank(self, minrank, maxrank): + def index_by_rank(self, minrank: int, maxrank: int) -> None: """ Index all entries of placex in the given rank range (inclusive) in order of their address rank. @@ -168,7 +174,7 @@ class Indexer: self._index(runners.InterpolationRunner(analyzer), 20) - def index_postcodes(self): + def index_postcodes(self) -> None: """Index the entries ofthe location_postcode table. """ LOG.warning("Starting indexing postcodes using %s threads", self.num_threads) @@ -176,7 +182,7 @@ class Indexer: self._index(runners.PostcodeRunner(), 20) - def update_status_table(self): + def update_status_table(self) -> None: """ Update the status in the status table to 'indexed'. """ with connect(self.dsn) as conn: @@ -185,7 +191,7 @@ class Indexer: conn.commit() - def _index(self, runner, batch=1): + def _index(self, runner: runners.Runner, batch: int = 1) -> None: """ Index a single rank or table. `runner` describes the SQL to use for indexing. `batch` describes the number of objects that should be processed with a single SQL statement diff --git a/nominatim/indexer/progress.py b/nominatim/indexer/progress.py index b758e10d..bc1d68a3 100644 --- a/nominatim/indexer/progress.py +++ b/nominatim/indexer/progress.py @@ -22,7 +22,7 @@ class ProgressLogger: should be reported. """ - def __init__(self, name, total, log_interval=1): + def __init__(self, name: str, total: int, log_interval: int = 1) -> None: self.name = name self.total_places = total self.done_places = 0 @@ -30,7 +30,7 @@ class ProgressLogger: self.log_interval = log_interval self.next_info = INITIAL_PROGRESS if LOG.isEnabledFor(logging.WARNING) else total + 1 - def add(self, num=1): + def add(self, num: int = 1) -> None: """ Mark `num` places as processed. Print a log message if the logging is at least info and the log interval has passed. """ @@ -55,14 +55,14 @@ class ProgressLogger: self.next_info += int(places_per_sec) * self.log_interval - def done(self): + def done(self) -> None: """ Print final statistics about the progress. """ rank_end_time = datetime.now() if rank_end_time == self.rank_start_time: - diff_seconds = 0 - places_per_sec = self.done_places + diff_seconds = 0.0 + places_per_sec = float(self.done_places) else: diff_seconds = (rank_end_time - self.rank_start_time).total_seconds() places_per_sec = self.done_places / diff_seconds diff --git a/nominatim/indexer/runners.py b/nominatim/indexer/runners.py index c8495ee4..bbadd282 100644 --- a/nominatim/indexer/runners.py +++ b/nominatim/indexer/runners.py @@ -8,35 +8,48 @@ Mix-ins that provide the actual commands for the indexer for various indexing tasks. """ +from typing import Any, List import functools from psycopg2 import sql as pysql import psycopg2.extras from nominatim.data.place_info import PlaceInfo +from nominatim.tokenizer.base import AbstractAnalyzer +from nominatim.db.async_connection import DBConnection +from nominatim.typing import Query, DictCursorResult, DictCursorResults, Protocol # pylint: disable=C0111 -def _mk_valuelist(template, num): +def _mk_valuelist(template: str, num: int) -> pysql.Composed: return pysql.SQL(',').join([pysql.SQL(template)] * num) -def _analyze_place(place, analyzer): +def _analyze_place(place: DictCursorResult, analyzer: AbstractAnalyzer) -> psycopg2.extras.Json: return psycopg2.extras.Json(analyzer.process_place(PlaceInfo(place))) + +class Runner(Protocol): + def name(self) -> str: ... + def sql_count_objects(self) -> Query: ... + def sql_get_objects(self) -> Query: ... + def get_place_details(self, worker: DBConnection, + ids: DictCursorResults) -> DictCursorResults: ... + def index_places(self, worker: DBConnection, places: DictCursorResults) -> None: ... + + class AbstractPlacexRunner: """ Returns SQL commands for indexing of the placex table. """ SELECT_SQL = pysql.SQL('SELECT place_id FROM placex ') UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)" - def __init__(self, rank, analyzer): + def __init__(self, rank: int, analyzer: AbstractAnalyzer) -> None: self.rank = rank self.analyzer = analyzer - @staticmethod @functools.lru_cache(maxsize=1) - def _index_sql(num_places): + def _index_sql(self, num_places: int) -> pysql.Composed: return pysql.SQL( """ UPDATE placex SET indexed_status = 0, address = v.addr, token_info = v.ti, @@ -46,16 +59,17 @@ class AbstractPlacexRunner: """).format(_mk_valuelist(AbstractPlacexRunner.UPDATE_LINE, num_places)) - @staticmethod - def get_place_details(worker, ids): + def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults: worker.perform("""SELECT place_id, extra.* FROM placex, LATERAL placex_indexing_prepare(placex) as extra WHERE place_id IN %s""", (tuple((p[0] for p in ids)), )) + return [] - def index_places(self, worker, places): - values = [] + + def index_places(self, worker: DBConnection, places: DictCursorResults) -> None: + values: List[Any] = [] for place in places: for field in ('place_id', 'name', 'address', 'linked_place_id'): values.append(place[field]) @@ -68,15 +82,15 @@ class RankRunner(AbstractPlacexRunner): """ Returns SQL commands for indexing one rank within the placex table. """ - def name(self): + def name(self) -> str: return f"rank {self.rank}" - def sql_count_objects(self): + def sql_count_objects(self) -> pysql.Composed: return pysql.SQL("""SELECT count(*) FROM placex WHERE rank_address = {} and indexed_status > 0 """).format(pysql.Literal(self.rank)) - def sql_get_objects(self): + def sql_get_objects(self) -> pysql.Composed: return self.SELECT_SQL + pysql.SQL( """WHERE indexed_status > 0 and rank_address = {} ORDER BY geometry_sector @@ -88,17 +102,17 @@ class BoundaryRunner(AbstractPlacexRunner): of a certain rank. """ - def name(self): + def name(self) -> str: return f"boundaries rank {self.rank}" - def sql_count_objects(self): + def sql_count_objects(self) -> pysql.Composed: return pysql.SQL("""SELECT count(*) FROM placex WHERE indexed_status > 0 AND rank_search = {} AND class = 'boundary' and type = 'administrative' """).format(pysql.Literal(self.rank)) - def sql_get_objects(self): + def sql_get_objects(self) -> pysql.Composed: return self.SELECT_SQL + pysql.SQL( """WHERE indexed_status > 0 and rank_search = {} and class = 'boundary' and type = 'administrative' @@ -111,37 +125,33 @@ class InterpolationRunner: location_property_osmline. """ - def __init__(self, analyzer): + def __init__(self, analyzer: AbstractAnalyzer) -> None: self.analyzer = analyzer - @staticmethod - def name(): + def name(self) -> str: return "interpolation lines (location_property_osmline)" - @staticmethod - def sql_count_objects(): + def sql_count_objects(self) -> str: return """SELECT count(*) FROM location_property_osmline WHERE indexed_status > 0""" - @staticmethod - def sql_get_objects(): + def sql_get_objects(self) -> str: return """SELECT place_id FROM location_property_osmline WHERE indexed_status > 0 ORDER BY geometry_sector""" - @staticmethod - def get_place_details(worker, ids): + def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults: worker.perform("""SELECT place_id, get_interpolation_address(address, osm_id) as address FROM location_property_osmline WHERE place_id IN %s""", (tuple((p[0] for p in ids)), )) + return [] - @staticmethod @functools.lru_cache(maxsize=1) - def _index_sql(num_places): + def _index_sql(self, num_places: int) -> pysql.Composed: return pysql.SQL("""UPDATE location_property_osmline SET indexed_status = 0, address = v.addr, token_info = v.ti FROM (VALUES {}) as v(id, addr, ti) @@ -149,8 +159,8 @@ class InterpolationRunner: """).format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", num_places)) - def index_places(self, worker, places): - values = [] + def index_places(self, worker: DBConnection, places: DictCursorResults) -> None: + values: List[Any] = [] for place in places: values.extend((place[x] for x in ('place_id', 'address'))) values.append(_analyze_place(place, self.analyzer)) @@ -159,26 +169,28 @@ class InterpolationRunner: -class PostcodeRunner: +class PostcodeRunner(Runner): """ Provides the SQL commands for indexing the location_postcode table. """ - @staticmethod - def name(): + def name(self) -> str: return "postcodes (location_postcode)" - @staticmethod - def sql_count_objects(): + + def sql_count_objects(self) -> str: return 'SELECT count(*) FROM location_postcode WHERE indexed_status > 0' - @staticmethod - def sql_get_objects(): + + def sql_get_objects(self) -> str: return """SELECT place_id FROM location_postcode WHERE indexed_status > 0 ORDER BY country_code, postcode""" - @staticmethod - def index_places(worker, ids): + + def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults: + return ids + + def index_places(self, worker: DBConnection, places: DictCursorResults) -> None: worker.perform(pysql.SQL("""UPDATE location_postcode SET indexed_status = 0 WHERE place_id IN ({})""") - .format(pysql.SQL(',').join((pysql.Literal(i[0]) for i in ids)))) + .format(pysql.SQL(',').join((pysql.Literal(i[0]) for i in places)))) diff --git a/nominatim/tokenizer/base.py b/nominatim/tokenizer/base.py index 70a54bfd..dbc4cfad 100644 --- a/nominatim/tokenizer/base.py +++ b/nominatim/tokenizer/base.py @@ -9,12 +9,12 @@ Abstract class defintions for tokenizers. These base classes are here mainly for documentation purposes. """ from abc import ABC, abstractmethod -from typing import List, Tuple, Dict, Any +from typing import List, Tuple, Dict, Any, Optional, Iterable +from pathlib import Path from nominatim.config import Configuration from nominatim.data.place_info import PlaceInfo - -# pylint: disable=unnecessary-pass +from nominatim.typing import Protocol class AbstractAnalyzer(ABC): """ The analyzer provides the functions for analysing names and building @@ -28,7 +28,7 @@ class AbstractAnalyzer(ABC): return self - def __exit__(self, exc_type, exc_value, traceback) -> None: + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.close() @@ -80,7 +80,8 @@ class AbstractAnalyzer(ABC): @abstractmethod - def update_special_phrases(self, phrases: List[Tuple[str, str, str, str]], + def update_special_phrases(self, + phrases: Iterable[Tuple[str, str, str, str]], should_replace: bool) -> None: """ Update the tokenizer's special phrase tokens from the given list of special phrases. @@ -95,7 +96,7 @@ class AbstractAnalyzer(ABC): @abstractmethod - def add_country_names(self, country_code: str, names: Dict[str, str]): + def add_country_names(self, country_code: str, names: Dict[str, str]) -> None: """ Add the given names to the tokenizer's list of country tokens. Arguments: @@ -186,7 +187,7 @@ class AbstractTokenizer(ABC): @abstractmethod - def check_database(self, config: Configuration) -> str: + def check_database(self, config: Configuration) -> Optional[str]: """ Check that the database is set up correctly and ready for being queried. @@ -230,3 +231,13 @@ class AbstractTokenizer(ABC): When used outside the with construct, the caller must ensure to call the close() function before destructing the analyzer. """ + + +class TokenizerModule(Protocol): + """ Interface that must be exported by modules that implement their + own tokenizer. + """ + + def create(self, dsn: str, data_dir: Path) -> AbstractTokenizer: + """ Factory for new tokenizers. + """ diff --git a/nominatim/tokenizer/factory.py b/nominatim/tokenizer/factory.py index 108c7841..67e22194 100644 --- a/nominatim/tokenizer/factory.py +++ b/nominatim/tokenizer/factory.py @@ -19,17 +19,20 @@ database. A tokenizer usually also includes PHP code for querying. The appropriate PHP normalizer module is installed, when the tokenizer is created. """ +from typing import Optional import logging import importlib from pathlib import Path -from ..errors import UsageError -from ..db import properties -from ..db.connection import connect +from nominatim.errors import UsageError +from nominatim.db import properties +from nominatim.db.connection import connect +from nominatim.config import Configuration +from nominatim.tokenizer.base import AbstractTokenizer, TokenizerModule LOG = logging.getLogger() -def _import_tokenizer(name): +def _import_tokenizer(name: str) -> TokenizerModule: """ Load the tokenizer.py module from project directory. """ src_file = Path(__file__).parent / (name + '_tokenizer.py') @@ -41,7 +44,8 @@ def _import_tokenizer(name): return importlib.import_module('nominatim.tokenizer.' + name + '_tokenizer') -def create_tokenizer(config, init_db=True, module_name=None): +def create_tokenizer(config: Configuration, init_db: bool = True, + module_name: Optional[str] = None) -> AbstractTokenizer: """ Create a new tokenizer as defined by the given configuration. The tokenizer data and code is copied into the 'tokenizer' directory @@ -70,7 +74,7 @@ def create_tokenizer(config, init_db=True, module_name=None): return tokenizer -def get_tokenizer_for_db(config): +def get_tokenizer_for_db(config: Configuration) -> AbstractTokenizer: """ Instantiate a tokenizer for an existing database. The function looks up the appropriate tokenizer in the database diff --git a/nominatim/tokenizer/icu_rule_loader.py b/nominatim/tokenizer/icu_rule_loader.py index 035b6698..84040ddc 100644 --- a/nominatim/tokenizer/icu_rule_loader.py +++ b/nominatim/tokenizer/icu_rule_loader.py @@ -7,16 +7,19 @@ """ Helper class to create ICU rules from a configuration file. """ +from typing import Mapping, Any, Dict, Optional import importlib import io import json import logging -from nominatim.config import flatten_config_list +from nominatim.config import flatten_config_list, Configuration from nominatim.db.properties import set_property, get_property +from nominatim.db.connection import Connection from nominatim.errors import UsageError from nominatim.tokenizer.place_sanitizer import PlaceSanitizer from nominatim.tokenizer.icu_token_analysis import ICUTokenAnalysis +from nominatim.tokenizer.token_analysis.base import AnalysisModule, Analyser import nominatim.data.country_info LOG = logging.getLogger() @@ -26,7 +29,7 @@ DBCFG_IMPORT_TRANS_RULES = "tokenizer_import_transliteration" DBCFG_IMPORT_ANALYSIS_RULES = "tokenizer_import_analysis_rules" -def _get_section(rules, section): +def _get_section(rules: Mapping[str, Any], section: str) -> Any: """ Get the section named 'section' from the rules. If the section does not exist, raise a usage error with a meaningful message. """ @@ -41,7 +44,7 @@ class ICURuleLoader: """ Compiler for ICU rules from a tokenizer configuration file. """ - def __init__(self, config): + def __init__(self, config: Configuration) -> None: rules = config.load_sub_configuration('icu_tokenizer.yaml', config='TOKENIZER_CONFIG') @@ -57,17 +60,27 @@ class ICURuleLoader: self.sanitizer_rules = rules.get('sanitizers', []) - def load_config_from_db(self, conn): + def load_config_from_db(self, conn: Connection) -> None: """ Get previously saved parts of the configuration from the database. """ - self.normalization_rules = get_property(conn, DBCFG_IMPORT_NORM_RULES) - self.transliteration_rules = get_property(conn, DBCFG_IMPORT_TRANS_RULES) - self.analysis_rules = json.loads(get_property(conn, DBCFG_IMPORT_ANALYSIS_RULES)) + rules = get_property(conn, DBCFG_IMPORT_NORM_RULES) + if rules is not None: + self.normalization_rules = rules + + rules = get_property(conn, DBCFG_IMPORT_TRANS_RULES) + if rules is not None: + self.transliteration_rules = rules + + rules = get_property(conn, DBCFG_IMPORT_ANALYSIS_RULES) + if rules: + self.analysis_rules = json.loads(rules) + else: + self.analysis_rules = [] self._setup_analysis() - def save_config_to_db(self, conn): + def save_config_to_db(self, conn: Connection) -> None: """ Save the part of the configuration that cannot be changed into the database. """ @@ -76,20 +89,20 @@ class ICURuleLoader: set_property(conn, DBCFG_IMPORT_ANALYSIS_RULES, json.dumps(self.analysis_rules)) - def make_sanitizer(self): + def make_sanitizer(self) -> PlaceSanitizer: """ Create a place sanitizer from the configured rules. """ return PlaceSanitizer(self.sanitizer_rules) - def make_token_analysis(self): + def make_token_analysis(self) -> ICUTokenAnalysis: """ Create a token analyser from the reviouly loaded rules. """ return ICUTokenAnalysis(self.normalization_rules, self.transliteration_rules, self.analysis) - def get_search_rules(self): + def get_search_rules(self) -> str: """ Return the ICU rules to be used during search. The rules combine normalization and transliteration. """ @@ -102,22 +115,22 @@ class ICURuleLoader: return rules.getvalue() - def get_normalization_rules(self): + def get_normalization_rules(self) -> str: """ Return rules for normalisation of a term. """ return self.normalization_rules - def get_transliteration_rules(self): + def get_transliteration_rules(self) -> str: """ Return the rules for converting a string into its asciii representation. """ return self.transliteration_rules - def _setup_analysis(self): + def _setup_analysis(self) -> None: """ Process the rules used for creating the various token analyzers. """ - self.analysis = {} + self.analysis: Dict[Optional[str], TokenAnalyzerRule] = {} if not isinstance(self.analysis_rules, list): raise UsageError("Configuration section 'token-analysis' must be a list.") @@ -135,7 +148,7 @@ class ICURuleLoader: @staticmethod - def _cfg_to_icu_rules(rules, section): + def _cfg_to_icu_rules(rules: Mapping[str, Any], section: str) -> str: """ Load an ICU ruleset from the given section. If the section is a simple string, it is interpreted as a file name and the rules are loaded verbatim from the given file. The filename is expected to be @@ -155,12 +168,16 @@ class TokenAnalyzerRule: and creates a new token analyzer on request. """ - def __init__(self, rules, normalization_rules): + def __init__(self, rules: Mapping[str, Any], normalization_rules: str) -> None: # Find the analysis module module_name = 'nominatim.tokenizer.token_analysis.' \ + _get_section(rules, 'analyzer').replace('-', '_') - analysis_mod = importlib.import_module(module_name) - self.create = analysis_mod.create + self._analysis_mod: AnalysisModule = importlib.import_module(module_name) # Load the configuration. - self.config = analysis_mod.configure(rules, normalization_rules) + self.config = self._analysis_mod.configure(rules, normalization_rules) + + def create(self, normalizer: Any, transliterator: Any) -> Analyser: + """ Create a new analyser instance for the given rule. + """ + return self._analysis_mod.create(normalizer, transliterator, self.config) diff --git a/nominatim/tokenizer/icu_token_analysis.py b/nominatim/tokenizer/icu_token_analysis.py index 68fc82e3..3c4d7298 100644 --- a/nominatim/tokenizer/icu_token_analysis.py +++ b/nominatim/tokenizer/icu_token_analysis.py @@ -8,15 +8,22 @@ Container class collecting all components required to transform an OSM name into a Nominatim token. """ - +from typing import Mapping, Optional, TYPE_CHECKING from icu import Transliterator +from nominatim.tokenizer.token_analysis.base import Analyser + +if TYPE_CHECKING: + from typing import Any + from nominatim.tokenizer.icu_rule_loader import TokenAnalyzerRule # pylint: disable=cyclic-import + class ICUTokenAnalysis: """ Container class collecting the transliterators and token analysis modules for a single NameAnalyser instance. """ - def __init__(self, norm_rules, trans_rules, analysis_rules): + def __init__(self, norm_rules: str, trans_rules: str, + analysis_rules: Mapping[Optional[str], 'TokenAnalyzerRule']): self.normalizer = Transliterator.createFromRules("icu_normalization", norm_rules) trans_rules += ";[:Space:]+ > ' '" @@ -25,11 +32,11 @@ class ICUTokenAnalysis: self.search = Transliterator.createFromRules("icu_search", norm_rules + trans_rules) - self.analysis = {name: arules.create(self.normalizer, self.to_ascii, arules.config) + self.analysis = {name: arules.create(self.normalizer, self.to_ascii) for name, arules in analysis_rules.items()} - def get_analyzer(self, name): + def get_analyzer(self, name: Optional[str]) -> Analyser: """ Return the given named analyzer. If no analyzer with that name exists, return the default analyzer. """ diff --git a/nominatim/tokenizer/icu_tokenizer.py b/nominatim/tokenizer/icu_tokenizer.py index 171d4392..31eaaf29 100644 --- a/nominatim/tokenizer/icu_tokenizer.py +++ b/nominatim/tokenizer/icu_tokenizer.py @@ -8,41 +8,48 @@ Tokenizer implementing normalisation as used before Nominatim 4 but using libICU instead of the PostgreSQL module. """ +from typing import Optional, Sequence, List, Tuple, Mapping, Any, cast, \ + Dict, Set, Iterable import itertools import json import logging +from pathlib import Path from textwrap import dedent -from nominatim.db.connection import connect +from nominatim.db.connection import connect, Connection, Cursor +from nominatim.config import Configuration from nominatim.db.utils import CopyBuffer from nominatim.db.sql_preprocessor import SQLPreprocessor from nominatim.data.place_info import PlaceInfo from nominatim.tokenizer.icu_rule_loader import ICURuleLoader +from nominatim.tokenizer.place_sanitizer import PlaceSanitizer +from nominatim.tokenizer.sanitizers.base import PlaceName +from nominatim.tokenizer.icu_token_analysis import ICUTokenAnalysis from nominatim.tokenizer.base import AbstractAnalyzer, AbstractTokenizer DBCFG_TERM_NORMALIZATION = "tokenizer_term_normalization" LOG = logging.getLogger() -def create(dsn, data_dir): +def create(dsn: str, data_dir: Path) -> 'ICUTokenizer': """ Create a new instance of the tokenizer provided by this module. """ - return LegacyICUTokenizer(dsn, data_dir) + return ICUTokenizer(dsn, data_dir) -class LegacyICUTokenizer(AbstractTokenizer): +class ICUTokenizer(AbstractTokenizer): """ This tokenizer uses libICU to covert names and queries to ASCII. Otherwise it uses the same algorithms and data structures as the normalization routines in Nominatim 3. """ - def __init__(self, dsn, data_dir): + def __init__(self, dsn: str, data_dir: Path) -> None: self.dsn = dsn self.data_dir = data_dir - self.loader = None + self.loader: Optional[ICURuleLoader] = None - def init_new_db(self, config, init_db=True): + def init_new_db(self, config: Configuration, init_db: bool = True) -> None: """ Set up a new tokenizer for the database. This copies all necessary data in the project directory to make @@ -58,7 +65,7 @@ class LegacyICUTokenizer(AbstractTokenizer): self._init_db_tables(config) - def init_from_project(self, config): + def init_from_project(self, config: Configuration) -> None: """ Initialise the tokenizer from the project directory. """ self.loader = ICURuleLoader(config) @@ -69,7 +76,7 @@ class LegacyICUTokenizer(AbstractTokenizer): self._install_php(config.lib_dir.php, overwrite=False) - def finalize_import(self, config): + def finalize_import(self, config: Configuration) -> None: """ Do any required postprocessing to make the tokenizer data ready for use. """ @@ -78,7 +85,7 @@ class LegacyICUTokenizer(AbstractTokenizer): sqlp.run_sql_file(conn, 'tokenizer/legacy_tokenizer_indices.sql') - def update_sql_functions(self, config): + def update_sql_functions(self, config: Configuration) -> None: """ Reimport the SQL functions for this tokenizer. """ with connect(self.dsn) as conn: @@ -86,14 +93,14 @@ class LegacyICUTokenizer(AbstractTokenizer): sqlp.run_sql_file(conn, 'tokenizer/icu_tokenizer.sql') - def check_database(self, config): + def check_database(self, config: Configuration) -> None: """ Check that the tokenizer is set up correctly. """ # Will throw an error if there is an issue. self.init_from_project(config) - def update_statistics(self): + def update_statistics(self) -> None: """ Recompute frequencies for all name words. """ with connect(self.dsn) as conn: @@ -113,7 +120,7 @@ class LegacyICUTokenizer(AbstractTokenizer): conn.commit() - def _cleanup_housenumbers(self): + def _cleanup_housenumbers(self) -> None: """ Remove unused house numbers. """ with connect(self.dsn) as conn: @@ -148,7 +155,7 @@ class LegacyICUTokenizer(AbstractTokenizer): - def update_word_tokens(self): + def update_word_tokens(self) -> None: """ Remove unused tokens. """ LOG.warning("Cleaning up housenumber tokens.") @@ -156,7 +163,7 @@ class LegacyICUTokenizer(AbstractTokenizer): LOG.warning("Tokenizer house-keeping done.") - def name_analyzer(self): + def name_analyzer(self) -> 'ICUNameAnalyzer': """ Create a new analyzer for tokenizing names and queries using this tokinzer. Analyzers are context managers and should be used accordingly: @@ -171,13 +178,15 @@ class LegacyICUTokenizer(AbstractTokenizer): Analyzers are not thread-safe. You need to instantiate one per thread. """ - return LegacyICUNameAnalyzer(self.dsn, self.loader.make_sanitizer(), - self.loader.make_token_analysis()) + assert self.loader is not None + return ICUNameAnalyzer(self.dsn, self.loader.make_sanitizer(), + self.loader.make_token_analysis()) - def _install_php(self, phpdir, overwrite=True): + def _install_php(self, phpdir: Path, overwrite: bool = True) -> None: """ Install the php script for the tokenizer. """ + assert self.loader is not None php_file = self.data_dir / "tokenizer.php" if not php_file.exists() or overwrite: @@ -189,15 +198,16 @@ class LegacyICUTokenizer(AbstractTokenizer): require_once('{phpdir}/tokenizer/icu_tokenizer.php');"""), encoding='utf-8') - def _save_config(self): + def _save_config(self) -> None: """ Save the configuration that needs to remain stable for the given database as database properties. """ + assert self.loader is not None with connect(self.dsn) as conn: self.loader.save_config_to_db(conn) - def _init_db_tables(self, config): + def _init_db_tables(self, config: Configuration) -> None: """ Set up the word table and fill it with pre-computed word frequencies. """ @@ -207,15 +217,16 @@ class LegacyICUTokenizer(AbstractTokenizer): conn.commit() -class LegacyICUNameAnalyzer(AbstractAnalyzer): - """ The legacy analyzer uses the ICU library for splitting names. +class ICUNameAnalyzer(AbstractAnalyzer): + """ The ICU analyzer uses the ICU library for splitting names. Each instance opens a connection to the database to request the normalization. """ - def __init__(self, dsn, sanitizer, token_analysis): - self.conn = connect(dsn).connection + def __init__(self, dsn: str, sanitizer: PlaceSanitizer, + token_analysis: ICUTokenAnalysis) -> None: + self.conn: Optional[Connection] = connect(dsn).connection self.conn.autocommit = True self.sanitizer = sanitizer self.token_analysis = token_analysis @@ -223,7 +234,7 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): self._cache = _TokenCache() - def close(self): + def close(self) -> None: """ Free all resources used by the analyzer. """ if self.conn: @@ -231,20 +242,20 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): self.conn = None - def _search_normalized(self, name): + def _search_normalized(self, name: str) -> str: """ Return the search token transliteration of the given name. """ - return self.token_analysis.search.transliterate(name).strip() + return cast(str, self.token_analysis.search.transliterate(name)).strip() - def _normalized(self, name): + def _normalized(self, name: str) -> str: """ Return the normalized version of the given name with all non-relevant information removed. """ - return self.token_analysis.normalizer.transliterate(name).strip() + return cast(str, self.token_analysis.normalizer.transliterate(name)).strip() - def get_word_token_info(self, words): + def get_word_token_info(self, words: Sequence[str]) -> List[Tuple[str, str, int]]: """ Return token information for the given list of words. If a word starts with # it is assumed to be a full name otherwise is a partial name. @@ -255,6 +266,7 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): The function is used for testing and debugging only and not necessarily efficient. """ + assert self.conn is not None full_tokens = {} partial_tokens = {} for word in words: @@ -277,7 +289,7 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): + [(k, v, part_ids.get(v, None)) for k, v in partial_tokens.items()] - def normalize_postcode(self, postcode): + def normalize_postcode(self, postcode: str) -> str: """ Convert the postcode to a standardized form. This function must yield exactly the same result as the SQL function @@ -286,10 +298,11 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): return postcode.strip().upper() - def update_postcodes_from_db(self): + def update_postcodes_from_db(self) -> None: """ Update postcode tokens in the word table from the location_postcode table. """ + assert self.conn is not None analyzer = self.token_analysis.analysis.get('@postcode') with self.conn.cursor() as cur: @@ -324,13 +337,15 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): self._delete_unused_postcode_words(word_entries - needed_entries) self._add_missing_postcode_words(needed_entries - word_entries) - def _delete_unused_postcode_words(self, tokens): + def _delete_unused_postcode_words(self, tokens: Iterable[str]) -> None: + assert self.conn is not None if tokens: with self.conn.cursor() as cur: cur.execute("DELETE FROM word WHERE type = 'P' and word = any(%s)", (list(tokens), )) - def _add_missing_postcode_words(self, tokens): + def _add_missing_postcode_words(self, tokens: Iterable[str]) -> None: + assert self.conn is not None if not tokens: return @@ -341,10 +356,12 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): if '@' in postcode_name: term, variant = postcode_name.split('@', 2) term = self._search_normalized(term) - variants = {term} - if analyzer is not None: - variants.update(analyzer.get_variants_ascii(variant)) - variants = list(variants) + if analyzer is None: + variants = [term] + else: + variants = analyzer.get_variants_ascii(variant) + if term not in variants: + variants.append(term) else: variants = [self._search_normalized(postcode_name)] terms.append((postcode_name, variants)) @@ -358,12 +375,14 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): - def update_special_phrases(self, phrases, should_replace): + def update_special_phrases(self, phrases: Iterable[Tuple[str, str, str, str]], + should_replace: bool) -> None: """ Replace the search index for special phrases with the new phrases. If `should_replace` is True, then the previous set of will be completely replaced. Otherwise the phrases are added to the already existing ones. """ + assert self.conn is not None norm_phrases = set(((self._normalized(p[0]), p[1], p[2], p[3]) for p in phrases)) @@ -386,7 +405,9 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): len(norm_phrases), added, deleted) - def _add_special_phrases(self, cursor, new_phrases, existing_phrases): + def _add_special_phrases(self, cursor: Cursor, + new_phrases: Set[Tuple[str, str, str, str]], + existing_phrases: Set[Tuple[str, str, str, str]]) -> int: """ Add all phrases to the database that are not yet there. """ to_add = new_phrases - existing_phrases @@ -407,8 +428,9 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): return added - @staticmethod - def _remove_special_phrases(cursor, new_phrases, existing_phrases): + def _remove_special_phrases(self, cursor: Cursor, + new_phrases: Set[Tuple[str, str, str, str]], + existing_phrases: Set[Tuple[str, str, str, str]]) -> int: """ Remove all phrases from the databse that are no longer in the new phrase list. """ @@ -425,7 +447,7 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): return len(to_delete) - def add_country_names(self, country_code, names): + def add_country_names(self, country_code: str, names: Mapping[str, str]) -> None: """ Add default names for the given country to the search index. """ # Make sure any name preprocessing for country names applies. @@ -437,10 +459,12 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): internal=True) - def _add_country_full_names(self, country_code, names, internal=False): + def _add_country_full_names(self, country_code: str, names: Sequence[PlaceName], + internal: bool = False) -> None: """ Add names for the given country from an already sanitized name list. """ + assert self.conn is not None word_tokens = set() for name in names: norm_name = self._search_normalized(name.name) @@ -453,7 +477,8 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): FROM word WHERE type = 'C' and word = %s""", (country_code, )) - existing_tokens = {True: set(), False: set()} # internal/external names + # internal/external names + existing_tokens: Dict[bool, Set[str]] = {True: set(), False: set()} for word in cur: existing_tokens[word[1]].add(word[0]) @@ -486,7 +511,7 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): cur.execute(sql, (country_code, list(new_tokens))) - def process_place(self, place): + def process_place(self, place: PlaceInfo) -> Mapping[str, Any]: """ Determine tokenizer information about the given place. Returns a JSON-serializable structure that will be handed into @@ -500,6 +525,7 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): token_info.set_names(*self._compute_name_tokens(names)) if place.is_country(): + assert place.country_code is not None self._add_country_full_names(place.country_code, names) if address: @@ -508,7 +534,8 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): return token_info.to_dict() - def _process_place_address(self, token_info, address): + def _process_place_address(self, token_info: '_TokenInfo', + address: Sequence[PlaceName]) -> None: for item in address: if item.kind == 'postcode': token_info.set_postcode(self._add_postcode(item)) @@ -524,12 +551,13 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): token_info.add_address_term(item.kind, self._compute_partial_tokens(item.name)) - def _compute_housenumber_token(self, hnr): + def _compute_housenumber_token(self, hnr: PlaceName) -> Tuple[Optional[int], Optional[str]]: """ Normalize the housenumber and return the word token and the canonical form. """ + assert self.conn is not None analyzer = self.token_analysis.analysis.get('@housenumber') - result = None, None + result: Tuple[Optional[int], Optional[str]] = (None, None) if analyzer is None: # When no custom analyzer is set, simply normalize and transliterate @@ -539,7 +567,7 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): if result[0] is None: with self.conn.cursor() as cur: cur.execute("SELECT getorcreate_hnr_id(%s)", (norm_name, )) - result = cur.fetchone()[0], norm_name + result = cur.fetchone()[0], norm_name # type: ignore[no-untyped-call] self._cache.housenumbers[norm_name] = result else: # Otherwise use the analyzer to determine the canonical name. @@ -554,16 +582,17 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): with self.conn.cursor() as cur: cur.execute("SELECT create_analyzed_hnr_id(%s, %s)", (norm_name, list(variants))) - result = cur.fetchone()[0], variants[0] + result = cur.fetchone()[0], variants[0] # type: ignore[no-untyped-call] self._cache.housenumbers[norm_name] = result return result - def _compute_partial_tokens(self, name): + def _compute_partial_tokens(self, name: str) -> List[int]: """ Normalize the given term, split it into partial words and return then token list for them. """ + assert self.conn is not None norm_name = self._search_normalized(name) tokens = [] @@ -582,16 +611,18 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): (need_lookup, )) for partial, token in cur: + assert token is not None tokens.append(token) self._cache.partials[partial] = token return tokens - def _retrieve_full_tokens(self, name): + def _retrieve_full_tokens(self, name: str) -> List[int]: """ Get the full name token for the given name, if it exists. The name is only retrived for the standard analyser. """ + assert self.conn is not None norm_name = self._search_normalized(name) # return cached if possible @@ -608,12 +639,13 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): return full - def _compute_name_tokens(self, names): + def _compute_name_tokens(self, names: Sequence[PlaceName]) -> Tuple[Set[int], Set[int]]: """ Computes the full name and partial name tokens for the given dictionary of names. """ - full_tokens = set() - partial_tokens = set() + assert self.conn is not None + full_tokens: Set[int] = set() + partial_tokens: Set[int] = set() for name in names: analyzer_id = name.get_attr('analyzer') @@ -633,19 +665,23 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): with self.conn.cursor() as cur: cur.execute("SELECT * FROM getorcreate_full_word(%s, %s)", (token_id, variants)) - full, part = cur.fetchone() + full, part = cast(Tuple[int, List[int]], + cur.fetchone()) # type: ignore[no-untyped-call] self._cache.names[token_id] = (full, part) + assert part is not None + full_tokens.add(full) partial_tokens.update(part) return full_tokens, partial_tokens - def _add_postcode(self, item): + def _add_postcode(self, item: PlaceName) -> Optional[str]: """ Make sure the normalized postcode is present in the word table. """ + assert self.conn is not None analyzer = self.token_analysis.analysis.get('@postcode') if analyzer is None: @@ -680,25 +716,24 @@ class LegacyICUNameAnalyzer(AbstractAnalyzer): class _TokenInfo: """ Collect token information to be sent back to the database. """ - def __init__(self): - self.names = None - self.housenumbers = set() - self.housenumber_tokens = set() - self.street_tokens = set() - self.place_tokens = set() - self.address_tokens = {} - self.postcode = None + def __init__(self) -> None: + self.names: Optional[str] = None + self.housenumbers: Set[str] = set() + self.housenumber_tokens: Set[int] = set() + self.street_tokens: Set[int] = set() + self.place_tokens: Set[int] = set() + self.address_tokens: Dict[str, str] = {} + self.postcode: Optional[str] = None - @staticmethod - def _mk_array(tokens): + def _mk_array(self, tokens: Iterable[Any]) -> str: return f"{{{','.join((str(s) for s in tokens))}}}" - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: """ Return the token information in database importable format. """ - out = {} + out: Dict[str, Any] = {} if self.names: out['names'] = self.names @@ -722,40 +757,41 @@ class _TokenInfo: return out - def set_names(self, fulls, partials): + def set_names(self, fulls: Iterable[int], partials: Iterable[int]) -> None: """ Adds token information for the normalised names. """ self.names = self._mk_array(itertools.chain(fulls, partials)) - def add_housenumber(self, token, hnr): + def add_housenumber(self, token: Optional[int], hnr: Optional[str]) -> None: """ Extract housenumber information from a list of normalised housenumbers. """ if token: + assert hnr is not None self.housenumbers.add(hnr) self.housenumber_tokens.add(token) - def add_street(self, tokens): + def add_street(self, tokens: Iterable[int]) -> None: """ Add addr:street match terms. """ self.street_tokens.update(tokens) - def add_place(self, tokens): + def add_place(self, tokens: Iterable[int]) -> None: """ Add addr:place search and match terms. """ self.place_tokens.update(tokens) - def add_address_term(self, key, partials): + def add_address_term(self, key: str, partials: Iterable[int]) -> None: """ Add additional address terms. """ if partials: self.address_tokens[key] = self._mk_array(partials) - def set_postcode(self, postcode): + def set_postcode(self, postcode: Optional[str]) -> None: """ Set the postcode to the given one. """ self.postcode = postcode @@ -767,9 +803,9 @@ class _TokenCache: This cache is not thread-safe and needs to be instantiated per analyzer. """ - def __init__(self): - self.names = {} - self.partials = {} - self.fulls = {} - self.postcodes = set() - self.housenumbers = {} + def __init__(self) -> None: + self.names: Dict[str, Tuple[int, List[int]]] = {} + self.partials: Dict[str, int] = {} + self.fulls: Dict[str, List[int]] = {} + self.postcodes: Set[str] = set() + self.housenumbers: Dict[str, Tuple[Optional[int], Optional[str]]] = {} diff --git a/nominatim/tokenizer/legacy_tokenizer.py b/nominatim/tokenizer/legacy_tokenizer.py index 36fd5722..f52eaada 100644 --- a/nominatim/tokenizer/legacy_tokenizer.py +++ b/nominatim/tokenizer/legacy_tokenizer.py @@ -7,8 +7,11 @@ """ Tokenizer implementing normalisation as used before Nominatim 4. """ +from typing import Optional, Sequence, List, Tuple, Mapping, Any, Callable, \ + cast, Dict, Set, Iterable from collections import OrderedDict import logging +from pathlib import Path import re import shutil from textwrap import dedent @@ -17,10 +20,12 @@ from icu import Transliterator import psycopg2 import psycopg2.extras -from nominatim.db.connection import connect +from nominatim.db.connection import connect, Connection +from nominatim.config import Configuration from nominatim.db import properties from nominatim.db import utils as db_utils from nominatim.db.sql_preprocessor import SQLPreprocessor +from nominatim.data.place_info import PlaceInfo from nominatim.errors import UsageError from nominatim.tokenizer.base import AbstractAnalyzer, AbstractTokenizer @@ -29,13 +34,13 @@ DBCFG_MAXWORDFREQ = "tokenizer_maxwordfreq" LOG = logging.getLogger() -def create(dsn, data_dir): +def create(dsn: str, data_dir: Path) -> 'LegacyTokenizer': """ Create a new instance of the tokenizer provided by this module. """ return LegacyTokenizer(dsn, data_dir) -def _install_module(config_module_path, src_dir, module_dir): +def _install_module(config_module_path: str, src_dir: Path, module_dir: Path) -> str: """ Copies the PostgreSQL normalisation module into the project directory if necessary. For historical reasons the module is saved in the '/module' subdirectory and not with the other tokenizer @@ -52,7 +57,7 @@ def _install_module(config_module_path, src_dir, module_dir): # Compatibility mode for builddir installations. if module_dir.exists() and src_dir.samefile(module_dir): LOG.info('Running from build directory. Leaving database module as is.') - return module_dir + return str(module_dir) # In any other case install the module in the project directory. if not module_dir.exists(): @@ -64,10 +69,10 @@ def _install_module(config_module_path, src_dir, module_dir): LOG.info('Database module installed at %s', str(destfile)) - return module_dir + return str(module_dir) -def _check_module(module_dir, conn): +def _check_module(module_dir: str, conn: Connection) -> None: """ Try to use the PostgreSQL module to confirm that it is correctly installed and accessible from PostgreSQL. """ @@ -89,13 +94,13 @@ class LegacyTokenizer(AbstractTokenizer): calls to the database. """ - def __init__(self, dsn, data_dir): + def __init__(self, dsn: str, data_dir: Path) -> None: self.dsn = dsn self.data_dir = data_dir - self.normalization = None + self.normalization: Optional[str] = None - def init_new_db(self, config, init_db=True): + def init_new_db(self, config: Configuration, init_db: bool = True) -> None: """ Set up a new tokenizer for the database. This copies all necessary data in the project directory to make @@ -119,7 +124,7 @@ class LegacyTokenizer(AbstractTokenizer): self._init_db_tables(config) - def init_from_project(self, config): + def init_from_project(self, config: Configuration) -> None: """ Initialise the tokenizer from the project directory. """ with connect(self.dsn) as conn: @@ -132,7 +137,7 @@ class LegacyTokenizer(AbstractTokenizer): self._install_php(config, overwrite=False) - def finalize_import(self, config): + def finalize_import(self, config: Configuration) -> None: """ Do any required postprocessing to make the tokenizer data ready for use. """ @@ -141,7 +146,7 @@ class LegacyTokenizer(AbstractTokenizer): sqlp.run_sql_file(conn, 'tokenizer/legacy_tokenizer_indices.sql') - def update_sql_functions(self, config): + def update_sql_functions(self, config: Configuration) -> None: """ Reimport the SQL functions for this tokenizer. """ with connect(self.dsn) as conn: @@ -154,7 +159,7 @@ class LegacyTokenizer(AbstractTokenizer): modulepath=modulepath) - def check_database(self, _): + def check_database(self, _: Configuration) -> Optional[str]: """ Check that the tokenizer is set up correctly. """ hint = """\ @@ -181,7 +186,7 @@ class LegacyTokenizer(AbstractTokenizer): return None - def migrate_database(self, config): + def migrate_database(self, config: Configuration) -> None: """ Initialise the project directory of an existing database for use with this tokenizer. @@ -198,7 +203,7 @@ class LegacyTokenizer(AbstractTokenizer): self._save_config(conn, config) - def update_statistics(self): + def update_statistics(self) -> None: """ Recompute the frequency of full words. """ with connect(self.dsn) as conn: @@ -218,13 +223,13 @@ class LegacyTokenizer(AbstractTokenizer): conn.commit() - def update_word_tokens(self): + def update_word_tokens(self) -> None: """ No house-keeping implemented for the legacy tokenizer. """ LOG.info("No tokenizer clean-up available.") - def name_analyzer(self): + def name_analyzer(self) -> 'LegacyNameAnalyzer': """ Create a new analyzer for tokenizing names and queries using this tokinzer. Analyzers are context managers and should be used accordingly: @@ -244,7 +249,7 @@ class LegacyTokenizer(AbstractTokenizer): return LegacyNameAnalyzer(self.dsn, normalizer) - def _install_php(self, config, overwrite=True): + def _install_php(self, config: Configuration, overwrite: bool = True) -> None: """ Install the php script for the tokenizer. """ php_file = self.data_dir / "tokenizer.php" @@ -258,7 +263,7 @@ class LegacyTokenizer(AbstractTokenizer): """), encoding='utf-8') - def _init_db_tables(self, config): + def _init_db_tables(self, config: Configuration) -> None: """ Set up the word table and fill it with pre-computed word frequencies. """ @@ -271,10 +276,12 @@ class LegacyTokenizer(AbstractTokenizer): db_utils.execute_file(self.dsn, config.lib_dir.data / 'words.sql') - def _save_config(self, conn, config): + def _save_config(self, conn: Connection, config: Configuration) -> None: """ Save the configuration that needs to remain stable for the given database as database properties. """ + assert self.normalization is not None + properties.set_property(conn, DBCFG_NORMALIZATION, self.normalization) properties.set_property(conn, DBCFG_MAXWORDFREQ, config.MAX_WORD_FREQUENCY) @@ -287,8 +294,8 @@ class LegacyNameAnalyzer(AbstractAnalyzer): normalization. """ - def __init__(self, dsn, normalizer): - self.conn = connect(dsn).connection + def __init__(self, dsn: str, normalizer: Any): + self.conn: Optional[Connection] = connect(dsn).connection self.conn.autocommit = True self.normalizer = normalizer psycopg2.extras.register_hstore(self.conn) @@ -296,7 +303,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer): self._cache = _TokenCache(self.conn) - def close(self): + def close(self) -> None: """ Free all resources used by the analyzer. """ if self.conn: @@ -304,7 +311,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer): self.conn = None - def get_word_token_info(self, words): + def get_word_token_info(self, words: Sequence[str]) -> List[Tuple[str, str, int]]: """ Return token information for the given list of words. If a word starts with # it is assumed to be a full name otherwise is a partial name. @@ -315,6 +322,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer): The function is used for testing and debugging only and not necessarily efficient. """ + assert self.conn is not None with self.conn.cursor() as cur: cur.execute("""SELECT t.term, word_token, word_id FROM word, (SELECT unnest(%s::TEXT[]) as term) t @@ -330,14 +338,14 @@ class LegacyNameAnalyzer(AbstractAnalyzer): return [(r[0], r[1], r[2]) for r in cur] - def normalize(self, phrase): + def normalize(self, phrase: str) -> str: """ Normalize the given phrase, i.e. remove all properties that are irrelevant for search. """ - return self.normalizer.transliterate(phrase) + return cast(str, self.normalizer.transliterate(phrase)) - def normalize_postcode(self, postcode): + def normalize_postcode(self, postcode: str) -> str: """ Convert the postcode to a standardized form. This function must yield exactly the same result as the SQL function @@ -346,10 +354,12 @@ class LegacyNameAnalyzer(AbstractAnalyzer): return postcode.strip().upper() - def update_postcodes_from_db(self): + def update_postcodes_from_db(self) -> None: """ Update postcode tokens in the word table from the location_postcode table. """ + assert self.conn is not None + with self.conn.cursor() as cur: # This finds us the rows in location_postcode and word that are # missing in the other table. @@ -383,9 +393,12 @@ class LegacyNameAnalyzer(AbstractAnalyzer): - def update_special_phrases(self, phrases, should_replace): + def update_special_phrases(self, phrases: Iterable[Tuple[str, str, str, str]], + should_replace: bool) -> None: """ Replace the search index for special phrases with the new phrases. """ + assert self.conn is not None + norm_phrases = set(((self.normalize(p[0]), p[1], p[2], p[3]) for p in phrases)) @@ -422,9 +435,11 @@ class LegacyNameAnalyzer(AbstractAnalyzer): len(norm_phrases), len(to_add), len(to_delete)) - def add_country_names(self, country_code, names): + def add_country_names(self, country_code: str, names: Mapping[str, str]) -> None: """ Add names for the given country to the search index. """ + assert self.conn is not None + with self.conn.cursor() as cur: cur.execute( """INSERT INTO word (word_id, word_token, country_code) @@ -436,12 +451,14 @@ class LegacyNameAnalyzer(AbstractAnalyzer): """, (country_code, list(names.values()), country_code)) - def process_place(self, place): + def process_place(self, place: PlaceInfo) -> Mapping[str, Any]: """ Determine tokenizer information about the given place. Returns a JSON-serialisable structure that will be handed into the database via the token_info field. """ + assert self.conn is not None + token_info = _TokenInfo(self._cache) names = place.name @@ -450,6 +467,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer): token_info.add_names(self.conn, names) if place.is_country(): + assert place.country_code is not None self.add_country_names(place.country_code, names) address = place.address @@ -459,7 +477,8 @@ class LegacyNameAnalyzer(AbstractAnalyzer): return token_info.data - def _process_place_address(self, token_info, address): + def _process_place_address(self, token_info: '_TokenInfo', address: Mapping[str, str]) -> None: + assert self.conn is not None hnrs = [] addr_terms = [] @@ -491,12 +510,12 @@ class LegacyNameAnalyzer(AbstractAnalyzer): class _TokenInfo: """ Collect token information to be sent back to the database. """ - def __init__(self, cache): + def __init__(self, cache: '_TokenCache') -> None: self.cache = cache - self.data = {} + self.data: Dict[str, Any] = {} - def add_names(self, conn, names): + def add_names(self, conn: Connection, names: Mapping[str, str]) -> None: """ Add token information for the names of the place. """ with conn.cursor() as cur: @@ -505,7 +524,7 @@ class _TokenInfo: (names, )) - def add_housenumbers(self, conn, hnrs): + def add_housenumbers(self, conn: Connection, hnrs: Sequence[str]) -> None: """ Extract housenumber information from the address. """ if len(hnrs) == 1: @@ -516,7 +535,7 @@ class _TokenInfo: return # split numbers if necessary - simple_list = [] + simple_list: List[str] = [] for hnr in hnrs: simple_list.extend((x.strip() for x in re.split(r'[;,]', hnr))) @@ -525,49 +544,53 @@ class _TokenInfo: with conn.cursor() as cur: cur.execute("SELECT * FROM create_housenumbers(%s)", (simple_list, )) - self.data['hnr_tokens'], self.data['hnr'] = cur.fetchone() + self.data['hnr_tokens'], self.data['hnr'] = \ + cur.fetchone() # type: ignore[no-untyped-call] - def set_postcode(self, postcode): + def set_postcode(self, postcode: str) -> None: """ Set or replace the postcode token with the given value. """ self.data['postcode'] = postcode - def add_street(self, conn, street): + def add_street(self, conn: Connection, street: str) -> None: """ Add addr:street match terms. """ - def _get_street(name): + def _get_street(name: str) -> List[int]: with conn.cursor() as cur: - return cur.scalar("SELECT word_ids_from_name(%s)::text", (name, )) + return cast(List[int], + cur.scalar("SELECT word_ids_from_name(%s)::text", (name, ))) tokens = self.cache.streets.get(street, _get_street) if tokens: self.data['street'] = tokens - def add_place(self, conn, place): + def add_place(self, conn: Connection, place: str) -> None: """ Add addr:place search and match terms. """ - def _get_place(name): + def _get_place(name: str) -> Tuple[List[int], List[int]]: with conn.cursor() as cur: cur.execute("""SELECT make_keywords(hstore('name' , %s))::text, word_ids_from_name(%s)::text""", (name, name)) - return cur.fetchone() + return cast(Tuple[List[int], List[int]], + cur.fetchone()) # type: ignore[no-untyped-call] self.data['place_search'], self.data['place_match'] = \ self.cache.places.get(place, _get_place) - def add_address_terms(self, conn, terms): + def add_address_terms(self, conn: Connection, terms: Sequence[Tuple[str, str]]) -> None: """ Add additional address terms. """ - def _get_address_term(name): + def _get_address_term(name: str) -> Tuple[List[int], List[int]]: with conn.cursor() as cur: cur.execute("""SELECT addr_ids_from_name(%s)::text, word_ids_from_name(%s)::text""", (name, name)) - return cur.fetchone() + return cast(Tuple[List[int], List[int]], + cur.fetchone()) # type: ignore[no-untyped-call] tokens = {} for key, value in terms: @@ -584,13 +607,12 @@ class _LRU: produce the item when there is a cache miss. """ - def __init__(self, maxsize=128, init_data=None): - self.data = init_data or OrderedDict() + def __init__(self, maxsize: int = 128): + self.data: 'OrderedDict[str, Any]' = OrderedDict() self.maxsize = maxsize - if init_data is not None and len(init_data) > maxsize: - self.maxsize = len(init_data) - def get(self, key, generator): + + def get(self, key: str, generator: Callable[[str], Any]) -> Any: """ Get the item with the given key from the cache. If nothing is found in the cache, generate the value through the generator function and store it in the cache. @@ -613,7 +635,7 @@ class _TokenCache: This cache is not thread-safe and needs to be instantiated per analyzer. """ - def __init__(self, conn): + def __init__(self, conn: Connection): # various LRU caches self.streets = _LRU(maxsize=256) self.places = _LRU(maxsize=128) @@ -623,18 +645,18 @@ class _TokenCache: with conn.cursor() as cur: cur.execute("""SELECT i, ARRAY[getorcreate_housenumber_id(i::text)]::text FROM generate_series(1, 100) as i""") - self._cached_housenumbers = {str(r[0]): r[1] for r in cur} + self._cached_housenumbers: Dict[str, str] = {str(r[0]): r[1] for r in cur} # For postcodes remember the ones that have already been added - self.postcodes = set() + self.postcodes: Set[str] = set() - def get_housenumber(self, number): + def get_housenumber(self, number: str) -> Optional[str]: """ Get a housenumber token from the cache. """ return self._cached_housenumbers.get(number) - def add_postcode(self, conn, postcode): + def add_postcode(self, conn: Connection, postcode: str) -> None: """ Make sure the given postcode is in the database. """ if postcode not in self.postcodes: diff --git a/nominatim/tokenizer/place_sanitizer.py b/nominatim/tokenizer/place_sanitizer.py index 913b363c..3f548e06 100644 --- a/nominatim/tokenizer/place_sanitizer.py +++ b/nominatim/tokenizer/place_sanitizer.py @@ -8,100 +8,13 @@ Handler for cleaning name and address tags in place information before it is handed to the token analysis. """ +from typing import Optional, List, Mapping, Sequence, Callable, Any, Tuple import importlib from nominatim.errors import UsageError from nominatim.tokenizer.sanitizers.config import SanitizerConfig - -class PlaceName: - """ A searchable name for a place together with properties. - Every name object saves the name proper and two basic properties: - * 'kind' describes the name of the OSM key used without any suffixes - (i.e. the part after the colon removed) - * 'suffix' contains the suffix of the OSM tag, if any. The suffix - is the part of the key after the first colon. - In addition to that, the name may have arbitrary additional attributes. - Which attributes are used, depends on the token analyser. - """ - - def __init__(self, name, kind, suffix): - self.name = name - self.kind = kind - self.suffix = suffix - self.attr = {} - - - def __repr__(self): - return f"PlaceName(name='{self.name}',kind='{self.kind}',suffix='{self.suffix}')" - - - def clone(self, name=None, kind=None, suffix=None, attr=None): - """ Create a deep copy of the place name, optionally with the - given parameters replaced. In the attribute list only the given - keys are updated. The list is not replaced completely. - In particular, the function cannot to be used to remove an - attribute from a place name. - """ - newobj = PlaceName(name or self.name, - kind or self.kind, - suffix or self.suffix) - - newobj.attr.update(self.attr) - if attr: - newobj.attr.update(attr) - - return newobj - - - def set_attr(self, key, value): - """ Add the given property to the name. If the property was already - set, then the value is overwritten. - """ - self.attr[key] = value - - - def get_attr(self, key, default=None): - """ Return the given property or the value of 'default' if it - is not set. - """ - return self.attr.get(key, default) - - - def has_attr(self, key): - """ Check if the given attribute is set. - """ - return key in self.attr - - -class _ProcessInfo: - """ Container class for information handed into to handler functions. - The 'names' and 'address' members are mutable. A handler must change - them by either modifying the lists place or replacing the old content - with a new list. - """ - - def __init__(self, place): - self.place = place - self.names = self._convert_name_dict(place.name) - self.address = self._convert_name_dict(place.address) - - - @staticmethod - def _convert_name_dict(names): - """ Convert a dictionary of names into a list of PlaceNames. - The dictionary key is split into the primary part of the key - and the suffix (the part after an optional colon). - """ - out = [] - - if names: - for key, value in names.items(): - parts = key.split(':', 1) - out.append(PlaceName(value.strip(), - parts[0].strip(), - parts[1].strip() if len(parts) > 1 else None)) - - return out +from nominatim.tokenizer.sanitizers.base import SanitizerHandler, ProcessInfo, PlaceName +from nominatim.data.place_info import PlaceInfo class PlaceSanitizer: @@ -109,24 +22,24 @@ class PlaceSanitizer: names and address before they are used by the token analysers. """ - def __init__(self, rules): - self.handlers = [] + def __init__(self, rules: Optional[Sequence[Mapping[str, Any]]]) -> None: + self.handlers: List[Callable[[ProcessInfo], None]] = [] if rules: for func in rules: if 'step' not in func: raise UsageError("Sanitizer rule is missing the 'step' attribute.") module_name = 'nominatim.tokenizer.sanitizers.' + func['step'].replace('-', '_') - handler_module = importlib.import_module(module_name) + handler_module: SanitizerHandler = importlib.import_module(module_name) self.handlers.append(handler_module.create(SanitizerConfig(func))) - def process_names(self, place): + def process_names(self, place: PlaceInfo) -> Tuple[List[PlaceName], List[PlaceName]]: """ Extract a sanitized list of names and address parts from the given place. The function returns a tuple (list of names, list of address names) """ - obj = _ProcessInfo(place) + obj = ProcessInfo(place) for func in self.handlers: func(obj) diff --git a/nominatim/tokenizer/sanitizers/base.py b/nominatim/tokenizer/sanitizers/base.py new file mode 100644 index 00000000..692c6d5f --- /dev/null +++ b/nominatim/tokenizer/sanitizers/base.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2022 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Common data types and protocols for sanitizers. +""" +from typing import Optional, Dict, List, Mapping, Callable + +from nominatim.tokenizer.sanitizers.config import SanitizerConfig +from nominatim.data.place_info import PlaceInfo +from nominatim.typing import Protocol, Final + +class PlaceName: + """ A searchable name for a place together with properties. + Every name object saves the name proper and two basic properties: + * 'kind' describes the name of the OSM key used without any suffixes + (i.e. the part after the colon removed) + * 'suffix' contains the suffix of the OSM tag, if any. The suffix + is the part of the key after the first colon. + In addition to that, the name may have arbitrary additional attributes. + Which attributes are used, depends on the token analyser. + """ + + def __init__(self, name: str, kind: str, suffix: Optional[str]): + self.name = name + self.kind = kind + self.suffix = suffix + self.attr: Dict[str, str] = {} + + + def __repr__(self) -> str: + return f"PlaceName(name='{self.name}',kind='{self.kind}',suffix='{self.suffix}')" + + + def clone(self, name: Optional[str] = None, + kind: Optional[str] = None, + suffix: Optional[str] = None, + attr: Optional[Mapping[str, str]] = None) -> 'PlaceName': + """ Create a deep copy of the place name, optionally with the + given parameters replaced. In the attribute list only the given + keys are updated. The list is not replaced completely. + In particular, the function cannot to be used to remove an + attribute from a place name. + """ + newobj = PlaceName(name or self.name, + kind or self.kind, + suffix or self.suffix) + + newobj.attr.update(self.attr) + if attr: + newobj.attr.update(attr) + + return newobj + + + def set_attr(self, key: str, value: str) -> None: + """ Add the given property to the name. If the property was already + set, then the value is overwritten. + """ + self.attr[key] = value + + + def get_attr(self, key: str, default: Optional[str] = None) -> Optional[str]: + """ Return the given property or the value of 'default' if it + is not set. + """ + return self.attr.get(key, default) + + + def has_attr(self, key: str) -> bool: + """ Check if the given attribute is set. + """ + return key in self.attr + + +class ProcessInfo: + """ Container class for information handed into to handler functions. + The 'names' and 'address' members are mutable. A handler must change + them by either modifying the lists place or replacing the old content + with a new list. + """ + + def __init__(self, place: PlaceInfo): + self.place: Final = place + self.names = self._convert_name_dict(place.name) + self.address = self._convert_name_dict(place.address) + + + @staticmethod + def _convert_name_dict(names: Optional[Mapping[str, str]]) -> List[PlaceName]: + """ Convert a dictionary of names into a list of PlaceNames. + The dictionary key is split into the primary part of the key + and the suffix (the part after an optional colon). + """ + out = [] + + if names: + for key, value in names.items(): + parts = key.split(':', 1) + out.append(PlaceName(value.strip(), + parts[0].strip(), + parts[1].strip() if len(parts) > 1 else None)) + + return out + + +class SanitizerHandler(Protocol): + """ Protocol for sanitizer modules. + """ + + def create(self, config: SanitizerConfig) -> Callable[[ProcessInfo], None]: + """ + A sanitizer must define a single function `create`. It takes the + dictionary with the configuration information for the sanitizer and + returns a function that transforms name and address. + """ diff --git a/nominatim/tokenizer/sanitizers/clean_housenumbers.py b/nominatim/tokenizer/sanitizers/clean_housenumbers.py index c229716f..5df057d0 100644 --- a/nominatim/tokenizer/sanitizers/clean_housenumbers.py +++ b/nominatim/tokenizer/sanitizers/clean_housenumbers.py @@ -24,11 +24,15 @@ Arguments: or a list of strings, where each string is a regular expression that must match the full house number value. """ +from typing import Callable, Iterator, List import re +from nominatim.tokenizer.sanitizers.base import ProcessInfo, PlaceName +from nominatim.tokenizer.sanitizers.config import SanitizerConfig + class _HousenumberSanitizer: - def __init__(self, config): + def __init__(self, config: SanitizerConfig) -> None: self.filter_kind = config.get_filter_kind('housenumber') self.split_regexp = config.get_delimiter() @@ -37,13 +41,13 @@ class _HousenumberSanitizer: - def __call__(self, obj): + def __call__(self, obj: ProcessInfo) -> None: if not obj.address: return - new_address = [] + new_address: List[PlaceName] = [] for item in obj.address: - if self.filter_kind(item): + if self.filter_kind(item.kind): if self._treat_as_name(item.name): obj.names.append(item.clone(kind='housenumber')) else: @@ -56,7 +60,7 @@ class _HousenumberSanitizer: obj.address = new_address - def sanitize(self, value): + def sanitize(self, value: str) -> Iterator[str]: """ Extract housenumbers in a regularized format from an OSM value. The function works as a generator that yields all valid housenumbers @@ -67,16 +71,15 @@ class _HousenumberSanitizer: yield from self._regularize(hnr) - @staticmethod - def _regularize(hnr): + def _regularize(self, hnr: str) -> Iterator[str]: yield hnr - def _treat_as_name(self, housenumber): + def _treat_as_name(self, housenumber: str) -> bool: return any(r.fullmatch(housenumber) is not None for r in self.is_name_regexp) -def create(config): +def create(config: SanitizerConfig) -> Callable[[ProcessInfo], None]: """ Create a housenumber processing function. """ diff --git a/nominatim/tokenizer/sanitizers/clean_postcodes.py b/nominatim/tokenizer/sanitizers/clean_postcodes.py index 05e90ca1..cabacff4 100644 --- a/nominatim/tokenizer/sanitizers/clean_postcodes.py +++ b/nominatim/tokenizer/sanitizers/clean_postcodes.py @@ -20,11 +20,15 @@ Arguments: objects that have no country assigned. These are always assumed to have no postcode. """ +from typing import Callable, Optional, Tuple + from nominatim.data.postcode_format import PostcodeFormatter +from nominatim.tokenizer.sanitizers.base import ProcessInfo +from nominatim.tokenizer.sanitizers.config import SanitizerConfig class _PostcodeSanitizer: - def __init__(self, config): + def __init__(self, config: SanitizerConfig) -> None: self.convert_to_address = config.get_bool('convert-to-address', True) self.matcher = PostcodeFormatter() @@ -33,7 +37,7 @@ class _PostcodeSanitizer: self.matcher.set_default_pattern(default_pattern) - def __call__(self, obj): + def __call__(self, obj: ProcessInfo) -> None: if not obj.address: return @@ -52,7 +56,7 @@ class _PostcodeSanitizer: postcode.set_attr('variant', formatted[1]) - def scan(self, postcode, country): + def scan(self, postcode: str, country: Optional[str]) -> Optional[Tuple[str, str]]: """ Check the postcode for correct formatting and return the normalized version. Returns None if the postcode does not correspond to the oficial format of the given country. @@ -61,13 +65,15 @@ class _PostcodeSanitizer: if match is None: return None + assert country is not None + return self.matcher.normalize(country, match),\ ' '.join(filter(lambda p: p is not None, match.groups())) -def create(config): +def create(config: SanitizerConfig) -> Callable[[ProcessInfo], None]: """ Create a housenumber processing function. """ diff --git a/nominatim/tokenizer/sanitizers/config.py b/nominatim/tokenizer/sanitizers/config.py index ce5ce1eb..fd05848b 100644 --- a/nominatim/tokenizer/sanitizers/config.py +++ b/nominatim/tokenizer/sanitizers/config.py @@ -7,20 +7,28 @@ """ Configuration for Sanitizers. """ +from typing import Sequence, Optional, Pattern, Callable, Any, TYPE_CHECKING from collections import UserDict import re from nominatim.errors import UsageError -class SanitizerConfig(UserDict): +# working around missing generics in Python < 3.8 +# See https://github.com/python/typing/issues/60#issuecomment-869757075 +if TYPE_CHECKING: + _BaseUserDict = UserDict[str, Any] +else: + _BaseUserDict = UserDict + +class SanitizerConfig(_BaseUserDict): """ Dictionary with configuration options for a sanitizer. - In addition to the usualy dictionary function, the class provides + In addition to the usual dictionary function, the class provides accessors to standard sanatizer options that are used by many of the sanitizers. """ - def get_string_list(self, param, default=tuple()): + def get_string_list(self, param: str, default: Sequence[str] = tuple()) -> Sequence[str]: """ Extract a configuration parameter as a string list. If the parameter value is a simple string, it is returned as a one-item list. If the parameter value does not exist, the given @@ -44,7 +52,7 @@ class SanitizerConfig(UserDict): return values - def get_bool(self, param, default=None): + def get_bool(self, param: str, default: Optional[bool] = None) -> bool: """ Extract a configuration parameter as a boolean. The parameter must be one of the yaml boolean values or an user error will be raised. If `default` is given, then the parameter @@ -58,7 +66,7 @@ class SanitizerConfig(UserDict): return value - def get_delimiter(self, default=',;'): + def get_delimiter(self, default: str = ',;') -> Pattern[str]: """ Return the 'delimiter' parameter in the configuration as a compiled regular expression that can be used to split the names on the delimiters. The regular expression makes sure that the resulting names @@ -76,7 +84,7 @@ class SanitizerConfig(UserDict): return re.compile('\\s*[{}]+\\s*'.format(''.join('\\' + d for d in delimiter_set))) - def get_filter_kind(self, *default): + def get_filter_kind(self, *default: str) -> Callable[[str], bool]: """ Return a filter function for the name kind from the 'filter-kind' config parameter. The filter functions takes a name item and returns True when the item passes the filter. @@ -93,4 +101,4 @@ class SanitizerConfig(UserDict): regexes = [re.compile(regex) for regex in filters] - return lambda name: any(regex.fullmatch(name.kind) for regex in regexes) + return lambda name: any(regex.fullmatch(name) for regex in regexes) diff --git a/nominatim/tokenizer/sanitizers/split_name_list.py b/nominatim/tokenizer/sanitizers/split_name_list.py index c9db0a9d..7d0667b4 100644 --- a/nominatim/tokenizer/sanitizers/split_name_list.py +++ b/nominatim/tokenizer/sanitizers/split_name_list.py @@ -11,13 +11,18 @@ Arguments: delimiters: Define the set of characters to be used for splitting the list. (default: ',;') """ -def create(config): +from typing import Callable + +from nominatim.tokenizer.sanitizers.base import ProcessInfo +from nominatim.tokenizer.sanitizers.config import SanitizerConfig + +def create(config: SanitizerConfig) -> Callable[[ProcessInfo], None]: """ Create a name processing function that splits name values with multiple values into their components. """ regexp = config.get_delimiter() - def _process(obj): + def _process(obj: ProcessInfo) -> None: if not obj.names: return diff --git a/nominatim/tokenizer/sanitizers/strip_brace_terms.py b/nominatim/tokenizer/sanitizers/strip_brace_terms.py index f8cdd035..119d5693 100644 --- a/nominatim/tokenizer/sanitizers/strip_brace_terms.py +++ b/nominatim/tokenizer/sanitizers/strip_brace_terms.py @@ -9,12 +9,17 @@ This sanitizer creates additional name variants for names that have addendums in brackets (e.g. "Halle (Saale)"). The additional variant contains only the main name part with the bracket part removed. """ +from typing import Callable -def create(_): +from nominatim.tokenizer.sanitizers.base import ProcessInfo +from nominatim.tokenizer.sanitizers.config import SanitizerConfig + + +def create(_: SanitizerConfig) -> Callable[[ProcessInfo], None]: """ Create a name processing function that creates additional name variants for bracket addendums. """ - def _process(obj): + def _process(obj: ProcessInfo) -> None: """ Add variants for names that have a bracket extension. """ if obj.names: diff --git a/nominatim/tokenizer/sanitizers/tag_analyzer_by_language.py b/nominatim/tokenizer/sanitizers/tag_analyzer_by_language.py index d3413c1a..6d6430f0 100644 --- a/nominatim/tokenizer/sanitizers/tag_analyzer_by_language.py +++ b/nominatim/tokenizer/sanitizers/tag_analyzer_by_language.py @@ -30,13 +30,17 @@ Arguments: any analyzer tagged) is retained. (default: replace) """ +from typing import Callable, Dict, Optional, List + from nominatim.data import country_info +from nominatim.tokenizer.sanitizers.base import ProcessInfo +from nominatim.tokenizer.sanitizers.config import SanitizerConfig class _AnalyzerByLanguage: """ Processor for tagging the language of names in a place. """ - def __init__(self, config): + def __init__(self, config: SanitizerConfig) -> None: self.filter_kind = config.get_filter_kind() self.replace = config.get('mode', 'replace') != 'append' self.whitelist = config.get('whitelist') @@ -44,8 +48,8 @@ class _AnalyzerByLanguage: self._compute_default_languages(config.get('use-defaults', 'no')) - def _compute_default_languages(self, use_defaults): - self.deflangs = {} + def _compute_default_languages(self, use_defaults: str) -> None: + self.deflangs: Dict[Optional[str], List[str]] = {} if use_defaults in ('mono', 'all'): for ccode, clangs in country_info.iterate('languages'): @@ -56,21 +60,21 @@ class _AnalyzerByLanguage: self.deflangs[ccode] = clangs - def _suffix_matches(self, suffix): + def _suffix_matches(self, suffix: str) -> bool: if self.whitelist is None: return len(suffix) in (2, 3) and suffix.islower() return suffix in self.whitelist - def __call__(self, obj): + def __call__(self, obj: ProcessInfo) -> None: if not obj.names: return more_names = [] for name in (n for n in obj.names - if not n.has_attr('analyzer') and self.filter_kind(n)): + if not n.has_attr('analyzer') and self.filter_kind(n.kind)): if name.suffix: langs = [name.suffix] if self._suffix_matches(name.suffix) else None else: @@ -88,7 +92,7 @@ class _AnalyzerByLanguage: obj.names.extend(more_names) -def create(config): +def create(config: SanitizerConfig) -> Callable[[ProcessInfo], None]: """ Create a function that sets the analyzer property depending on the language of the tag. """ diff --git a/nominatim/tokenizer/token_analysis/base.py b/nominatim/tokenizer/token_analysis/base.py new file mode 100644 index 00000000..b2a4386c --- /dev/null +++ b/nominatim/tokenizer/token_analysis/base.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2022 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Common data types and protocols for analysers. +""" +from typing import Mapping, List, Any + +from nominatim.typing import Protocol + +class Analyser(Protocol): + """ Instance of the token analyser. + """ + + def normalize(self, name: str) -> str: + """ Return the normalized form of the name. This is the standard form + from which possible variants for the name can be derived. + """ + + def get_variants_ascii(self, norm_name: str) -> List[str]: + """ Compute the spelling variants for the given normalized name + and transliterate the result. + """ + +class AnalysisModule(Protocol): + """ Protocol for analysis modules. + """ + + def configure(self, rules: Mapping[str, Any], normalization_rules: str) -> Any: + """ Prepare the configuration of the analysis module. + This function should prepare all data that can be shared + between instances of this analyser. + """ + + def create(self, normalizer: Any, transliterator: Any, config: Any) -> Analyser: + """ Create a new instance of the analyser. + A separate instance of the analyser is created for each thread + when used in multi-threading context. + """ diff --git a/nominatim/tokenizer/token_analysis/config_variants.py b/nominatim/tokenizer/token_analysis/config_variants.py index 067c4b5b..e0d1579d 100644 --- a/nominatim/tokenizer/token_analysis/config_variants.py +++ b/nominatim/tokenizer/token_analysis/config_variants.py @@ -7,7 +7,8 @@ """ Parser for configuration for variants. """ -from collections import defaultdict, namedtuple +from typing import Any, Iterator, Tuple, List, Optional, Set, NamedTuple +from collections import defaultdict import itertools import re @@ -16,9 +17,15 @@ from icu import Transliterator from nominatim.config import flatten_config_list from nominatim.errors import UsageError -ICUVariant = namedtuple('ICUVariant', ['source', 'replacement']) +class ICUVariant(NamedTuple): + """ A single replacement rule for variant creation. + """ + source: str + replacement: str -def get_variant_config(rules, normalization_rules): + +def get_variant_config(in_rules: Any, + normalization_rules: str) -> Tuple[List[Tuple[str, List[str]]], str]: """ Convert the variant definition from the configuration into replacement sets. @@ -26,11 +33,11 @@ def get_variant_config(rules, normalization_rules): used in the replacements. """ immediate = defaultdict(list) - chars = set() + chars: Set[str] = set() - if rules: - vset = set() - rules = flatten_config_list(rules, 'variants') + if in_rules: + vset: Set[ICUVariant] = set() + rules = flatten_config_list(in_rules, 'variants') vmaker = _VariantMaker(normalization_rules) @@ -56,12 +63,12 @@ class _VariantMaker: All text in rules is normalized to make sure the variants match later. """ - def __init__(self, norm_rules): + def __init__(self, norm_rules: Any) -> None: self.norm = Transliterator.createFromRules("rule_loader_normalization", norm_rules) - def compute(self, rule): + def compute(self, rule: Any) -> Iterator[ICUVariant]: """ Generator for all ICUVariant tuples from a single variant rule. """ parts = re.split(r'(\|)?([=-])>', rule) @@ -85,7 +92,7 @@ class _VariantMaker: yield ICUVariant(froms, tos) - def _parse_variant_word(self, name): + def _parse_variant_word(self, name: str) -> Optional[Tuple[str, str, str]]: name = name.strip() match = re.fullmatch(r'([~^]?)([^~$^]*)([~$]?)', name) if match is None or (match.group(1) == '~' and match.group(3) == '~'): @@ -102,7 +109,8 @@ _FLAG_MATCH = {'^': '^ ', '': ' '} -def _create_variants(src, preflag, postflag, repl, decompose): +def _create_variants(src: str, preflag: str, postflag: str, + repl: str, decompose: bool) -> Iterator[Tuple[str, str]]: if preflag == '~': postfix = _FLAG_MATCH[postflag] # suffix decomposition diff --git a/nominatim/tokenizer/token_analysis/generic.py b/nominatim/tokenizer/token_analysis/generic.py index 3de915ba..e14f844c 100644 --- a/nominatim/tokenizer/token_analysis/generic.py +++ b/nominatim/tokenizer/token_analysis/generic.py @@ -7,6 +7,7 @@ """ Generic processor for names that creates abbreviation variants. """ +from typing import Mapping, Dict, Any, Iterable, Iterator, Optional, List, cast import itertools import datrie @@ -17,10 +18,10 @@ from nominatim.tokenizer.token_analysis.generic_mutation import MutationVariantG ### Configuration section -def configure(rules, normalization_rules): +def configure(rules: Mapping[str, Any], normalization_rules: str) -> Dict[str, Any]: """ Extract and preprocess the configuration for this module. """ - config = {} + config: Dict[str, Any] = {} config['replacements'], config['chars'] = get_variant_config(rules.get('variants'), normalization_rules) @@ -47,7 +48,8 @@ def configure(rules, normalization_rules): ### Analysis section -def create(normalizer, transliterator, config): +def create(normalizer: Any, transliterator: Any, + config: Mapping[str, Any]) -> 'GenericTokenAnalysis': """ Create a new token analysis instance for this module. """ return GenericTokenAnalysis(normalizer, transliterator, config) @@ -58,7 +60,7 @@ class GenericTokenAnalysis: and provides the functions to apply the transformations. """ - def __init__(self, norm, to_ascii, config): + def __init__(self, norm: Any, to_ascii: Any, config: Mapping[str, Any]) -> None: self.norm = norm self.to_ascii = to_ascii self.variant_only = config['variant_only'] @@ -75,14 +77,14 @@ class GenericTokenAnalysis: self.mutations = [MutationVariantGenerator(*cfg) for cfg in config['mutations']] - def normalize(self, name): + def normalize(self, name: str) -> str: """ Return the normalized form of the name. This is the standard form from which possible variants for the name can be derived. """ - return self.norm.transliterate(name).strip() + return cast(str, self.norm.transliterate(name)).strip() - def get_variants_ascii(self, norm_name): + def get_variants_ascii(self, norm_name: str) -> List[str]: """ Compute the spelling variants for the given normalized name and transliterate the result. """ @@ -94,7 +96,8 @@ class GenericTokenAnalysis: return [name for name in self._transliterate_unique_list(norm_name, variants) if name] - def _transliterate_unique_list(self, norm_name, iterable): + def _transliterate_unique_list(self, norm_name: str, + iterable: Iterable[str]) -> Iterator[Optional[str]]: seen = set() if self.variant_only: seen.add(norm_name) @@ -105,7 +108,7 @@ class GenericTokenAnalysis: yield self.to_ascii.transliterate(variant).strip() - def _generate_word_variants(self, norm_name): + def _generate_word_variants(self, norm_name: str) -> Iterable[str]: baseform = '^ ' + norm_name + ' ^' baselen = len(baseform) partials = [''] diff --git a/nominatim/tokenizer/token_analysis/generic_mutation.py b/nominatim/tokenizer/token_analysis/generic_mutation.py index d23d5cd4..47154537 100644 --- a/nominatim/tokenizer/token_analysis/generic_mutation.py +++ b/nominatim/tokenizer/token_analysis/generic_mutation.py @@ -7,6 +7,7 @@ """ Creator for mutation variants for the generic token analysis. """ +from typing import Sequence, Iterable, Iterator, Tuple import itertools import logging import re @@ -15,7 +16,7 @@ from nominatim.errors import UsageError LOG = logging.getLogger() -def _zigzag(outer, inner): +def _zigzag(outer: Iterable[str], inner: Iterable[str]) -> Iterator[str]: return itertools.chain.from_iterable(itertools.zip_longest(outer, inner, fillvalue='')) @@ -26,7 +27,7 @@ class MutationVariantGenerator: patterns. """ - def __init__(self, pattern, replacements): + def __init__(self, pattern: str, replacements: Sequence[str]): self.pattern = re.compile(pattern) self.replacements = replacements @@ -36,7 +37,7 @@ class MutationVariantGenerator: raise UsageError("Bad mutation pattern in configuration.") - def generate(self, names): + def generate(self, names: Iterable[str]) -> Iterator[str]: """ Generator function for the name variants. 'names' is an iterable over a set of names for which the variants are to be generated. """ @@ -49,7 +50,7 @@ class MutationVariantGenerator: yield ''.join(_zigzag(parts, seps)) - def _fillers(self, num_parts): + def _fillers(self, num_parts: int) -> Iterator[Tuple[str, ...]]: """ Returns a generator for strings to join the given number of string parts in all possible combinations. """ diff --git a/nominatim/tokenizer/token_analysis/housenumbers.py b/nominatim/tokenizer/token_analysis/housenumbers.py index 96e86b28..a0f4214d 100644 --- a/nominatim/tokenizer/token_analysis/housenumbers.py +++ b/nominatim/tokenizer/token_analysis/housenumbers.py @@ -8,6 +8,7 @@ Specialized processor for housenumbers. Analyses common housenumber patterns and creates variants for them. """ +from typing import Mapping, Any, List, cast import re from nominatim.tokenizer.token_analysis.generic_mutation import MutationVariantGenerator @@ -19,14 +20,14 @@ RE_NAMED_PART = re.compile(r'[a-z]{4}') ### Configuration section -def configure(rules, normalization_rules): # pylint: disable=W0613 +def configure(rules: Mapping[str, Any], normalization_rules: str) -> None: # pylint: disable=W0613 """ All behaviour is currently hard-coded. """ return None ### Analysis section -def create(normalizer, transliterator, config): # pylint: disable=W0613 +def create(normalizer: Any, transliterator: Any, config: None) -> 'HousenumberTokenAnalysis': # pylint: disable=W0613 """ Create a new token analysis instance for this module. """ return HousenumberTokenAnalysis(normalizer, transliterator) @@ -35,20 +36,20 @@ def create(normalizer, transliterator, config): # pylint: disable=W0613 class HousenumberTokenAnalysis: """ Detects common housenumber patterns and normalizes them. """ - def __init__(self, norm, trans): + def __init__(self, norm: Any, trans: Any) -> None: self.norm = norm self.trans = trans self.mutator = MutationVariantGenerator('␣', (' ', '')) - def normalize(self, name): + def normalize(self, name: str) -> str: """ Return the normalized form of the housenumber. """ # shortcut for number-only numbers, which make up 90% of the data. if RE_NON_DIGIT.search(name) is None: return name - norm = self.trans.transliterate(self.norm.transliterate(name)) + norm = cast(str, self.trans.transliterate(self.norm.transliterate(name))) # If there is a significant non-numeric part, use as is. if RE_NAMED_PART.search(norm) is None: # Otherwise add optional spaces between digits and letters. @@ -60,7 +61,7 @@ class HousenumberTokenAnalysis: return norm - def get_variants_ascii(self, norm_name): + def get_variants_ascii(self, norm_name: str) -> List[str]: """ Compute the spelling variants for the given normalized housenumber. Generates variants for optional spaces (marked with '␣'). diff --git a/nominatim/tokenizer/token_analysis/postcodes.py b/nominatim/tokenizer/token_analysis/postcodes.py index 18fc2a8d..15b20bf9 100644 --- a/nominatim/tokenizer/token_analysis/postcodes.py +++ b/nominatim/tokenizer/token_analysis/postcodes.py @@ -8,19 +8,20 @@ Specialized processor for postcodes. Supports a 'lookup' variant of the token, which produces variants with optional spaces. """ +from typing import Mapping, Any, List from nominatim.tokenizer.token_analysis.generic_mutation import MutationVariantGenerator ### Configuration section -def configure(rules, normalization_rules): # pylint: disable=W0613 +def configure(rules: Mapping[str, Any], normalization_rules: str) -> None: # pylint: disable=W0613 """ All behaviour is currently hard-coded. """ return None ### Analysis section -def create(normalizer, transliterator, config): # pylint: disable=W0613 +def create(normalizer: Any, transliterator: Any, config: None) -> 'PostcodeTokenAnalysis': # pylint: disable=W0613 """ Create a new token analysis instance for this module. """ return PostcodeTokenAnalysis(normalizer, transliterator) @@ -38,20 +39,20 @@ class PostcodeTokenAnalysis: and transliteration, so that postcodes are correctly recognised by the search algorithm. """ - def __init__(self, norm, trans): + def __init__(self, norm: Any, trans: Any) -> None: self.norm = norm self.trans = trans self.mutator = MutationVariantGenerator(' ', (' ', '')) - def normalize(self, name): + def normalize(self, name: str) -> str: """ Return the standard form of the postcode. """ return name.strip().upper() - def get_variants_ascii(self, norm_name): + def get_variants_ascii(self, norm_name: str) -> List[str]: """ Compute the spelling variants for the given normalized postcode. Takes the canonical form of the postcode, normalizes it using the diff --git a/nominatim/tools/add_osm_data.py b/nominatim/tools/add_osm_data.py index b4e77b21..fc016fec 100644 --- a/nominatim/tools/add_osm_data.py +++ b/nominatim/tools/add_osm_data.py @@ -7,6 +7,7 @@ """ Function to add additional OSM data from a file or the API into the database. """ +from typing import Any, MutableMapping from pathlib import Path import logging import urllib @@ -15,7 +16,7 @@ from nominatim.tools.exec_utils import run_osm2pgsql, get_url LOG = logging.getLogger() -def add_data_from_file(fname, options): +def add_data_from_file(fname: str, options: MutableMapping[str, Any]) -> int: """ Adds data from a OSM file to the database. The file may be a normal OSM file or a diff file in all formats supported by libosmium. """ @@ -27,7 +28,8 @@ def add_data_from_file(fname, options): return 0 -def add_osm_object(osm_type, osm_id, use_main_api, options): +def add_osm_object(osm_type: str, osm_id: int, use_main_api: bool, + options: MutableMapping[str, Any]) -> int: """ Add or update a single OSM object from the latest version of the API. """ @@ -50,3 +52,5 @@ def add_osm_object(osm_type, osm_id, use_main_api, options): options['import_data'] = get_url(base_url).encode('utf-8') run_osm2pgsql(options) + + return 0 diff --git a/nominatim/tools/admin.py b/nominatim/tools/admin.py index 1bf217e2..9fb944d3 100644 --- a/nominatim/tools/admin.py +++ b/nominatim/tools/admin.py @@ -7,22 +7,27 @@ """ Functions for database analysis and maintenance. """ +from typing import Optional, Tuple, Any, cast import logging from psycopg2.extras import Json, register_hstore -from nominatim.db.connection import connect +from nominatim.config import Configuration +from nominatim.db.connection import connect, Cursor from nominatim.tokenizer import factory as tokenizer_factory from nominatim.errors import UsageError from nominatim.data.place_info import PlaceInfo +from nominatim.typing import DictCursorResult LOG = logging.getLogger() -def _get_place_info(cursor, osm_id, place_id): +def _get_place_info(cursor: Cursor, osm_id: Optional[str], + place_id: Optional[int]) -> DictCursorResult: sql = """SELECT place_id, extra.* FROM placex, LATERAL placex_indexing_prepare(placex) as extra """ + values: Tuple[Any, ...] if osm_id: osm_type = osm_id[0].upper() if osm_type not in 'NWR' or not osm_id[1:].isdigit(): @@ -44,10 +49,11 @@ def _get_place_info(cursor, osm_id, place_id): LOG.fatal("OSM object %s not found in database.", osm_id) raise UsageError("OSM object not found") - return cursor.fetchone() + return cast(DictCursorResult, cursor.fetchone()) # type: ignore[no-untyped-call] -def analyse_indexing(config, osm_id=None, place_id=None): +def analyse_indexing(config: Configuration, osm_id: Optional[str] = None, + place_id: Optional[int] = None) -> None: """ Analyse indexing of a single Nominatim object. """ with connect(config.get_libpq_dsn()) as conn: diff --git a/nominatim/tools/check_database.py b/nominatim/tools/check_database.py index 7ac31271..e5cefe4f 100644 --- a/nominatim/tools/check_database.py +++ b/nominatim/tools/check_database.py @@ -7,10 +7,12 @@ """ Collection of functions that check if the database is complete and functional. """ +from typing import Callable, Optional, Any, Union, Tuple, Mapping, List from enum import Enum from textwrap import dedent -from nominatim.db.connection import connect +from nominatim.config import Configuration +from nominatim.db.connection import connect, Connection from nominatim.errors import UsageError from nominatim.tokenizer import factory as tokenizer_factory @@ -25,14 +27,17 @@ class CheckState(Enum): NOT_APPLICABLE = 3 WARN = 4 -def _check(hint=None): +CheckResult = Union[CheckState, Tuple[CheckState, Mapping[str, Any]]] +CheckFunc = Callable[[Connection, Configuration], CheckResult] + +def _check(hint: Optional[str] = None) -> Callable[[CheckFunc], CheckFunc]: """ Decorator for checks. It adds the function to the list of checks to execute and adds the code for printing progress messages. """ - def decorator(func): - title = func.__doc__.split('\n', 1)[0].strip() + def decorator(func: CheckFunc) -> CheckFunc: + title = (func.__doc__ or '').split('\n', 1)[0].strip() - def run_check(conn, config): + def run_check(conn: Connection, config: Configuration) -> CheckState: print(title, end=' ... ') ret = func(conn, config) if isinstance(ret, tuple): @@ -61,20 +66,20 @@ def _check(hint=None): class _BadConnection: - def __init__(self, msg): + def __init__(self, msg: str) -> None: self.msg = msg - def close(self): + def close(self) -> None: """ Dummy function to provide the implementation. """ -def check_database(config): +def check_database(config: Configuration) -> int: """ Run a number of checks on the database and return the status. """ try: conn = connect(config.get_libpq_dsn()).connection except UsageError as err: - conn = _BadConnection(str(err)) + conn = _BadConnection(str(err)) # type: ignore[assignment] overall_result = 0 for check in CHECKLIST: @@ -89,7 +94,7 @@ def check_database(config): return overall_result -def _get_indexes(conn): +def _get_indexes(conn: Connection) -> List[str]: indexes = ['idx_place_addressline_address_place_id', 'idx_placex_rank_search', 'idx_placex_rank_address', @@ -131,7 +136,7 @@ def _get_indexes(conn): Project directory: {config.project_dir} Current setting of NOMINATIM_DATABASE_DSN: {config.DATABASE_DSN} """) -def check_connection(conn, config): +def check_connection(conn: Any, config: Configuration) -> CheckResult: """ Checking database connection """ if isinstance(conn, _BadConnection): @@ -149,7 +154,7 @@ def check_connection(conn, config): Project directory: {config.project_dir} Current setting of NOMINATIM_DATABASE_DSN: {config.DATABASE_DSN} """) -def check_placex_table(conn, config): +def check_placex_table(conn: Connection, config: Configuration) -> CheckResult: """ Checking for placex table """ if conn.table_exists('placex'): @@ -159,7 +164,7 @@ def check_placex_table(conn, config): @_check(hint="""placex table has no data. Did the import finish sucessfully?""") -def check_placex_size(conn, _): +def check_placex_size(conn: Connection, _: Configuration) -> CheckResult: """ Checking for placex content """ with conn.cursor() as cur: @@ -169,7 +174,7 @@ def check_placex_size(conn, _): @_check(hint="""{msg}""") -def check_tokenizer(_, config): +def check_tokenizer(_: Connection, config: Configuration) -> CheckResult: """ Checking that tokenizer works """ try: @@ -191,7 +196,7 @@ def check_tokenizer(_, config): Quality of search results may be degraded. Reverse geocoding is unaffected. See https://nominatim.org/release-docs/latest/admin/Import/#wikipediawikidata-rankings """) -def check_existance_wikipedia(conn, _): +def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult: """ Checking for wikipedia/wikidata data """ if not conn.table_exists('search_name'): @@ -208,7 +213,7 @@ def check_existance_wikipedia(conn, _): To index the remaining entries, run: {index_cmd} """) -def check_indexing(conn, _): +def check_indexing(conn: Connection, _: Configuration) -> CheckResult: """ Checking indexing status """ with conn.cursor() as cur: @@ -233,7 +238,7 @@ def check_indexing(conn, _): Rerun the index creation with: nominatim import --continue db-postprocess """) -def check_database_indexes(conn, _): +def check_database_indexes(conn: Connection, _: Configuration) -> CheckResult: """ Checking that database indexes are complete """ missing = [] @@ -255,7 +260,7 @@ def check_database_indexes(conn, _): Invalid indexes: {indexes} """) -def check_database_index_valid(conn, _): +def check_database_index_valid(conn: Connection, _: Configuration) -> CheckResult: """ Checking that all database indexes are valid """ with conn.cursor() as cur: @@ -275,7 +280,7 @@ def check_database_index_valid(conn, _): {error} Run TIGER import again: nominatim add-data --tiger-data """) -def check_tiger_table(conn, config): +def check_tiger_table(conn: Connection, config: Configuration) -> CheckResult: """ Checking TIGER external data table. """ if not config.get_bool('USE_US_TIGER_DATA'): diff --git a/nominatim/tools/database_import.py b/nominatim/tools/database_import.py index 6195b44a..fa60abf2 100644 --- a/nominatim/tools/database_import.py +++ b/nominatim/tools/database_import.py @@ -7,6 +7,7 @@ """ Functions for setting up and importing a new Nominatim database. """ +from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any import logging import os import selectors @@ -16,7 +17,8 @@ from pathlib import Path import psutil from psycopg2 import sql as pysql -from nominatim.db.connection import connect, get_pg_env +from nominatim.config import Configuration +from nominatim.db.connection import connect, get_pg_env, Connection from nominatim.db.async_connection import DBConnection from nominatim.db.sql_preprocessor import SQLPreprocessor from nominatim.tools.exec_utils import run_osm2pgsql @@ -25,7 +27,7 @@ from nominatim.version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERS LOG = logging.getLogger() -def _require_version(module, actual, expected): +def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None: """ Compares the version for the given module and raises an exception if the actual version is too old. """ @@ -36,7 +38,7 @@ def _require_version(module, actual, expected): raise UsageError(f'{module} is too old.') -def setup_database_skeleton(dsn, rouser=None): +def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None: """ Create a new database for Nominatim and populate it with the essential extensions. @@ -80,7 +82,9 @@ def setup_database_skeleton(dsn, rouser=None): POSTGIS_REQUIRED_VERSION) -def import_osm_data(osm_files, options, drop=False, ignore_errors=False): +def import_osm_data(osm_files: Union[Path, Sequence[Path]], + options: MutableMapping[str, Any], + drop: bool = False, ignore_errors: bool = False) -> None: """ Import the given OSM files. 'options' contains the list of default settings for osm2pgsql. """ @@ -91,7 +95,7 @@ def import_osm_data(osm_files, options, drop=False, ignore_errors=False): if not options['flatnode_file'] and options['osm2pgsql_cache'] == 0: # Make some educated guesses about cache size based on the size # of the import file and the available memory. - mem = psutil.virtual_memory() + mem = psutil.virtual_memory() # type: ignore[no-untyped-call] fsize = 0 if isinstance(osm_files, list): for fname in osm_files: @@ -117,7 +121,7 @@ def import_osm_data(osm_files, options, drop=False, ignore_errors=False): Path(options['flatnode_file']).unlink() -def create_tables(conn, config, reverse_only=False): +def create_tables(conn: Connection, config: Configuration, reverse_only: bool = False) -> None: """ Create the set of basic tables. When `reverse_only` is True, then the main table for searching will be skipped and only reverse search is possible. @@ -128,7 +132,7 @@ def create_tables(conn, config, reverse_only=False): sql.run_sql_file(conn, 'tables.sql') -def create_table_triggers(conn, config): +def create_table_triggers(conn: Connection, config: Configuration) -> None: """ Create the triggers for the tables. The trigger functions must already have been imported with refresh.create_functions(). """ @@ -136,14 +140,14 @@ def create_table_triggers(conn, config): sql.run_sql_file(conn, 'table-triggers.sql') -def create_partition_tables(conn, config): +def create_partition_tables(conn: Connection, config: Configuration) -> None: """ Create tables that have explicit partitioning. """ sql = SQLPreprocessor(conn, config) sql.run_sql_file(conn, 'partition-tables.src.sql') -def truncate_data_tables(conn): +def truncate_data_tables(conn: Connection) -> None: """ Truncate all data tables to prepare for a fresh load. """ with conn.cursor() as cur: @@ -174,7 +178,7 @@ _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier, 'extratags', 'geometry'))) -def load_data(dsn, threads): +def load_data(dsn: str, threads: int) -> None: """ Copy data into the word and placex table. """ sel = selectors.DefaultSelector() @@ -216,12 +220,12 @@ def load_data(dsn, threads): print('.', end='', flush=True) print('\n') - with connect(dsn) as conn: - with conn.cursor() as cur: + with connect(dsn) as syn_conn: + with syn_conn.cursor() as cur: cur.execute('ANALYSE') -def create_search_indices(conn, config, drop=False): +def create_search_indices(conn: Connection, config: Configuration, drop: bool = False) -> None: """ Create tables that have explicit partitioning. """ diff --git a/nominatim/tools/exec_utils.py b/nominatim/tools/exec_utils.py index a81a8d6b..610e2182 100644 --- a/nominatim/tools/exec_utils.py +++ b/nominatim/tools/exec_utils.py @@ -7,17 +7,22 @@ """ Helper functions for executing external programs. """ +from typing import Any, Union, Optional, Mapping, IO +from pathlib import Path import logging import subprocess import urllib.request as urlrequest from urllib.parse import urlencode +from nominatim.typing import StrPath from nominatim.version import version_str from nominatim.db.connection import get_pg_env LOG = logging.getLogger() -def run_legacy_script(script, *args, nominatim_env=None, throw_on_fail=False): +def run_legacy_script(script: StrPath, *args: Union[int, str], + nominatim_env: Any, + throw_on_fail: bool = False) -> int: """ Run a Nominatim PHP script with the given arguments. Returns the exit code of the script. If `throw_on_fail` is True @@ -40,8 +45,10 @@ def run_legacy_script(script, *args, nominatim_env=None, throw_on_fail=False): return proc.returncode -def run_api_script(endpoint, project_dir, extra_env=None, phpcgi_bin=None, - params=None): +def run_api_script(endpoint: str, project_dir: Path, + extra_env: Optional[Mapping[str, str]] = None, + phpcgi_bin: Optional[Path] = None, + params: Optional[Mapping[str, Any]] = None) -> int: """ Execute a Nominatim API function. The function needs a project directory that contains the website @@ -96,14 +103,14 @@ def run_api_script(endpoint, project_dir, extra_env=None, phpcgi_bin=None, return 0 -def run_php_server(server_address, base_dir): +def run_php_server(server_address: str, base_dir: StrPath) -> None: """ Run the built-in server from the given directory. """ subprocess.run(['/usr/bin/env', 'php', '-S', server_address], cwd=str(base_dir), check=True) -def run_osm2pgsql(options): +def run_osm2pgsql(options: Mapping[str, Any]) -> None: """ Run osm2pgsql with the given options. """ env = get_pg_env(options['dsn']) @@ -147,13 +154,14 @@ def run_osm2pgsql(options): env=env, check=True) -def get_url(url): +def get_url(url: str) -> str: """ Get the contents from the given URL and return it as a UTF-8 string. """ headers = {"User-Agent": f"Nominatim/{version_str()}"} try: - with urlrequest.urlopen(urlrequest.Request(url, headers=headers)) as response: + request = urlrequest.Request(url, headers=headers) + with urlrequest.urlopen(request) as response: # type: IO[bytes] return response.read().decode('utf-8') except Exception: LOG.fatal('Failed to load URL: %s', url) diff --git a/nominatim/tools/freeze.py b/nominatim/tools/freeze.py index e502c963..39c3279d 100644 --- a/nominatim/tools/freeze.py +++ b/nominatim/tools/freeze.py @@ -7,10 +7,13 @@ """ Functions for removing unnecessary data from the database. """ +from typing import Optional from pathlib import Path from psycopg2 import sql as pysql +from nominatim.db.connection import Connection + UPDATE_TABLES = [ 'address_levels', 'gb_postcode', @@ -25,7 +28,7 @@ UPDATE_TABLES = [ 'wikipedia_%' ] -def drop_update_tables(conn): +def drop_update_tables(conn: Connection) -> None: """ Drop all tables only necessary for updating the database from OSM replication data. """ @@ -42,10 +45,8 @@ def drop_update_tables(conn): conn.commit() -def drop_flatnode_file(fname): +def drop_flatnode_file(fpath: Optional[Path]) -> None: """ Remove the flatnode file if it exists. """ - if fname: - fpath = Path(fname) - if fpath.exists(): - fpath.unlink() + if fpath and fpath.exists(): + fpath.unlink() diff --git a/nominatim/tools/migration.py b/nominatim/tools/migration.py index 28a14455..aa86bcc8 100644 --- a/nominatim/tools/migration.py +++ b/nominatim/tools/migration.py @@ -7,12 +7,14 @@ """ Functions for database migration to newer software versions. """ +from typing import List, Tuple, Callable, Any import logging from psycopg2 import sql as pysql +from nominatim.config import Configuration from nominatim.db import properties -from nominatim.db.connection import connect +from nominatim.db.connection import connect, Connection from nominatim.version import NOMINATIM_VERSION, version_str from nominatim.tools import refresh from nominatim.tokenizer import factory as tokenizer_factory @@ -20,9 +22,11 @@ from nominatim.errors import UsageError LOG = logging.getLogger() -_MIGRATION_FUNCTIONS = [] +VersionTuple = Tuple[int, int, int, int] -def migrate(config, paths): +_MIGRATION_FUNCTIONS : List[Tuple[VersionTuple, Callable[..., None]]] = [] + +def migrate(config: Configuration, paths: Any) -> int: """ Check for the current database version and execute migrations, if necesssary. """ @@ -48,7 +52,8 @@ def migrate(config, paths): has_run_migration = False for version, func in _MIGRATION_FUNCTIONS: if db_version <= version: - LOG.warning("Runnning: %s (%s)", func.__doc__.split('\n', 1)[0], + title = func.__doc__ or '' + LOG.warning("Runnning: %s (%s)", title.split('\n', 1)[0], version_str(version)) kwargs = dict(conn=conn, config=config, paths=paths) func(**kwargs) @@ -68,7 +73,7 @@ def migrate(config, paths): return 0 -def _guess_version(conn): +def _guess_version(conn: Connection) -> VersionTuple: """ Guess a database version when there is no property table yet. Only migrations for 3.6 and later are supported, so bail out when the version seems older. @@ -88,7 +93,8 @@ def _guess_version(conn): -def _migration(major, minor, patch=0, dbpatch=0): +def _migration(major: int, minor: int, patch: int = 0, + dbpatch: int = 0) -> Callable[[Callable[..., None]], Callable[..., None]]: """ Decorator for a single migration step. The parameters describe the version after which the migration is applicable, i.e before changing from the given version to the next, the migration is required. @@ -101,7 +107,7 @@ def _migration(major, minor, patch=0, dbpatch=0): process, so the migration functions may leave a temporary state behind there. """ - def decorator(func): + def decorator(func: Callable[..., None]) -> Callable[..., None]: _MIGRATION_FUNCTIONS.append(((major, minor, patch, dbpatch), func)) return func @@ -109,7 +115,7 @@ def _migration(major, minor, patch=0, dbpatch=0): @_migration(3, 5, 0, 99) -def import_status_timestamp_change(conn, **_): +def import_status_timestamp_change(conn: Connection, **_: Any) -> None: """ Add timezone to timestamp in status table. The import_status table has been changed to include timezone information @@ -121,7 +127,7 @@ def import_status_timestamp_change(conn, **_): @_migration(3, 5, 0, 99) -def add_nominatim_property_table(conn, config, **_): +def add_nominatim_property_table(conn: Connection, config: Configuration, **_: Any) -> None: """ Add nominatim_property table. """ if not conn.table_exists('nominatim_properties'): @@ -133,7 +139,7 @@ def add_nominatim_property_table(conn, config, **_): """).format(pysql.Identifier(config.DATABASE_WEBUSER))) @_migration(3, 6, 0, 0) -def change_housenumber_transliteration(conn, **_): +def change_housenumber_transliteration(conn: Connection, **_: Any) -> None: """ Transliterate housenumbers. The database schema switched from saving raw housenumbers in @@ -164,7 +170,7 @@ def change_housenumber_transliteration(conn, **_): @_migration(3, 7, 0, 0) -def switch_placenode_geometry_index(conn, **_): +def switch_placenode_geometry_index(conn: Connection, **_: Any) -> None: """ Replace idx_placex_geometry_reverse_placeNode index. Make the index slightly more permissive, so that it can also be used @@ -181,7 +187,7 @@ def switch_placenode_geometry_index(conn, **_): @_migration(3, 7, 0, 1) -def install_legacy_tokenizer(conn, config, **_): +def install_legacy_tokenizer(conn: Connection, config: Configuration, **_: Any) -> None: """ Setup legacy tokenizer. If no other tokenizer has been configured yet, then create the @@ -200,11 +206,11 @@ def install_legacy_tokenizer(conn, config, **_): tokenizer = tokenizer_factory.create_tokenizer(config, init_db=False, module_name='legacy') - tokenizer.migrate_database(config) + tokenizer.migrate_database(config) # type: ignore[attr-defined] @_migration(4, 0, 99, 0) -def create_tiger_housenumber_index(conn, **_): +def create_tiger_housenumber_index(conn: Connection, **_: Any) -> None: """ Create idx_location_property_tiger_parent_place_id with included house number. @@ -221,7 +227,7 @@ def create_tiger_housenumber_index(conn, **_): @_migration(4, 0, 99, 1) -def create_interpolation_index_on_place(conn, **_): +def create_interpolation_index_on_place(conn: Connection, **_: Any) -> None: """ Create idx_place_interpolations for lookup of interpolation lines on updates. """ @@ -232,7 +238,7 @@ def create_interpolation_index_on_place(conn, **_): @_migration(4, 0, 99, 2) -def add_step_column_for_interpolation(conn, **_): +def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None: """ Add a new column 'step' to the interpolations table. Also convers the data into the stricter format which requires that @@ -267,7 +273,7 @@ def add_step_column_for_interpolation(conn, **_): @_migration(4, 0, 99, 3) -def add_step_column_for_tiger(conn, **_): +def add_step_column_for_tiger(conn: Connection, **_: Any) -> None: """ Add a new column 'step' to the tiger data table. """ if conn.table_has_column('location_property_tiger', 'step'): @@ -282,7 +288,7 @@ def add_step_column_for_tiger(conn, **_): @_migration(4, 0, 99, 4) -def add_derived_name_column_for_country_names(conn, **_): +def add_derived_name_column_for_country_names(conn: Connection, **_: Any) -> None: """ Add a new column 'derived_name' which in the future takes the country names as imported from OSM data. """ @@ -292,7 +298,7 @@ def add_derived_name_column_for_country_names(conn, **_): @_migration(4, 0, 99, 5) -def mark_internal_country_names(conn, config, **_): +def mark_internal_country_names(conn: Connection, config: Configuration, **_: Any) -> None: """ Names from the country table should be marked as internal to prevent them from being deleted. Only necessary for ICU tokenizer. """ diff --git a/nominatim/tools/postcodes.py b/nominatim/tools/postcodes.py index 9c66719b..7171e25d 100644 --- a/nominatim/tools/postcodes.py +++ b/nominatim/tools/postcodes.py @@ -8,7 +8,9 @@ Functions for importing, updating and otherwise maintaining the table of artificial postcode centroids. """ +from typing import Optional, Tuple, Dict, List, TextIO from collections import defaultdict +from pathlib import Path import csv import gzip import logging @@ -16,18 +18,19 @@ from math import isfinite from psycopg2 import sql as pysql -from nominatim.db.connection import connect +from nominatim.db.connection import connect, Connection from nominatim.utils.centroid import PointsCentroid -from nominatim.data.postcode_format import PostcodeFormatter +from nominatim.data.postcode_format import PostcodeFormatter, CountryPostcodeMatcher +from nominatim.tokenizer.base import AbstractAnalyzer, AbstractTokenizer LOG = logging.getLogger() -def _to_float(num, max_value): +def _to_float(numstr: str, max_value: float) -> float: """ Convert the number in string into a float. The number is expected to be in the range of [-max_value, max_value]. Otherwise rises a ValueError. """ - num = float(num) + num = float(numstr) if not isfinite(num) or num <= -max_value or num >= max_value: raise ValueError() @@ -37,18 +40,19 @@ class _PostcodeCollector: """ Collector for postcodes of a single country. """ - def __init__(self, country, matcher): + def __init__(self, country: str, matcher: Optional[CountryPostcodeMatcher]): self.country = country self.matcher = matcher - self.collected = defaultdict(PointsCentroid) - self.normalization_cache = None + self.collected: Dict[str, PointsCentroid] = defaultdict(PointsCentroid) + self.normalization_cache: Optional[Tuple[str, Optional[str]]] = None - def add(self, postcode, x, y): + def add(self, postcode: str, x: float, y: float) -> None: """ Add the given postcode to the collection cache. If the postcode already existed, it is overwritten with the new centroid. """ if self.matcher is not None: + normalized: Optional[str] if self.normalization_cache and self.normalization_cache[0] == postcode: normalized = self.normalization_cache[1] else: @@ -60,7 +64,7 @@ class _PostcodeCollector: self.collected[normalized] += (x, y) - def commit(self, conn, analyzer, project_dir): + def commit(self, conn: Connection, analyzer: AbstractAnalyzer, project_dir: Path) -> None: """ Update postcodes for the country from the postcodes selected so far as well as any externally supplied postcodes. """ @@ -94,7 +98,8 @@ class _PostcodeCollector: """).format(pysql.Literal(self.country)), to_update) - def _compute_changes(self, conn): + def _compute_changes(self, conn: Connection) \ + -> Tuple[List[Tuple[str, float, float]], List[str], List[Tuple[str, float, float]]]: """ Compute which postcodes from the collected postcodes have to be added or modified and which from the location_postcode table have to be deleted. @@ -116,12 +121,12 @@ class _PostcodeCollector: to_delete.append(postcode) to_add = [(k, *v.centroid()) for k, v in self.collected.items()] - self.collected = None + self.collected = defaultdict(PointsCentroid) return to_add, to_delete, to_update - def _update_from_external(self, analyzer, project_dir): + def _update_from_external(self, analyzer: AbstractAnalyzer, project_dir: Path) -> None: """ Look for an external postcode file for the active country in the project directory and add missing postcodes when found. """ @@ -151,7 +156,7 @@ class _PostcodeCollector: csvfile.close() - def _open_external(self, project_dir): + def _open_external(self, project_dir: Path) -> Optional[TextIO]: fname = project_dir / f'{self.country}_postcodes.csv' if fname.is_file(): @@ -167,7 +172,7 @@ class _PostcodeCollector: return None -def update_postcodes(dsn, project_dir, tokenizer): +def update_postcodes(dsn: str, project_dir: Path, tokenizer: AbstractTokenizer) -> None: """ Update the table of artificial postcodes. Computes artificial postcode centroids from the placex table, @@ -220,7 +225,7 @@ def update_postcodes(dsn, project_dir, tokenizer): analyzer.update_postcodes_from_db() -def can_compute(dsn): +def can_compute(dsn: str) -> bool: """ Check that the place table exists so that postcodes can be computed. diff --git a/nominatim/tools/refresh.py b/nominatim/tools/refresh.py index 561bcf83..9c5b7b08 100644 --- a/nominatim/tools/refresh.py +++ b/nominatim/tools/refresh.py @@ -7,12 +7,15 @@ """ Functions for bringing auxiliary data in the database up-to-date. """ +from typing import MutableSequence, Tuple, Any, Type, Mapping, Sequence, List, cast import logging from textwrap import dedent from pathlib import Path from psycopg2 import sql as pysql +from nominatim.config import Configuration +from nominatim.db.connection import Connection from nominatim.db.utils import execute_file from nominatim.db.sql_preprocessor import SQLPreprocessor from nominatim.version import version_str @@ -21,7 +24,8 @@ LOG = logging.getLogger() OSM_TYPE = {'N': 'node', 'W': 'way', 'R': 'relation'} -def _add_address_level_rows_from_entry(rows, entry): +def _add_address_level_rows_from_entry(rows: MutableSequence[Tuple[Any, ...]], + entry: Mapping[str, Any]) -> None: """ Converts a single entry from the JSON format for address rank descriptions into a flat format suitable for inserting into a PostgreSQL table and adds these lines to `rows`. @@ -38,14 +42,15 @@ def _add_address_level_rows_from_entry(rows, entry): for country in countries: rows.append((country, key, value, rank_search, rank_address)) -def load_address_levels(conn, table, levels): + +def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[str, Any]]) -> None: """ Replace the `address_levels` table with the contents of `levels'. A new table is created any previously existing table is dropped. The table has the following columns: country, class, type, rank_search, rank_address """ - rows = [] + rows: List[Tuple[Any, ...]] = [] for entry in levels: _add_address_level_rows_from_entry(rows, entry) @@ -69,7 +74,7 @@ def load_address_levels(conn, table, levels): conn.commit() -def load_address_levels_from_config(conn, config): +def load_address_levels_from_config(conn: Connection, config: Configuration) -> None: """ Replace the `address_levels` table with the content as defined in the given configuration. Uses the parameter NOMINATIM_ADDRESS_LEVEL_CONFIG to determine the location of the @@ -79,7 +84,9 @@ def load_address_levels_from_config(conn, config): load_address_levels(conn, 'address_levels', cfg) -def create_functions(conn, config, enable_diff_updates=True, enable_debug=False): +def create_functions(conn: Connection, config: Configuration, + enable_diff_updates: bool = True, + enable_debug: bool = False) -> None: """ (Re)create the PL/pgSQL functions. """ sql = SQLPreprocessor(conn, config) @@ -116,7 +123,7 @@ PHP_CONST_DEFS = ( ) -def import_wikipedia_articles(dsn, data_path, ignore_errors=False): +def import_wikipedia_articles(dsn: str, data_path: Path, ignore_errors: bool = False) -> int: """ Replaces the wikipedia importance tables with new data. The import is run in a single transaction so that the new data is replace seemlessly. @@ -140,7 +147,7 @@ def import_wikipedia_articles(dsn, data_path, ignore_errors=False): return 0 -def recompute_importance(conn): +def recompute_importance(conn: Connection) -> None: """ Recompute wikipedia links and importance for all entries in placex. This is a long-running operations that must not be executed in parallel with updates. @@ -163,18 +170,19 @@ def recompute_importance(conn): conn.commit() -def _quote_php_variable(var_type, config, conf_name): +def _quote_php_variable(var_type: Type[Any], config: Configuration, + conf_name: str) -> str: if var_type == bool: return 'true' if config.get_bool(conf_name) else 'false' if var_type == int: - return getattr(config, conf_name) + return cast(str, getattr(config, conf_name)) if not getattr(config, conf_name): return 'false' if var_type == Path: - value = str(config.get_path(conf_name)) + value = str(config.get_path(conf_name) or '') else: value = getattr(config, conf_name) @@ -182,7 +190,7 @@ def _quote_php_variable(var_type, config, conf_name): return f"'{quoted}'" -def setup_website(basedir, config, conn): +def setup_website(basedir: Path, config: Configuration, conn: Connection) -> None: """ Create the website script stubs. """ if not basedir.exists(): @@ -215,7 +223,8 @@ def setup_website(basedir, config, conn): (basedir / script).write_text(template.format(script), 'utf-8') -def invalidate_osm_object(osm_type, osm_id, conn, recursive=True): +def invalidate_osm_object(osm_type: str, osm_id: int, conn: Connection, + recursive: bool = True) -> None: """ Mark the given OSM object for reindexing. When 'recursive' is set to True (the default), then all dependent objects are marked for reindexing as well. diff --git a/nominatim/tools/replication.py b/nominatim/tools/replication.py index 53571706..db706bf6 100644 --- a/nominatim/tools/replication.py +++ b/nominatim/tools/replication.py @@ -7,6 +7,7 @@ """ Functions for updating a database from a replication source. """ +from typing import ContextManager, MutableMapping, Any, Generator, cast from contextlib import contextmanager import datetime as dt from enum import Enum @@ -14,6 +15,7 @@ import logging import time from nominatim.db import status +from nominatim.db.connection import Connection from nominatim.tools.exec_utils import run_osm2pgsql from nominatim.errors import UsageError @@ -21,13 +23,13 @@ try: from osmium.replication.server import ReplicationServer from osmium import WriteHandler except ImportError as exc: - logging.getLogger().fatal("pyosmium not installed. Replication functions not available.\n" - "To install pyosmium via pip: pip3 install osmium") + logging.getLogger().critical("pyosmium not installed. Replication functions not available.\n" + "To install pyosmium via pip: pip3 install osmium") raise UsageError("replication tools not available") from exc LOG = logging.getLogger() -def init_replication(conn, base_url): +def init_replication(conn: Connection, base_url: str) -> None: """ Set up replication for the server at the given base URL. """ LOG.info("Using replication source: %s", base_url) @@ -51,7 +53,7 @@ def init_replication(conn, base_url): LOG.warning("Updates initialised at sequence %s (%s)", seq, date) -def check_for_updates(conn, base_url): +def check_for_updates(conn: Connection, base_url: str) -> int: """ Check if new data is available from the replication service at the given base URL. """ @@ -84,7 +86,7 @@ class UpdateState(Enum): NO_CHANGES = 3 -def update(conn, options): +def update(conn: Connection, options: MutableMapping[str, Any]) -> UpdateState: """ Update database from the next batch of data. Returns the state of updates according to `UpdateState`. """ @@ -95,6 +97,8 @@ def update(conn, options): "Please run 'nominatim replication --init' first.") raise UsageError("Replication not set up.") + assert startdate is not None + if not indexed and options['indexed_only']: LOG.info("Skipping update. There is data that needs indexing.") return UpdateState.MORE_PENDING @@ -132,17 +136,17 @@ def update(conn, options): return UpdateState.UP_TO_DATE -def _make_replication_server(url): +def _make_replication_server(url: str) -> ContextManager[ReplicationServer]: """ Returns a ReplicationServer in form of a context manager. Creates a light wrapper around older versions of pyosmium that did not support the context manager interface. """ if hasattr(ReplicationServer, '__enter__'): - return ReplicationServer(url) + return cast(ContextManager[ReplicationServer], ReplicationServer(url)) @contextmanager - def get_cm(): + def get_cm() -> Generator[ReplicationServer, None, None]: yield ReplicationServer(url) return get_cm() diff --git a/nominatim/tools/special_phrases/importer_statistics.py b/nominatim/tools/special_phrases/importer_statistics.py index b1a9c438..0bb118c8 100644 --- a/nominatim/tools/special_phrases/importer_statistics.py +++ b/nominatim/tools/special_phrases/importer_statistics.py @@ -12,15 +12,14 @@ import logging LOG = logging.getLogger() class SpecialPhrasesImporterStatistics(): - # pylint: disable-msg=too-many-instance-attributes """ Class handling statistics of the import process of special phrases. """ - def __init__(self): + def __init__(self) -> None: self._intialize_values() - def _intialize_values(self): + def _intialize_values(self) -> None: """ Set all counts for the global import to 0. @@ -30,32 +29,32 @@ class SpecialPhrasesImporterStatistics(): self.tables_ignored = 0 self.invalids = 0 - def notify_one_phrase_invalid(self): + def notify_one_phrase_invalid(self) -> None: """ Add +1 to the count of invalid entries fetched from the wiki. """ self.invalids += 1 - def notify_one_table_created(self): + def notify_one_table_created(self) -> None: """ Add +1 to the count of created tables. """ self.tables_created += 1 - def notify_one_table_deleted(self): + def notify_one_table_deleted(self) -> None: """ Add +1 to the count of deleted tables. """ self.tables_deleted += 1 - def notify_one_table_ignored(self): + def notify_one_table_ignored(self) -> None: """ Add +1 to the count of ignored tables. """ self.tables_ignored += 1 - def notify_import_done(self): + def notify_import_done(self) -> None: """ Print stats for the whole import process and reset all values. diff --git a/nominatim/tools/special_phrases/sp_csv_loader.py b/nominatim/tools/special_phrases/sp_csv_loader.py index 0bd93c00..400f9fa9 100644 --- a/nominatim/tools/special_phrases/sp_csv_loader.py +++ b/nominatim/tools/special_phrases/sp_csv_loader.py @@ -9,6 +9,7 @@ The class allows to load phrases from a csv file. """ +from typing import Iterable import csv import os from nominatim.tools.special_phrases.special_phrase import SpecialPhrase @@ -18,12 +19,11 @@ class SPCsvLoader: """ Handles loading of special phrases from external csv file. """ - def __init__(self, csv_path): - super().__init__() + def __init__(self, csv_path: str) -> None: self.csv_path = csv_path - def generate_phrases(self): + def generate_phrases(self) -> Iterable[SpecialPhrase]: """ Open and parse the given csv file. Create the corresponding SpecialPhrases. """ @@ -35,7 +35,7 @@ class SPCsvLoader: yield SpecialPhrase(row['phrase'], row['class'], row['type'], row['operator']) - def _check_csv_validity(self): + def _check_csv_validity(self) -> None: """ Check that the csv file has the right extension. """ diff --git a/nominatim/tools/special_phrases/sp_importer.py b/nominatim/tools/special_phrases/sp_importer.py index 31bbc355..8906e03e 100644 --- a/nominatim/tools/special_phrases/sp_importer.py +++ b/nominatim/tools/special_phrases/sp_importer.py @@ -13,19 +13,36 @@ The phrases already present in the database which are not valids anymore are removed. """ +from typing import Iterable, Tuple, Mapping, Sequence, Optional, Set import logging import re from psycopg2.sql import Identifier, SQL + +from nominatim.config import Configuration +from nominatim.db.connection import Connection from nominatim.tools.special_phrases.importer_statistics import SpecialPhrasesImporterStatistics +from nominatim.tools.special_phrases.special_phrase import SpecialPhrase +from nominatim.tokenizer.base import AbstractTokenizer +from nominatim.typing import Protocol LOG = logging.getLogger() -def _classtype_table(phrase_class, phrase_type): +def _classtype_table(phrase_class: str, phrase_type: str) -> str: """ Return the name of the table for the given class and type. """ return f'place_classtype_{phrase_class}_{phrase_type}' + +class SpecialPhraseLoader(Protocol): + """ Protocol for classes implementing a loader for special phrases. + """ + + def generate_phrases(self) -> Iterable[SpecialPhrase]: + """ Generates all special phrase terms this loader can produce. + """ + + class SPImporter(): # pylint: disable-msg=too-many-instance-attributes """ @@ -33,21 +50,22 @@ class SPImporter(): Take a sp loader which load the phrases from an external source. """ - def __init__(self, config, db_connection, sp_loader) -> None: + def __init__(self, config: Configuration, conn: Connection, + sp_loader: SpecialPhraseLoader) -> None: self.config = config - self.db_connection = db_connection + self.db_connection = conn self.sp_loader = sp_loader self.statistics_handler = SpecialPhrasesImporterStatistics() self.black_list, self.white_list = self._load_white_and_black_lists() self.sanity_check_pattern = re.compile(r'^\w+$') # This set will contain all existing phrases to be added. # It contains tuples with the following format: (lable, class, type, operator) - self.word_phrases = set() + self.word_phrases: Set[Tuple[str, str, str, str]] = set() # This set will contain all existing place_classtype tables which doesn't match any # special phrases class/type on the wiki. - self.table_phrases_to_delete = set() + self.table_phrases_to_delete: Set[str] = set() - def import_phrases(self, tokenizer, should_replace): + def import_phrases(self, tokenizer: AbstractTokenizer, should_replace: bool) -> None: """ Iterate through all SpecialPhrases extracted from the loader and import them into the database. @@ -67,7 +85,7 @@ class SPImporter(): if result: class_type_pairs.add(result) - self._create_place_classtype_table_and_indexes(class_type_pairs) + self._create_classtype_table_and_indexes(class_type_pairs) if should_replace: self._remove_non_existent_tables_from_db() self.db_connection.commit() @@ -79,7 +97,7 @@ class SPImporter(): self.statistics_handler.notify_import_done() - def _fetch_existing_place_classtype_tables(self): + def _fetch_existing_place_classtype_tables(self) -> None: """ Fetch existing place_classtype tables. Fill the table_phrases_to_delete set of the class. @@ -95,7 +113,8 @@ class SPImporter(): for row in db_cursor: self.table_phrases_to_delete.add(row[0]) - def _load_white_and_black_lists(self): + def _load_white_and_black_lists(self) \ + -> Tuple[Mapping[str, Sequence[str]], Mapping[str, Sequence[str]]]: """ Load white and black lists from phrases-settings.json. """ @@ -103,7 +122,7 @@ class SPImporter(): return settings['blackList'], settings['whiteList'] - def _check_sanity(self, phrase): + def _check_sanity(self, phrase: SpecialPhrase) -> bool: """ Check sanity of given inputs in case somebody added garbage in the wiki. If a bad class/type is detected the system will exit with an error. @@ -117,7 +136,7 @@ class SPImporter(): return False return True - def _process_phrase(self, phrase): + def _process_phrase(self, phrase: SpecialPhrase) -> Optional[Tuple[str, str]]: """ Processes the given phrase by checking black and white list and sanity. @@ -145,7 +164,8 @@ class SPImporter(): return (phrase.p_class, phrase.p_type) - def _create_place_classtype_table_and_indexes(self, class_type_pairs): + def _create_classtype_table_and_indexes(self, + class_type_pairs: Iterable[Tuple[str, str]]) -> None: """ Create table place_classtype for each given pair. Also create indexes on place_id and centroid. @@ -188,7 +208,8 @@ class SPImporter(): db_cursor.execute("DROP INDEX idx_placex_classtype") - def _create_place_classtype_table(self, sql_tablespace, phrase_class, phrase_type): + def _create_place_classtype_table(self, sql_tablespace: str, + phrase_class: str, phrase_type: str) -> None: """ Create table place_classtype of the given phrase_class/phrase_type if doesn't exit. @@ -204,7 +225,8 @@ class SPImporter(): (phrase_class, phrase_type)) - def _create_place_classtype_indexes(self, sql_tablespace, phrase_class, phrase_type): + def _create_place_classtype_indexes(self, sql_tablespace: str, + phrase_class: str, phrase_type: str) -> None: """ Create indexes on centroid and place_id for the place_classtype table. """ @@ -227,7 +249,7 @@ class SPImporter(): SQL(sql_tablespace))) - def _grant_access_to_webuser(self, phrase_class, phrase_type): + def _grant_access_to_webuser(self, phrase_class: str, phrase_type: str) -> None: """ Grant access on read to the table place_classtype for the webuser. """ @@ -237,7 +259,7 @@ class SPImporter(): .format(Identifier(table_name), Identifier(self.config.DATABASE_WEBUSER))) - def _remove_non_existent_tables_from_db(self): + def _remove_non_existent_tables_from_db(self) -> None: """ Remove special phrases which doesn't exist on the wiki anymore. Delete the place_classtype tables. diff --git a/nominatim/tools/special_phrases/sp_wiki_loader.py b/nominatim/tools/special_phrases/sp_wiki_loader.py index ca4758ac..e71c2ec0 100644 --- a/nominatim/tools/special_phrases/sp_wiki_loader.py +++ b/nominatim/tools/special_phrases/sp_wiki_loader.py @@ -7,14 +7,17 @@ """ Module containing the SPWikiLoader class. """ +from typing import Iterable import re import logging + +from nominatim.config import Configuration from nominatim.tools.special_phrases.special_phrase import SpecialPhrase from nominatim.tools.exec_utils import get_url LOG = logging.getLogger() -def _get_wiki_content(lang): +def _get_wiki_content(lang: str) -> str: """ Request and return the wiki page's content corresponding to special phrases for a given lang. @@ -30,8 +33,7 @@ class SPWikiLoader: """ Handles loading of special phrases from the wiki. """ - def __init__(self, config): - super().__init__() + def __init__(self, config: Configuration) -> None: self.config = config # Compile the regex here to increase performances. self.occurence_pattern = re.compile( @@ -39,10 +41,15 @@ class SPWikiLoader: ) # Hack around a bug where building=yes was imported with quotes into the wiki self.type_fix_pattern = re.compile(r'\"|"') - self._load_languages() + + self.languages = self.config.get_str_list('LANGUAGES') or \ + ['af', 'ar', 'br', 'ca', 'cs', 'de', 'en', 'es', + 'et', 'eu', 'fa', 'fi', 'fr', 'gl', 'hr', 'hu', + 'ia', 'is', 'it', 'ja', 'mk', 'nl', 'no', 'pl', + 'ps', 'pt', 'ru', 'sk', 'sl', 'sv', 'uk', 'vi'] - def generate_phrases(self): + def generate_phrases(self) -> Iterable[SpecialPhrase]: """ Download the wiki pages for the configured languages and extract the phrases from the page. """ @@ -58,19 +65,3 @@ class SPWikiLoader: match[1], self.type_fix_pattern.sub('', match[2]), match[3]) - - - def _load_languages(self): - """ - Get list of all languages from env config file - or default if there is no languages configured. - The system will extract special phrases only from all specified languages. - """ - if self.config.LANGUAGES: - self.languages = self.config.get_str_list('LANGUAGES') - else: - self.languages = [ - 'af', 'ar', 'br', 'ca', 'cs', 'de', 'en', 'es', - 'et', 'eu', 'fa', 'fi', 'fr', 'gl', 'hr', 'hu', - 'ia', 'is', 'it', 'ja', 'mk', 'nl', 'no', 'pl', - 'ps', 'pt', 'ru', 'sk', 'sl', 'sv', 'uk', 'vi'] diff --git a/nominatim/tools/special_phrases/special_phrase.py b/nominatim/tools/special_phrases/special_phrase.py index 16935ccf..40f6a9e4 100644 --- a/nominatim/tools/special_phrases/special_phrase.py +++ b/nominatim/tools/special_phrases/special_phrase.py @@ -10,20 +10,21 @@ This class is a model used to transfer a special phrase through the process of load and importation. """ +from typing import Any + class SpecialPhrase: """ Model representing a special phrase. """ - def __init__(self, p_label, p_class, p_type, p_operator): + def __init__(self, p_label: str, p_class: str, p_type: str, p_operator: str) -> None: self.p_label = p_label.strip() self.p_class = p_class.strip() - # Hack around a bug where building=yes was imported with quotes into the wiki self.p_type = p_type.strip() # Needed if some operator in the wiki are not written in english p_operator = p_operator.strip().lower() self.p_operator = '-' if p_operator not in ('near', 'in') else p_operator - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, SpecialPhrase): return False @@ -32,5 +33,5 @@ class SpecialPhrase: and self.p_type == other.p_type \ and self.p_operator == other.p_operator - def __hash__(self): + def __hash__(self) -> int: return hash((self.p_label, self.p_class, self.p_type, self.p_operator)) diff --git a/nominatim/tools/tiger_data.py b/nominatim/tools/tiger_data.py index e78dcd8f..4a32bb1e 100644 --- a/nominatim/tools/tiger_data.py +++ b/nominatim/tools/tiger_data.py @@ -7,6 +7,7 @@ """ Functions for importing tiger data and handling tarbar and directory files """ +from typing import Any, TextIO, List, Union, cast import csv import io import logging @@ -15,11 +16,13 @@ import tarfile from psycopg2.extras import Json +from nominatim.config import Configuration from nominatim.db.connection import connect from nominatim.db.async_connection import WorkerPool from nominatim.db.sql_preprocessor import SQLPreprocessor from nominatim.errors import UsageError from nominatim.data.place_info import PlaceInfo +from nominatim.tokenizer.base import AbstractAnalyzer, AbstractTokenizer LOG = logging.getLogger() @@ -28,9 +31,9 @@ class TigerInput: either be in a directory or gzipped together in a tar file. """ - def __init__(self, data_dir): + def __init__(self, data_dir: str) -> None: self.tar_handle = None - self.files = [] + self.files: List[Union[str, tarfile.TarInfo]] = [] if data_dir.endswith('.tar.gz'): try: @@ -50,33 +53,36 @@ class TigerInput: LOG.warning("Tiger data import selected but no files found at %s", data_dir) - def __enter__(self): + def __enter__(self) -> 'TigerInput': return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if self.tar_handle: self.tar_handle.close() self.tar_handle = None - def next_file(self): + def next_file(self) -> TextIO: """ Return a file handle to the next file to be processed. Raises an IndexError if there is no file left. """ fname = self.files.pop(0) if self.tar_handle is not None: - return io.TextIOWrapper(self.tar_handle.extractfile(fname)) + extracted = self.tar_handle.extractfile(fname) + assert extracted is not None + return io.TextIOWrapper(extracted) - return open(fname, encoding='utf-8') + return open(cast(str, fname), encoding='utf-8') - def __len__(self): + def __len__(self) -> int: return len(self.files) -def handle_threaded_sql_statements(pool, fd, analyzer): +def handle_threaded_sql_statements(pool: WorkerPool, fd: TextIO, + analyzer: AbstractAnalyzer) -> None: """ Handles sql statement with multiplexing """ lines = 0 @@ -101,14 +107,15 @@ def handle_threaded_sql_statements(pool, fd, analyzer): lines = 0 -def add_tiger_data(data_dir, config, threads, tokenizer): +def add_tiger_data(data_dir: str, config: Configuration, threads: int, + tokenizer: AbstractTokenizer) -> int: """ Import tiger data from directory or tar file `data dir`. """ dsn = config.get_libpq_dsn() with TigerInput(data_dir) as tar: if not tar: - return + return 1 with connect(dsn) as conn: sql = SQLPreprocessor(conn, config) @@ -130,3 +137,5 @@ def add_tiger_data(data_dir, config, threads, tokenizer): with connect(dsn) as conn: sql = SQLPreprocessor(conn, config) sql.run_sql_file(conn, 'tiger_import_finish.sql') + + return 0 diff --git a/nominatim/typing.py b/nominatim/typing.py new file mode 100644 index 00000000..308f3e6a --- /dev/null +++ b/nominatim/typing.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2022 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Type definitions for typing annotations. + +Complex type definitions are moved here, to keep the source files readable. +""" +from typing import Any, Union, Mapping, TypeVar, Sequence, TYPE_CHECKING + +# Generics varaible names do not confirm to naming styles, ignore globally here. +# pylint: disable=invalid-name,abstract-method,multiple-statements +# pylint: disable=missing-class-docstring,useless-import-alias + +if TYPE_CHECKING: + import psycopg2.sql + import psycopg2.extensions + import psycopg2.extras + import os + +StrPath = Union[str, 'os.PathLike[str]'] + +SysEnv = Mapping[str, str] + +# psycopg2-related types + +Query = Union[str, bytes, 'psycopg2.sql.Composable'] + +T_ResultKey = TypeVar('T_ResultKey', int, str) + +class DictCursorResult(Mapping[str, Any]): + def __getitem__(self, x: Union[int, str]) -> Any: ... + +DictCursorResults = Sequence[DictCursorResult] + +T_cursor = TypeVar('T_cursor', bound='psycopg2.extensions.cursor') + +# The following typing features require typing_extensions to work +# on all supported Python versions. +# Only require this for type checking but not for normal operations. + +if TYPE_CHECKING: + from typing_extensions import (Protocol as Protocol, + Final as Final, + TypedDict as TypedDict) +else: + Protocol = object + Final = 'Final' + TypedDict = dict diff --git a/nominatim/utils/centroid.py b/nominatim/utils/centroid.py index c2bd6192..21823176 100644 --- a/nominatim/utils/centroid.py +++ b/nominatim/utils/centroid.py @@ -7,6 +7,7 @@ """ Functions for computation of centroids. """ +from typing import Tuple, Any from collections.abc import Collection class PointsCentroid: @@ -17,12 +18,12 @@ class PointsCentroid: (i.e. in OSM style). """ - def __init__(self): + def __init__(self) -> None: self.sum_x = 0 self.sum_y = 0 self.count = 0 - def centroid(self): + def centroid(self) -> Tuple[float, float]: """ Return the centroid of all points collected so far. """ if self.count == 0: @@ -32,11 +33,11 @@ class PointsCentroid: float(self.sum_y/self.count)/10000000) - def __len__(self): + def __len__(self) -> int: return self.count - def __iadd__(self, other): + def __iadd__(self, other: Any) -> 'PointsCentroid': if isinstance(other, Collection) and len(other) == 2: if all(isinstance(p, (float, int)) for p in other): x, y = other diff --git a/nominatim/version.py b/nominatim/version.py index 88d42af9..f950b8ef 100644 --- a/nominatim/version.py +++ b/nominatim/version.py @@ -7,6 +7,7 @@ """ Version information for Nominatim. """ +from typing import Optional, Tuple # Version information: major, minor, patch level, database patch level # @@ -33,11 +34,11 @@ POSTGIS_REQUIRED_VERSION = (2, 2) # on every execution of 'make'. # cmake/tool-installed.tmpl is used to build the binary 'nominatim'. Inside # there is a call to set the variable value below. -GIT_COMMIT_HASH = None +GIT_COMMIT_HASH : Optional[str] = None # pylint: disable=consider-using-f-string -def version_str(version=NOMINATIM_VERSION): +def version_str(version:Tuple[int, int, int, int] = NOMINATIM_VERSION) -> str: """ Return a human-readable string of the version. """ diff --git a/test/bdd/steps/nominatim_environment.py b/test/bdd/steps/nominatim_environment.py index 6b83c2e4..e7234788 100644 --- a/test/bdd/steps/nominatim_environment.py +++ b/test/bdd/steps/nominatim_environment.py @@ -15,7 +15,7 @@ sys.path.insert(1, str((Path(__file__) / '..' / '..' / '..' / '..').resolve())) from nominatim import cli from nominatim.config import Configuration -from nominatim.db.connection import _Connection +from nominatim.db.connection import Connection from nominatim.tools import refresh from nominatim.tokenizer import factory as tokenizer_factory from steps.utils import run_script @@ -61,7 +61,7 @@ class NominatimEnvironment: dbargs['user'] = self.db_user if self.db_pass: dbargs['password'] = self.db_pass - conn = psycopg2.connect(connection_factory=_Connection, **dbargs) + conn = psycopg2.connect(connection_factory=Connection, **dbargs) return conn def next_code_coverage_file(self): diff --git a/test/python/tokenizer/sanitizers/test_sanitizer_config.py b/test/python/tokenizer/sanitizers/test_sanitizer_config.py index 02794776..0dbbc7a0 100644 --- a/test/python/tokenizer/sanitizers/test_sanitizer_config.py +++ b/test/python/tokenizer/sanitizers/test_sanitizer_config.py @@ -82,32 +82,32 @@ def test_create_split_regex_empty_delimiter(): def test_create_kind_filter_no_params(inp): filt = SanitizerConfig().get_filter_kind() - assert filt(PlaceName('something', inp, '')) + assert filt(inp) @pytest.mark.parametrize('kind', ('de', 'name:de', 'ende')) def test_create_kind_filter_custom_regex_positive(kind): filt = SanitizerConfig({'filter-kind': '.*de'}).get_filter_kind() - assert filt(PlaceName('something', kind, '')) + assert filt(kind) @pytest.mark.parametrize('kind', ('de ', '123', '', 'bedece')) def test_create_kind_filter_custom_regex_negative(kind): filt = SanitizerConfig({'filter-kind': '.*de'}).get_filter_kind() - assert not filt(PlaceName('something', kind, '')) + assert not filt(kind) @pytest.mark.parametrize('kind', ('name', 'fr', 'name:fr', 'frfr', '34')) def test_create_kind_filter_many_positive(kind): filt = SanitizerConfig({'filter-kind': ['.*fr', 'name', r'\d+']}).get_filter_kind() - assert filt(PlaceName('something', kind, '')) + assert filt(kind) @pytest.mark.parametrize('kind', ('name:de', 'fridge', 'a34', '.*', '\\')) def test_create_kind_filter_many_negative(kind): filt = SanitizerConfig({'filter-kind': ['.*fr', 'name', r'\d+']}).get_filter_kind() - assert not filt(PlaceName('something', kind, '')) + assert not filt(kind) diff --git a/test/python/tools/test_freeze.py b/test/python/tools/test_freeze.py index 30b673ff..3ebb1730 100644 --- a/test/python/tools/test_freeze.py +++ b/test/python/tools/test_freeze.py @@ -39,17 +39,17 @@ def test_drop_tables(temp_db_conn, temp_db_cursor, table_factory): assert not temp_db_cursor.table_exists(table) def test_drop_flatnode_file_no_file(): - freeze.drop_flatnode_file('') + freeze.drop_flatnode_file(None) def test_drop_flatnode_file_file_already_gone(tmp_path): - freeze.drop_flatnode_file(str(tmp_path / 'something.store')) + freeze.drop_flatnode_file(tmp_path / 'something.store') def test_drop_flatnode_file_delte(tmp_path): flatfile = tmp_path / 'flatnode.store' flatfile.write_text('Some content') - freeze.drop_flatnode_file(str(flatfile)) + freeze.drop_flatnode_file(flatfile) assert not flatfile.exists() diff --git a/test/python/tools/test_import_special_phrases.py b/test/python/tools/test_import_special_phrases.py index 0dcf549c..75a6a066 100644 --- a/test/python/tools/test_import_special_phrases.py +++ b/test/python/tools/test_import_special_phrases.py @@ -128,7 +128,7 @@ def test_create_place_classtype_table_and_indexes( """ pairs = set([('class1', 'type1'), ('class2', 'type2')]) - sp_importer._create_place_classtype_table_and_indexes(pairs) + sp_importer._create_classtype_table_and_indexes(pairs) for pair in pairs: assert check_table_exist(temp_db_conn, pair[0], pair[1])