add type annotations for legacy tokenizer

This commit is contained in:
Sarah Hoffmann
2022-07-15 22:52:26 +02:00
parent e37cfc64d2
commit 18b16e06ca
3 changed files with 83 additions and 62 deletions

View File

@@ -7,7 +7,7 @@
""" """
Specialised connection and cursor functions. Specialised connection and cursor functions.
""" """
from typing import List, Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
import contextlib import contextlib
import logging import logging
import os import os
@@ -36,7 +36,7 @@ class Cursor(psycopg2.extras.DictCursor):
super().execute(query, args) super().execute(query, args)
def execute_values(self, sql: Query, argslist: List[Any], def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
template: Optional[str] = None) -> None: template: Optional[str] = None) -> None:
""" Wrapper for the psycopg2 convenience function to execute """ Wrapper for the psycopg2 convenience function to execute
SQL for a list of values. SQL for a list of values.

View File

@@ -9,7 +9,7 @@ Abstract class defintions for tokenizers. These base classes are here
mainly for documentation purposes. mainly for documentation purposes.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Dict, Any from typing import List, Tuple, Dict, Any, Optional
from pathlib import Path from pathlib import Path
from typing_extensions import Protocol from typing_extensions import Protocol
@@ -187,7 +187,7 @@ class AbstractTokenizer(ABC):
@abstractmethod @abstractmethod
def check_database(self, config: Configuration) -> str: def check_database(self, config: Configuration) -> Optional[str]:
""" Check that the database is set up correctly and ready for being """ Check that the database is set up correctly and ready for being
queried. queried.

View File

