use context management when processing Tiger data

This commit is contained in:
Sarah Hoffmann
2022-05-11 09:45:15 +02:00
parent ae6b029543
commit 5d5f40a82f

View File

@@ -21,33 +21,57 @@ from nominatim.indexer.place_info import PlaceInfo
LOG = logging.getLogger() LOG = logging.getLogger()
class TigerInput:
def handle_tarfile_or_directory(data_dir): """ Context manager that goes through Tiger input files which may
""" Handles tarfile or directory for importing tiger data either be in a directory or gzipped together in a tar file.
""" """
tar = None def __init__(self, data_dir):
if data_dir.endswith('.tar.gz'): self.tar_handle = None
try: self.files = []
tar = tarfile.open(data_dir)
except tarfile.ReadError as err:
LOG.fatal("Cannot open '%s'. Is this a tar file?", data_dir)
raise UsageError("Cannot open Tiger data file.") from err
csv_files = [i for i in tar.getmembers() if i.name.endswith('.csv')] if data_dir.endswith('.tar.gz'):
LOG.warning("Found %d CSV files in tarfile with path %s", len(csv_files), data_dir) try:
if not csv_files: self.tar_handle = tarfile.open(data_dir) # pylint: disable=consider-using-with
LOG.warning("Tiger data import selected but no files in tarfile's path %s", data_dir) except tarfile.ReadError as err:
return None, None LOG.fatal("Cannot open '%s'. Is this a tar file?", data_dir)
else: raise UsageError("Cannot open Tiger data file.") from err
files = os.listdir(data_dir)
csv_files = [os.path.join(data_dir, i) for i in files if i.endswith('.csv')]
LOG.warning("Found %d CSV files in path %s", len(csv_files), data_dir)
if not csv_files:
LOG.warning("Tiger data import selected but no files found in path %s", data_dir)
return None, None
return csv_files, tar self.files = [i for i in self.tar_handle.getmembers() if i.name.endswith('.csv')]
LOG.warning("Found %d CSV files in tarfile with path %s", len(self.files), data_dir)
else:
files = os.listdir(data_dir)
self.files = [os.path.join(data_dir, i) for i in files if i.endswith('.csv')]
LOG.warning("Found %d CSV files in path %s", len(self.files), data_dir)
if not self.files:
LOG.warning("Tiger data import selected but no files found at %s", data_dir)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.tar_handle:
self.tar_handle.close()
self.tar_handle = None
def next_file(self):
""" Return a file handle to the next file to be processed.
Raises an IndexError if there is no file left.
"""
fname = self.files.pop(0)
if self.tar_handle is not None:
return io.TextIOWrapper(self.tar_handle.extractfile(fname))
return open(fname, encoding='utf-8')
def __len__(self):
return len(self.files)
def handle_threaded_sql_statements(pool, fd, analyzer): def handle_threaded_sql_statements(pool, fd, analyzer):
@@ -79,34 +103,27 @@ def add_tiger_data(data_dir, config, threads, tokenizer):
""" Import tiger data from directory or tar file `data dir`. """ Import tiger data from directory or tar file `data dir`.
""" """
dsn = config.get_libpq_dsn() dsn = config.get_libpq_dsn()
files, tar = handle_tarfile_or_directory(data_dir)
if not files: with TigerInput(data_dir) as tar:
return if not tar:
return
with connect(dsn) as conn: with connect(dsn) as conn:
sql = SQLPreprocessor(conn, config) sql = SQLPreprocessor(conn, config)
sql.run_sql_file(conn, 'tiger_import_start.sql') sql.run_sql_file(conn, 'tiger_import_start.sql')
# Reading files and then for each file line handling # Reading files and then for each file line handling
# sql_query in <threads - 1> chunks. # sql_query in <threads - 1> chunks.
place_threads = max(1, threads - 1) place_threads = max(1, threads - 1)
with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool: with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool:
with tokenizer.name_analyzer() as analyzer: with tokenizer.name_analyzer() as analyzer:
for fname in files: while tar:
if not tar: with tar.next_file() as fd:
fd = open(fname) handle_threaded_sql_statements(pool, fd, analyzer)
else:
fd = io.TextIOWrapper(tar.extractfile(fname))
handle_threaded_sql_statements(pool, fd, analyzer) print('\n')
fd.close()
if tar:
tar.close()
print('\n')
LOG.warning("Creating indexes on Tiger data") LOG.warning("Creating indexes on Tiger data")
with connect(dsn) as conn: with connect(dsn) as conn:
sql = SQLPreprocessor(conn, config) sql = SQLPreprocessor(conn, config)