add type annotations for Tiger import function

This commit is contained in:
Sarah Hoffmann
2022-07-17 11:01:44 +02:00
parent 9963261d8d
commit 25d854dc5c

View File

@@ -7,6 +7,7 @@
"""
Functions for importing tiger data and handling tarbar and directory files
"""
from typing import Any, TextIO, List, Union, cast
import csv
import io
import logging
@@ -15,11 +16,13 @@ import tarfile
from psycopg2.extras import Json
from nominatim.config import Configuration
from nominatim.db.connection import connect
from nominatim.db.async_connection import WorkerPool
from nominatim.db.sql_preprocessor import SQLPreprocessor
from nominatim.errors import UsageError
from nominatim.data.place_info import PlaceInfo
from nominatim.tokenizer.base import AbstractAnalyzer, AbstractTokenizer
LOG = logging.getLogger()
@@ -28,9 +31,9 @@ class TigerInput:
either be in a directory or gzipped together in a tar file.
"""
def __init__(self, data_dir):
def __init__(self, data_dir: str) -> None:
self.tar_handle = None
self.files = []
self.files: List[Union[str, tarfile.TarInfo]] = []
if data_dir.endswith('.tar.gz'):
try:
@@ -50,33 +53,36 @@ class TigerInput:
LOG.warning("Tiger data import selected but no files found at %s", data_dir)
def __enter__(self):
def __enter__(self) -> 'TigerInput':
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self.tar_handle:
self.tar_handle.close()
self.tar_handle = None
def next_file(self):
def next_file(self) -> TextIO:
""" Return a file handle to the next file to be processed.
Raises an IndexError if there is no file left.
"""
fname = self.files.pop(0)
if self.tar_handle is not None:
return io.TextIOWrapper(self.tar_handle.extractfile(fname))
extracted = self.tar_handle.extractfile(fname)
assert extracted is not None
return io.TextIOWrapper(extracted)
return open(fname, encoding='utf-8')
return open(cast(str, fname), encoding='utf-8')
def __len__(self):
def __len__(self) -> int:
return len(self.files)
def handle_threaded_sql_statements(pool, fd, analyzer):
def handle_threaded_sql_statements(pool: WorkerPool, fd: TextIO,
analyzer: AbstractAnalyzer) -> None:
""" Handles sql statement with multiplexing
"""
lines = 0
@@ -101,7 +107,8 @@ def handle_threaded_sql_statements(pool, fd, analyzer):
lines = 0
def add_tiger_data(data_dir, config, threads, tokenizer):
def add_tiger_data(data_dir: str, config: Configuration, threads: int,
tokenizer: AbstractTokenizer) -> None:
""" Import tiger data from directory or tar file `data dir`.
"""
dsn = config.get_libpq_dsn()