@@ -7,8 +7,10 @@
""" """
Tokenizer implementing normalisation as used before Nominatim 4. Tokenizer implementing normalisation as used before Nominatim 4.
""" """
from typing import Optional, Sequence, List, Tuple, Mapping, Any, Callable, cast, Dict, Set
from collections import OrderedDict from collections import OrderedDict
import logging import logging
from pathlib import Path
import re import re
import shutil import shutil
from textwrap import dedent from textwrap import dedent
@@ -17,10 +19,12 @@ from icu import Transliterator
import psycopg2 import psycopg2
import psycopg2.extras import psycopg2.extras
from nominatim.db.connection import connect from nominatim.db.connection import connect, Connection
from nominatim.config import Configuration
from nominatim.db import properties from nominatim.db import properties
from nominatim.db import utils as db_utils from nominatim.db import utils as db_utils
from nominatim.db.sql_preprocessor import SQLPreprocessor from nominatim.db.sql_preprocessor import SQLPreprocessor
from nominatim.data.place_info import PlaceInfo
from nominatim.errors import UsageError from nominatim.errors import UsageError
from nominatim.tokenizer.base import AbstractAnalyzer, AbstractTokenizer from nominatim.tokenizer.base import AbstractAnalyzer, AbstractTokenizer
@@ -29,13 +33,13 @@ DBCFG_MAXWORDFREQ = "tokenizer_maxwordfreq"
LOG = logging.getLogger() LOG = logging.getLogger()
def create(dsn, data_dir): def create(dsn: str, data_dir: Path) -> 'LegacyTokenizer':
""" Create a new instance of the tokenizer provided by this module. """ Create a new instance of the tokenizer provided by this module.
""" """
return LegacyTokenizer(dsn, data_dir) return LegacyTokenizer(dsn, data_dir)
def _install_module(config_module_path, src_dir, module_dir): def _install_module(config_module_path: str, src_dir: Path, module_dir: Path) -> str:
""" Copies the PostgreSQL normalisation module into the project """ Copies the PostgreSQL normalisation module into the project
directory if necessary. For historical reasons the module is directory if necessary. For historical reasons the module is
saved in the '/module' subdirectory and not with the other tokenizer saved in the '/module' subdirectory and not with the other tokenizer
@@ -52,7 +56,7 @@ def _install_module(config_module_path, src_dir, module_dir):
# Compatibility mode for builddir installations. # Compatibility mode for builddir installations.
if module_dir.exists() and src_dir.samefile(module_dir): if module_dir.exists() and src_dir.samefile(module_dir):
LOG.info('Running from build directory. Leaving database module as is.') LOG.info('Running from build directory. Leaving database module as is.')
return module_dir return str(module_dir)
# In any other case install the module in the project directory. # In any other case install the module in the project directory.
if not module_dir.exists(): if not module_dir.exists():
@@ -64,10 +68,10 @@ def _install_module(config_module_path, src_dir, module_dir):
LOG.info('Database module installed at %s', str(destfile)) LOG.info('Database module installed at %s', str(destfile))
return module_dir return str(module_dir)
def _check_module(module_dir, conn): def _check_module(module_dir: str, conn: Connection) -> None:
""" Try to use the PostgreSQL module to confirm that it is correctly """ Try to use the PostgreSQL module to confirm that it is correctly
installed and accessible from PostgreSQL. installed and accessible from PostgreSQL.
""" """
@@ -89,13 +93,13 @@ class LegacyTokenizer(AbstractTokenizer):
calls to the database. calls to the database.
""" """
def __init__(self, dsn, data_dir): def __init__(self, dsn: str, data_dir: Path) -> None:
self.dsn = dsn self.dsn = dsn
self.data_dir = data_dir self.data_dir = data_dir
self.normalization = None self.normalization: Optional[str] = None
def init_new_db(self, config, init_db=True): def init_new_db(self, config: Configuration, init_db: bool = True) -> None:
""" Set up a new tokenizer for the database. """ Set up a new tokenizer for the database.
This copies all necessary data in the project directory to make This copies all necessary data in the project directory to make
@@ -119,7 +123,7 @@ class LegacyTokenizer(AbstractTokenizer):
self._init_db_tables(config) self._init_db_tables(config)
def init_from_project(self, config): def init_from_project(self, config: Configuration) -> None:
""" Initialise the tokenizer from the project directory. """ Initialise the tokenizer from the project directory.
""" """
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
@@ -132,7 +136,7 @@ class LegacyTokenizer(AbstractTokenizer):
self._install_php(config, overwrite=False) self._install_php(config, overwrite=False)
def finalize_import(self, config): def finalize_import(self, config: Configuration) -> None:
""" Do any required postprocessing to make the tokenizer data ready """ Do any required postprocessing to make the tokenizer data ready
for use. for use.
""" """
@@ -141,7 +145,7 @@ class LegacyTokenizer(AbstractTokenizer):
sqlp.run_sql_file(conn, 'tokenizer/legacy_tokenizer_indices.sql') sqlp.run_sql_file(conn, 'tokenizer/legacy_tokenizer_indices.sql')
def update_sql_functions(self, config): def update_sql_functions(self, config: Configuration) -> None:
""" Reimport the SQL functions for this tokenizer. """ Reimport the SQL functions for this tokenizer.
""" """
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
@@ -154,7 +158,7 @@ class LegacyTokenizer(AbstractTokenizer):
modulepath=modulepath) modulepath=modulepath)
def check_database(self, _): def check_database(self, _: Configuration) -> Optional[str]:
""" Check that the tokenizer is set up correctly. """ Check that the tokenizer is set up correctly.
""" """
hint = """\ hint = """\
@@ -181,7 +185,7 @@ class LegacyTokenizer(AbstractTokenizer):
return None return None
def migrate_database(self, config): def migrate_database(self, config: Configuration) -> None:
""" Initialise the project directory of an existing database for """ Initialise the project directory of an existing database for
use with this tokenizer. use with this tokenizer.
@@ -198,7 +202,7 @@ class LegacyTokenizer(AbstractTokenizer):
self._save_config(conn, config) self._save_config(conn, config)
def update_statistics(self): def update_statistics(self) -> None:
""" Recompute the frequency of full words. """ Recompute the frequency of full words.
""" """
with connect(self.dsn) as conn: with connect(self.dsn) as conn:
@@ -218,13 +222,13 @@ class LegacyTokenizer(AbstractTokenizer):
conn.commit() conn.commit()
def update_word_tokens(self): def update_word_tokens(self) -> None:
""" No house-keeping implemented for the legacy tokenizer. """ No house-keeping implemented for the legacy tokenizer.
""" """
LOG.info("No tokenizer clean-up available.") LOG.info("No tokenizer clean-up available.")
def name_analyzer(self): def name_analyzer(self) -> 'LegacyNameAnalyzer':
""" Create a new analyzer for tokenizing names and queries """ Create a new analyzer for tokenizing names and queries
using this tokinzer. Analyzers are context managers and should using this tokinzer. Analyzers are context managers and should
be used accordingly: be used accordingly:
@@ -244,7 +248,7 @@ class LegacyTokenizer(AbstractTokenizer):
return LegacyNameAnalyzer(self.dsn, normalizer) return LegacyNameAnalyzer(self.dsn, normalizer)
def _install_php(self, config, overwrite=True): def _install_php(self, config: Configuration, overwrite: bool = True) -> None:
""" Install the php script for the tokenizer. """ Install the php script for the tokenizer.
""" """
php_file = self.data_dir / "tokenizer.php" php_file = self.data_dir / "tokenizer.php"
@@ -258,7 +262,7 @@ class LegacyTokenizer(AbstractTokenizer):
"""), encoding='utf-8') """), encoding='utf-8')
def _init_db_tables(self, config): def _init_db_tables(self, config: Configuration) -> None:
""" Set up the word table and fill it with pre-computed word """ Set up the word table and fill it with pre-computed word
frequencies. frequencies.
""" """
@@ -271,10 +275,12 @@ class LegacyTokenizer(AbstractTokenizer):
db_utils.execute_file(self.dsn, config.lib_dir.data / 'words.sql') db_utils.execute_file(self.dsn, config.lib_dir.data / 'words.sql')
def _save_config(self, conn, config): def _save_config(self, conn: Connection, config: Configuration) -> None:
""" Save the configuration that needs to remain stable for the given """ Save the configuration that needs to remain stable for the given
database as database properties. database as database properties.
""" """
assert self.normalization is not None
properties.set_property(conn, DBCFG_NORMALIZATION, self.normalization) properties.set_property(conn, DBCFG_NORMALIZATION, self.normalization)
properties.set_property(conn, DBCFG_MAXWORDFREQ, config.MAX_WORD_FREQUENCY) properties.set_property(conn, DBCFG_MAXWORDFREQ, config.MAX_WORD_FREQUENCY)
@@ -287,8 +293,8 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
normalization. normalization.
""" """
def __init__(self, dsn, normalizer): def __init__(self, dsn: str, normalizer: Any):
self.conn = connect(dsn).connection self.conn: Optional[Connection] = connect(dsn).connection
self.conn.autocommit = True self.conn.autocommit = True
self.normalizer = normalizer self.normalizer = normalizer
psycopg2.extras.register_hstore(self.conn) psycopg2.extras.register_hstore(self.conn)
@@ -296,7 +302,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
self._cache = _TokenCache(self.conn) self._cache = _TokenCache(self.conn)
def close(self): def close(self) -> None:
""" Free all resources used by the analyzer. """ Free all resources used by the analyzer.
""" """
if self.conn: if self.conn:
@@ -304,7 +310,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
self.conn = None self.conn = None
def get_word_token_info(self, words): def get_word_token_info(self, words: Sequence[str]) -> List[Tuple[str, str, int]]:
""" Return token information for the given list of words. """ Return token information for the given list of words.
If a word starts with # it is assumed to be a full name If a word starts with # it is assumed to be a full name
otherwise is a partial name. otherwise is a partial name.
@@ -315,6 +321,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
The function is used for testing and debugging only The function is used for testing and debugging only
and not necessarily efficient. and not necessarily efficient.
""" """
assert self.conn is not None
with self.conn.cursor() as cur: with self.conn.cursor() as cur:
cur.execute("""SELECT t.term, word_token, word_id cur.execute("""SELECT t.term, word_token, word_id
FROM word, (SELECT unnest(%s::TEXT[]) as term) t FROM word, (SELECT unnest(%s::TEXT[]) as term) t
@@ -330,14 +337,14 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
return [(r[0], r[1], r[2]) for r in cur] return [(r[0], r[1], r[2]) for r in cur]
def normalize(self, phrase): def normalize(self, phrase: str) -> str:
""" Normalize the given phrase, i.e. remove all properties that """ Normalize the given phrase, i.e. remove all properties that
are irrelevant for search. are irrelevant for search.
""" """
return self.normalizer.transliterate(phrase) return cast(str, self.normalizer.transliterate(phrase))
def normalize_postcode(self, postcode): def normalize_postcode(self, postcode: str) -> str:
""" Convert the postcode to a standardized form. """ Convert the postcode to a standardized form.
This function must yield exactly the same result as the SQL function This function must yield exactly the same result as the SQL function
@@ -346,10 +353,12 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
return postcode.strip().upper() return postcode.strip().upper()
def update_postcodes_from_db(self): def update_postcodes_from_db(self) -> None:
""" Update postcode tokens in the word table from the location_postcode """ Update postcode tokens in the word table from the location_postcode
table. table.
""" """
assert self.conn is not None
with self.conn.cursor() as cur: with self.conn.cursor() as cur:
# This finds us the rows in location_postcode and word that are # This finds us the rows in location_postcode and word that are
# missing in the other table. # missing in the other table.
@@ -383,9 +392,12 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
def update_special_phrases(self, phrases, should_replace): def update_special_phrases(self, phrases: Sequence[Tuple[str, str, str, str]],
should_replace: bool) -> None:
""" Replace the search index for special phrases with the new phrases. """ Replace the search index for special phrases with the new phrases.
""" """
assert self.conn is not None
norm_phrases = set(((self.normalize(p[0]), p[1], p[2], p[3]) norm_phrases = set(((self.normalize(p[0]), p[1], p[2], p[3])
for p in phrases)) for p in phrases))
@@ -422,9 +434,11 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
len(norm_phrases), len(to_add), len(to_delete)) len(norm_phrases), len(to_add), len(to_delete))
def add_country_names(self, country_code, names): def add_country_names(self, country_code: str, names: Mapping[str, str]) -> None:
""" Add names for the given country to the search index. """ Add names for the given country to the search index.
""" """
assert self.conn is not None
with self.conn.cursor() as cur: with self.conn.cursor() as cur:
cur.execute( cur.execute(
"""INSERT INTO word (word_id, word_token, country_code) """INSERT INTO word (word_id, word_token, country_code)
@@ -436,12 +450,14 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
""", (country_code, list(names.values()), country_code)) """, (country_code, list(names.values()), country_code))
def process_place(self, place): def process_place(self, place: PlaceInfo) -> Mapping[str, Any]:
""" Determine tokenizer information about the given place. """ Determine tokenizer information about the given place.
Returns a JSON-serialisable structure that will be handed into Returns a JSON-serialisable structure that will be handed into
the database via the token_info field. the database via the token_info field.
""" """
assert self.conn is not None
token_info = _TokenInfo(self._cache) token_info = _TokenInfo(self._cache)
names = place.name names = place.name
@@ -450,6 +466,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
token_info.add_names(self.conn, names) token_info.add_names(self.conn, names)
if place.is_country(): if place.is_country():
assert place.country_code is not None
self.add_country_names(place.country_code, names) self.add_country_names(place.country_code, names)
address = place.address address = place.address
@@ -459,7 +476,8 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
return token_info.data return token_info.data
def _process_place_address(self, token_info, address): def _process_place_address(self, token_info: '_TokenInfo', address: Mapping[str, str]) -> None:
assert self.conn is not None
hnrs = [] hnrs = []
addr_terms = [] addr_terms = []
@@ -491,12 +509,12 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
class _TokenInfo: class _TokenInfo:
""" Collect token information to be sent back to the database. """ Collect token information to be sent back to the database.
""" """
def __init__(self, cache): def __init__(self, cache: '_TokenCache') -> None:
self.cache = cache self.cache = cache
self.data = {} self.data: Dict[str, Any] = {}
def add_names(self, conn, names): def add_names(self, conn: Connection, names: Mapping[str, str]) -> None:
""" Add token information for the names of the place. """ Add token information for the names of the place.
""" """
with conn.cursor() as cur: with conn.cursor() as cur:
@@ -505,7 +523,7 @@ class _TokenInfo:
(names, )) (names, ))
def add_housenumbers(self, conn, hnrs): def add_housenumbers(self, conn: Connection, hnrs: Sequence[str]) -> None:
""" Extract housenumber information from the address. """ Extract housenumber information from the address.
""" """
if len(hnrs) == 1: if len(hnrs) == 1:
@@ -516,7 +534,7 @@ class _TokenInfo:
return return
# split numbers if necessary # split numbers if necessary
simple_list = [] simple_list: List[str] = []
for hnr in hnrs: for hnr in hnrs:
simple_list.extend((x.strip() for x in re.split(r'[;,]', hnr))) simple_list.extend((x.strip() for x in re.split(r'[;,]', hnr)))
@@ -525,49 +543,53 @@ class _TokenInfo:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute("SELECT * FROM create_housenumbers(%s)", (simple_list, )) cur.execute("SELECT * FROM create_housenumbers(%s)", (simple_list, ))
self.data['hnr_tokens'], self.data['hnr'] = cur.fetchone() self.data['hnr_tokens'], self.data['hnr'] = \
cur.fetchone() # type: ignore[no-untyped-call]
def set_postcode(self, postcode): def set_postcode(self, postcode: str) -> None:
""" Set or replace the postcode token with the given value. """ Set or replace the postcode token with the given value.
""" """
self.data['postcode'] = postcode self.data['postcode'] = postcode
def add_street(self, conn, street): def add_street(self, conn: Connection, street: str) -> None:
""" Add addr:street match terms. """ Add addr:street match terms.
""" """
def _get_street(name): def _get_street(name: str) -> List[int]:
with conn.cursor() as cur: with conn.cursor() as cur:
return cur.scalar("SELECT word_ids_from_name(%s)::text", (name, )) return cast(List[int],
cur.scalar("SELECT word_ids_from_name(%s)::text", (name, )))
tokens = self.cache.streets.get(street, _get_street) tokens = self.cache.streets.get(street, _get_street)
if tokens: if tokens:
self.data['street'] = tokens self.data['street'] = tokens
def add_place(self, conn, place): def add_place(self, conn: Connection, place: str) -> None:
""" Add addr:place search and match terms. """ Add addr:place search and match terms.
""" """
def _get_place(name): def _get_place(name: str) -> Tuple[List[int], List[int]]:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute("""SELECT make_keywords(hstore('name' , %s))::text, cur.execute("""SELECT make_keywords(hstore('name' , %s))::text,
word_ids_from_name(%s)::text""", word_ids_from_name(%s)::text""",
(name, name)) (name, name))
return cur.fetchone() return cast(Tuple[List[int], List[int]],
cur.fetchone()) # type: ignore[no-untyped-call]
self.data['place_search'], self.data['place_match'] = \ self.data['place_search'], self.data['place_match'] = \
self.cache.places.get(place, _get_place) self.cache.places.get(place, _get_place)
def add_address_terms(self, conn, terms): def add_address_terms(self, conn: Connection, terms: Sequence[Tuple[str, str]]) -> None:
""" Add additional address terms. """ Add additional address terms.
""" """
def _get_address_term(name): def _get_address_term(name: str) -> Tuple[List[int], List[int]]:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute("""SELECT addr_ids_from_name(%s)::text, cur.execute("""SELECT addr_ids_from_name(%s)::text,
word_ids_from_name(%s)::text""", word_ids_from_name(%s)::text""",
(name, name)) (name, name))
return cur.fetchone() return cast(Tuple[List[int], List[int]],
cur.fetchone()) # type: ignore[no-untyped-call]
tokens = {} tokens = {}
for key, value in terms: for key, value in terms:
@@ -584,13 +606,12 @@ class _LRU:
produce the item when there is a cache miss. produce the item when there is a cache miss.
""" """
def __init__(self, maxsize=128, init_data=None): def __init__(self, maxsize: int = 128):
self.data = init_data or OrderedDict() self.data: 'OrderedDict[str, Any]' = OrderedDict()
self.maxsize = maxsize self.maxsize = maxsize
if init_data is not None and len(init_data) > maxsize:
self.maxsize = len(init_data)
def get(self, key, generator):
def get(self, key: str, generator: Callable[[str], Any]) -> Any:
""" Get the item with the given key from the cache. If nothing """ Get the item with the given key from the cache. If nothing
is found in the cache, generate the value through the is found in the cache, generate the value through the
generator function and store it in the cache. generator function and store it in the cache.
@@ -613,7 +634,7 @@ class _TokenCache:
This cache is not thread-safe and needs to be instantiated per This cache is not thread-safe and needs to be instantiated per
analyzer. analyzer.
""" """
def __init__(self, conn): def __init__(self, conn: Connection):
# various LRU caches # various LRU caches
self.streets = _LRU(maxsize=256) self.streets = _LRU(maxsize=256)
self.places = _LRU(maxsize=128) self.places = _LRU(maxsize=128)
@@ -623,18 +644,18 @@ class _TokenCache:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute("""SELECT i, ARRAY[getorcreate_housenumber_id(i::text)]::text cur.execute("""SELECT i, ARRAY[getorcreate_housenumber_id(i::text)]::text
FROM generate_series(1, 100) as i""") FROM generate_series(1, 100) as i""")
self._cached_housenumbers = {str(r[0]): r[1] for r in cur} self._cached_housenumbers: Dict[str, str] = {str(r[0]): r[1] for r in cur}
# For postcodes remember the ones that have already been added # For postcodes remember the ones that have already been added
self.postcodes = set() self.postcodes: Set[str] = set()
def get_housenumber(self, number): def get_housenumber(self, number: str) -> Optional[str]:
""" Get a housenumber token from the cache. """ Get a housenumber token from the cache.
""" """
return self._cached_housenumbers.get(number) return self._cached_housenumbers.get(number)
def add_postcode(self, conn, postcode): def add_postcode(self, conn: Connection, postcode: str) -> None:
""" Make sure the given postcode is in the database. """ Make sure the given postcode is in the database.
""" """
if postcode not in self.postcodes: if postcode not in self.postcodes: