correctly quote strings when copying in data

Encapsulate the copy string in a class that ensures that
copy lines are written with correct quoting.
This commit is contained in:
Sarah Hoffmann
2021-06-10 09:36:43 +02:00
parent 2f6e4edcdb
commit a0a7b05c9f
5 changed files with 202 additions and 52 deletions

View File

@@ -4,6 +4,7 @@ Helper functions for handling DB accesses.
import subprocess
import logging
import gzip
import io
from nominatim.db.connection import get_pg_env
from nominatim.errors import UsageError
@@ -57,3 +58,49 @@ def execute_file(dsn, fname, ignore_errors=False, pre_code=None, post_code=None)
if ret != 0 or remain > 0:
raise UsageError("Failed to execute SQL file.")
# List of characters that need to be quoted for the copy command.
_SQL_TRANSLATION = {ord(u'\\') : u'\\\\',
ord(u'\t') : u'\\t',
ord(u'\n') : u'\\n'}
class CopyBuffer:
""" Data collector for the copy_from command.
"""
def __init__(self):
self.buffer = io.StringIO()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.buffer is not None:
self.buffer.close()
def add(self, *data):
""" Add another row of data to the copy buffer.
"""
first = True
for column in data:
if first:
first = False
else:
self.buffer.write('\t')
if column is None:
self.buffer.write('\\N')
else:
self.buffer.write(str(column).translate(_SQL_TRANSLATION))
self.buffer.write('\n')
def copy_out(self, cur, table, columns=None):
""" Copy all collected data into the given table.
"""
if self.buffer.tell() > 0:
self.buffer.seek(0)
cur.copy_from(self.buffer, table, columns=columns)