correctly quote strings when copying in data

Encapsulate the copy string in a class that ensures that
copy lines are written with correct quoting.
This commit is contained in:
Sarah Hoffmann
2021-06-10 09:36:43 +02:00
parent 2f6e4edcdb
commit a0a7b05c9f
5 changed files with 202 additions and 52 deletions

View File

@@ -4,6 +4,7 @@ Helper functions for handling DB accesses.
import subprocess import subprocess
import logging import logging
import gzip import gzip
import io
from nominatim.db.connection import get_pg_env from nominatim.db.connection import get_pg_env
from nominatim.errors import UsageError from nominatim.errors import UsageError
@@ -57,3 +58,49 @@ def execute_file(dsn, fname, ignore_errors=False, pre_code=None, post_code=None)
if ret != 0 or remain > 0: if ret != 0 or remain > 0:
raise UsageError("Failed to execute SQL file.") raise UsageError("Failed to execute SQL file.")
# List of characters that need to be quoted for the copy command.
_SQL_TRANSLATION = {ord(u'\\') : u'\\\\',
ord(u'\t') : u'\\t',
ord(u'\n') : u'\\n'}
class CopyBuffer:
""" Data collector for the copy_from command.
"""
def __init__(self):
self.buffer = io.StringIO()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.buffer is not None:
self.buffer.close()
def add(self, *data):
""" Add another row of data to the copy buffer.
"""
first = True
for column in data:
if first:
first = False
else:
self.buffer.write('\t')
if column is None:
self.buffer.write('\\N')
else:
self.buffer.write(str(column).translate(_SQL_TRANSLATION))
self.buffer.write('\n')
def copy_out(self, cur, table, columns=None):
""" Copy all collected data into the given table.
"""
if self.buffer.tell() > 0:
self.buffer.seek(0)
cur.copy_from(self.buffer, table, columns=columns)

View File

@@ -93,7 +93,7 @@ class ICURuleLoader:
def _load_from_yaml(self): def _load_from_yaml(self):
rules = yaml.load(self.configfile.read_text()) rules = yaml.safe_load(self.configfile.read_text())
self.normalization_rules = self._cfg_to_icu_rules(rules, 'normalization') self.normalization_rules = self._cfg_to_icu_rules(rules, 'normalization')
self.transliteration_rules = self._cfg_to_icu_rules(rules, 'transliteration') self.transliteration_rules = self._cfg_to_icu_rules(rules, 'transliteration')
@@ -122,6 +122,9 @@ class ICURuleLoader:
""" """
content = self._get_section(rules, section) content = self._get_section(rules, section)
if content is None:
return ''
if isinstance(content, str): if isinstance(content, str):
return (self.configfile.parent / content).read_text().replace('\n', ' ') return (self.configfile.parent / content).read_text().replace('\n', ' ')
@@ -160,4 +163,5 @@ class ICURuleLoader:
abbrterms = (norm.transliterate(t.strip()) for t in parts[1].split(',')) abbrterms = (norm.transliterate(t.strip()) for t in parts[1].split(','))
for full, abbr in itertools.product(fullterms, abbrterms): for full, abbr in itertools.product(fullterms, abbrterms):
self.abbreviations[full].append(abbr) if full and abbr:
self.abbreviations[full].append(abbr)

View File

