test: move Testingcursor into separate class

Also adds more convenience functions: counting with a where
statement and a wrapper to execute_values().
This commit is contained in:
Sarah Hoffmann
2021-05-19 10:30:36 +02:00
parent 16bb007135
commit 510eb53f53
2 changed files with 56 additions and 36 deletions

View File

@@ -1,12 +1,11 @@
import importlib
import itertools import itertools
import sys import sys
import tempfile
from pathlib import Path from pathlib import Path
import psycopg2 import psycopg2
import psycopg2.extras import psycopg2.extras
import pytest import pytest
import tempfile
SRC_DIR = Path(__file__) / '..' / '..' / '..' SRC_DIR = Path(__file__) / '..' / '..' / '..'
@@ -21,38 +20,7 @@ import nominatim.tokenizer.factory
import dummy_tokenizer import dummy_tokenizer
import mocks import mocks
from cursor import TestingCursor
class _TestingCursor(psycopg2.extras.DictCursor):
""" Extension to the DictCursor class that provides execution
short-cuts that simplify writing assertions.
"""
def scalar(self, sql, params=None):
""" Execute a query with a single return value and return this value.
Raises an assertion when not exactly one row is returned.
"""
self.execute(sql, params)
assert self.rowcount == 1
return self.fetchone()[0]
def row_set(self, sql, params=None):
""" Execute a query and return the result as a set of tuples.
"""
self.execute(sql, params)
return set((tuple(row) for row in self))
def table_exists(self, table):
""" Check that a table with the given name exists in the database.
"""
num = self.scalar("""SELECT count(*) FROM pg_tables
WHERE tablename = %s""", (table, ))
return num == 1
def table_rows(self, table):
""" Return the number of rows in the given table.
"""
return self.scalar('SELECT count(*) FROM ' + table)
@pytest.fixture @pytest.fixture
@@ -70,7 +38,7 @@ def temp_db(monkeypatch):
conn.close() conn.close()
monkeypatch.setenv('NOMINATIM_DATABASE_DSN' , 'dbname=' + name) monkeypatch.setenv('NOMINATIM_DATABASE_DSN', 'dbname=' + name)
yield name yield name
@@ -113,7 +81,7 @@ def temp_db_cursor(temp_db):
""" """
conn = psycopg2.connect('dbname=' + temp_db) conn = psycopg2.connect('dbname=' + temp_db)
conn.set_isolation_level(0) conn.set_isolation_level(0)
with conn.cursor(cursor_factory=_TestingCursor) as cur: with conn.cursor(cursor_factory=TestingCursor) as cur:
yield cur yield cur
conn.close() conn.close()

52
test/python/cursor.py Normal file
View File

@@ -0,0 +1,52 @@
"""
Specialised psycopg2 cursor with shortcut functions useful for testing.
"""
import psycopg2.extras
class TestingCursor(psycopg2.extras.DictCursor):
""" Extension to the DictCursor class that provides execution
short-cuts that simplify writing assertions.
"""
def scalar(self, sql, params=None):
""" Execute a query with a single return value and return this value.
Raises an assertion when not exactly one row is returned.
"""
self.execute(sql, params)
assert self.rowcount == 1
return self.fetchone()[0]
def row_set(self, sql, params=None):
""" Execute a query and return the result as a set of tuples.
Fails when the SQL command returns duplicate rows.
"""
self.execute(sql, params)
result = set((tuple(row) for row in self))
assert len(result) == self.rowcount
return result
def table_exists(self, table):
""" Check that a table with the given name exists in the database.
"""
num = self.scalar("""SELECT count(*) FROM pg_tables
WHERE tablename = %s""", (table, ))
return num == 1
def table_rows(self, table, where=None):
""" Return the number of rows in the given table.
"""
if where is None:
return self.scalar('SELECT count(*) FROM ' + table)
return self.scalar('SELECT count(*) FROM {} WHERE {}'.format(table, where))
def execute_values(self, *args, **kwargs):
""" Execute the execute_values() function on the cursor.
"""
psycopg2.extras.execute_values(self, *args, **kwargs)