add type annotations for Tiger import function

This commit is contained in:
Sarah Hoffmann
2022-07-17 11:01:44 +02:00
parent 9963261d8d
commit 25d854dc5c

View File

@@ -7,6 +7,7 @@
""" """
Functions for importing tiger data and handling tarbar and directory files Functions for importing tiger data and handling tarbar and directory files
""" """
from typing import Any, TextIO, List, Union, cast
import csv import csv
import io import io
import logging import logging
@@ -15,11 +16,13 @@ import tarfile
from psycopg2.extras import Json from psycopg2.extras import Json
from nominatim.config import Configuration
from nominatim.db.connection import connect from nominatim.db.connection import connect
from nominatim.db.async_connection import WorkerPool from nominatim.db.async_connection import WorkerPool
from nominatim.db.sql_preprocessor import SQLPreprocessor from nominatim.db.sql_preprocessor import SQLPreprocessor
from nominatim.errors import UsageError from nominatim.errors import UsageError
from nominatim.data.place_info import PlaceInfo from nominatim.data.place_info import PlaceInfo
from nominatim.tokenizer.base import AbstractAnalyzer, AbstractTokenizer
LOG = logging.getLogger() LOG = logging.getLogger()
@@ -28,9 +31,9 @@ class TigerInput:
either be in a directory or gzipped together in a tar file. either be in a directory or gzipped together in a tar file.
""" """
def __init__(self, data_dir): def __init__(self, data_dir: str) -> None:
self.tar_handle = None self.tar_handle = None
self.files = [] self.files: List[Union[str, tarfile.TarInfo]] = []
if data_dir.endswith('.tar.gz'): if data_dir.endswith('.tar.gz'):
try: try:
@@ -50,33 +53,36 @@ class TigerInput:
LOG.warning("Tiger data import selected but no files found at %s", data_dir) LOG.warning("Tiger data import selected but no files found at %s", data_dir)
def __enter__(self): def __enter__(self) -> 'TigerInput':
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self.tar_handle: if self.tar_handle:
self.tar_handle.close() self.tar_handle.close()
self.tar_handle = None self.tar_handle = None
def next_file(self): def next_file(self) -> TextIO:
""" Return a file handle to the next file to be processed. """ Return a file handle to the next file to be processed.
Raises an IndexError if there is no file left. Raises an IndexError if there is no file left.
""" """
fname = self.files.pop(0) fname = self.files.pop(0)
if self.tar_handle is not None: if self.tar_handle is not None:
return io.TextIOWrapper(self.tar_handle.extractfile(fname)) extracted = self.tar_handle.extractfile(fname)
assert extracted is not None
return io.TextIOWrapper(extracted)
return open(fname, encoding='utf-8') return open(cast(str, fname), encoding='utf-8')
def __len__(self): def __len__(self) -> int:
return len(self.files) return len(self.files)
def handle_threaded_sql_statements(pool, fd, analyzer): def handle_threaded_sql_statements(pool: WorkerPool, fd: TextIO,
analyzer: AbstractAnalyzer) -> None:
""" Handles sql statement with multiplexing """ Handles sql statement with multiplexing
""" """
lines = 0 lines = 0
@@ -101,7 +107,8 @@ def handle_threaded_sql_statements(pool, fd, analyzer):
lines = 0 lines = 0
def add_tiger_data(data_dir, config, threads, tokenizer): def add_tiger_data(data_dir: str, config: Configuration, threads: int,
tokenizer: AbstractTokenizer) -> None:
""" Import tiger data from directory or tar file `data dir`. """ Import tiger data from directory or tar file `data dir`.
""" """
dsn = config.get_libpq_dsn() dsn = config.get_libpq_dsn()