@@ -14,6 +14,7 @@ import psycopg2.extras
from nominatim.db.connection import connect from nominatim.db.connection import connect
from nominatim.db.properties import set_property, get_property from nominatim.db.properties import set_property, get_property
from nominatim.db.utils import CopyBuffer
from nominatim.db.sql_preprocessor import SQLPreprocessor from nominatim.db.sql_preprocessor import SQLPreprocessor
from nominatim.tokenizer.icu_rule_loader import ICURuleLoader from nominatim.tokenizer.icu_rule_loader import ICURuleLoader
from nominatim.tokenizer.icu_name_processor import ICUNameProcessor, ICUNameProcessorRules from nominatim.tokenizer.icu_name_processor import ICUNameProcessor, ICUNameProcessorRules
@@ -134,7 +135,7 @@ class LegacyICUTokenizer:
@define('CONST_Term_Normalization_Rules', "{0.term_normalization}"); @define('CONST_Term_Normalization_Rules', "{0.term_normalization}");
@define('CONST_Transliteration', "{0.naming_rules.search_rules}"); @define('CONST_Transliteration', "{0.naming_rules.search_rules}");
require_once('{1}/tokenizer/legacy_icu_tokenizer.php'); require_once('{1}/tokenizer/legacy_icu_tokenizer.php');
""".format(self, phpdir))) """.format(self, phpdir))) # pylint: disable=missing-format-attribute
def _save_config(self, config): def _save_config(self, config):
@@ -171,14 +172,15 @@ class LegacyICUTokenizer:
words[term] += cnt words[term] += cnt
# copy them back into the word table # copy them back into the word table
copystr = io.StringIO(''.join(('{}\t{}\n'.format(*args) for args in words.items()))) with CopyBuffer() as copystr:
for args in words.items():
copystr.add(*args)
with conn.cursor() as cur:
with conn.cursor() as cur: copystr.copy_out(cur, 'word',
copystr.seek(0) columns=['word_token', 'search_name_count'])
cur.copy_from(copystr, 'word', columns=['word_token', 'search_name_count']) cur.execute("""UPDATE word SET word_id = nextval('seq_word')
cur.execute("""UPDATE word SET word_id = nextval('seq_word') WHERE word_id is null""")
WHERE word_id is null""")
conn.commit() conn.commit()
@@ -265,7 +267,6 @@ class LegacyICUNameAnalyzer:
table. table.
""" """
to_delete = [] to_delete = []
copystr = io.StringIO()
with self.conn.cursor() as cur: with self.conn.cursor() as cur:
# This finds us the rows in location_postcode and word that are # This finds us the rows in location_postcode and word that are
# missing in the other table. # missing in the other table.
@@ -278,26 +279,25 @@ class LegacyICUNameAnalyzer:
ON pc = word) x ON pc = word) x
WHERE pc is null or word is null""") WHERE pc is null or word is null""")
for postcode, word in cur: with CopyBuffer() as copystr:
if postcode is None: for postcode, word in cur:
to_delete.append(word) if postcode is None:
else: to_delete.append(word)
copystr.write(postcode) else:
copystr.write('\t ') copystr.add(
copystr.write(self.name_processor.get_search_normalized(postcode)) postcode,
copystr.write('\tplace\tpostcode\t0\n') ' ' + self.name_processor.get_search_normalized(postcode),
'place', 'postcode', 0)
if to_delete: if to_delete:
cur.execute("""DELETE FROM WORD cur.execute("""DELETE FROM WORD
WHERE class ='place' and type = 'postcode' WHERE class ='place' and type = 'postcode'
and word = any(%s) and word = any(%s)
""", (to_delete, )) """, (to_delete, ))
if copystr.getvalue(): copystr.copy_out(cur, 'word',
copystr.seek(0) columns=['word', 'word_token', 'class', 'type',
cur.copy_from(copystr, 'word', 'search_name_count'])
columns=['word', 'word_token', 'class', 'type',
'search_name_count'])
def update_special_phrases(self, phrases, should_replace): def update_special_phrases(self, phrases, should_replace):
@@ -331,34 +331,24 @@ class LegacyICUNameAnalyzer:
""" """
to_add = new_phrases - existing_phrases to_add = new_phrases - existing_phrases
copystr = io.StringIO()
added = 0 added = 0
for word, cls, typ, oper in to_add: with CopyBuffer() as copystr:
term = self.name_processor.get_search_normalized(word) for word, cls, typ, oper in to_add:
if term: term = self.name_processor.get_search_normalized(word)
copystr.write(word) if term:
copystr.write('\t ') copystr.add(word, term, cls, typ,
copystr.write(term) oper if oper in ('in', 'near') else None, 0)
copystr.write('\t') added += 1
copystr.write(cls)
copystr.write('\t')
copystr.write(typ)
copystr.write('\t')
copystr.write(oper if oper in ('in', 'near') else '\\N')
copystr.write('\t0\n')
added += 1
copystr.copy_out(cursor, 'word',
if copystr.tell() > 0:
copystr.seek(0)
cursor.copy_from(copystr, 'word',
columns=['word', 'word_token', 'class', 'type', columns=['word', 'word_token', 'class', 'type',
'operator', 'search_name_count']) 'operator', 'search_name_count'])
return added return added
def _remove_special_phrases(self, cursor, new_phrases, existing_phrases): @staticmethod
def _remove_special_phrases(cursor, new_phrases, existing_phrases):
""" Remove all phrases from the databse that are no longer in the """ Remove all phrases from the databse that are no longer in the
new phrase list. new phrase list.
""" """

View File

@@ -50,3 +50,68 @@ def test_execute_file_with_post_code(dsn, tmp_path, temp_db_cursor):
db_utils.execute_file(dsn, tmpfile, post_code='INSERT INTO test VALUES(23)') db_utils.execute_file(dsn, tmpfile, post_code='INSERT INTO test VALUES(23)')
assert temp_db_cursor.row_set('SELECT * FROM test') == {(23, )} assert temp_db_cursor.row_set('SELECT * FROM test') == {(23, )}
class TestCopyBuffer:
TABLE_NAME = 'copytable'
@pytest.fixture(autouse=True)
def setup_test_table(self, table_factory):
table_factory(self.TABLE_NAME, 'colA INT, colB TEXT')
def table_rows(self, cursor):
return cursor.row_set('SELECT * FROM ' + self.TABLE_NAME)
def test_copybuffer_empty(self):
with db_utils.CopyBuffer() as buf:
buf.copy_out(None, "dummy")
def test_all_columns(self, temp_db_cursor):
with db_utils.CopyBuffer() as buf:
buf.add(3, 'hum')
buf.add(None, 'f\\t')
buf.copy_out(temp_db_cursor, self.TABLE_NAME)
assert self.table_rows(temp_db_cursor) == {(3, 'hum'), (None, 'f\\t')}
def test_selected_columns(self, temp_db_cursor):
with db_utils.CopyBuffer() as buf:
buf.add('foo')
buf.copy_out(temp_db_cursor, self.TABLE_NAME,
columns=['colB'])
assert self.table_rows(temp_db_cursor) == {(None, 'foo')}
def test_reordered_columns(self, temp_db_cursor):
with db_utils.CopyBuffer() as buf:
buf.add('one', 1)
buf.add(' two ', 2)
buf.copy_out(temp_db_cursor, self.TABLE_NAME,
columns=['colB', 'colA'])
assert self.table_rows(temp_db_cursor) == {(1, 'one'), (2, ' two ')}
def test_special_characters(self, temp_db_cursor):
with db_utils.CopyBuffer() as buf:
buf.add('foo\tbar')
buf.add('sun\nson')
buf.add('\\N')
buf.copy_out(temp_db_cursor, self.TABLE_NAME,
columns=['colB'])
assert self.table_rows(temp_db_cursor) == {(None, 'foo\tbar'),
(None, 'sun\nson'),
(None, '\\N')}

View File

@@ -21,6 +21,7 @@ def cfgfile(tmp_path, suffix='.yaml'):
- ":: NFC ()" - ":: NFC ()"
transliteration: transliteration:
- ":: Latin ()" - ":: Latin ()"
- "[[:Punctuation:][:Space:]]+ > ' '"
""") """)
content += "compound_suffixes:\n" content += "compound_suffixes:\n"
content += '\n'.join((" - " + s for s in suffixes)) + '\n' content += '\n'.join((" - " + s for s in suffixes)) + '\n'
@@ -32,13 +33,33 @@ def cfgfile(tmp_path, suffix='.yaml'):
return _create_config return _create_config
def test_missing_normalization(tmp_path):
def test_empty_rule_file(tmp_path):
fpath = tmp_path / ('test_config.yaml') fpath = tmp_path / ('test_config.yaml')
fpath.write_text(dedent("""\ fpath.write_text(dedent("""\
normalizatio: normalization:
- ":: NFD ()" transliteration:
compound_suffixes:
abbreviations:
""")) """))
rules = ICURuleLoader(fpath)
assert rules.get_search_rules() == ''
assert rules.get_normalization_rules() == ''
assert rules.get_transliteration_rules() == ''
assert rules.get_replacement_pairs() == []
CONFIG_SECTIONS = ('normalization', 'transliteration',
'compound_suffixes', 'abbreviations')
@pytest.mark.parametrize("section", CONFIG_SECTIONS)
def test_missing_normalization(tmp_path, section):
fpath = tmp_path / ('test_config.yaml')
with fpath.open('w') as fd:
for name in CONFIG_SECTIONS:
if name != section:
fd.write(name + ':\n')
with pytest.raises(UsageError): with pytest.raises(UsageError):
ICURuleLoader(fpath) ICURuleLoader(fpath)
@@ -53,6 +74,7 @@ def test_get_search_rules(cfgfile):
rules = loader.get_search_rules() rules = loader.get_search_rules()
trans = Transliterator.createFromRules("test", rules) trans = Transliterator.createFromRules("test", rules)
assert trans.transliterate(" Baum straße ") == " baum straße "
assert trans.transliterate(" Baumstraße ") == " baum straße " assert trans.transliterate(" Baumstraße ") == " baum straße "
assert trans.transliterate(" Baumstrasse ") == " baum strasse " assert trans.transliterate(" Baumstrasse ") == " baum strasse "
assert trans.transliterate(" Baumstr ") == " baum str " assert trans.transliterate(" Baumstr ") == " baum str "
@@ -61,6 +83,28 @@ def test_get_search_rules(cfgfile):
assert trans.transliterate(" проспект ") == " prospekt " assert trans.transliterate(" проспект ") == " prospekt "
def test_get_normalization_rules(cfgfile):
fpath = cfgfile(['strasse', 'straße', 'weg'],
['strasse,straße => str'])
loader = ICURuleLoader(fpath)
rules = loader.get_normalization_rules()
trans = Transliterator.createFromRules("test", rules)
assert trans.transliterate(" проспект-Prospekt ") == " проспект prospekt "
def test_get_transliteration_rules(cfgfile):
fpath = cfgfile(['strasse', 'straße', 'weg'],
['strasse,straße => str'])
loader = ICURuleLoader(fpath)
rules = loader.get_transliteration_rules()
trans = Transliterator.createFromRules("test", rules)
assert trans.transliterate(" проспект-Prospekt ") == " prospekt Prospekt "
def test_get_synonym_pairs(cfgfile): def test_get_synonym_pairs(cfgfile):
fpath = cfgfile(['Weg', 'Strasse'], fpath = cfgfile(['Weg', 'Strasse'],
['Strasse => str,st']) ['Strasse => str,st'])