From 35a023d133330da3d06e4fdf120b351495f172a6 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Thu, 12 Feb 2026 19:49:52 +0100 Subject: [PATCH] add function for inserting data to testing cursor --- test/python/conftest.py | 27 ++++++++------- test/python/cursor.py | 41 +++++++++++++++++++++++ test/python/tools/test_database_import.py | 2 +- 3 files changed, 57 insertions(+), 13 deletions(-) diff --git a/test/python/conftest.py b/test/python/conftest.py index 2f19ed4c..95f3cb0a 100644 --- a/test/python/conftest.py +++ b/test/python/conftest.py @@ -178,12 +178,14 @@ def place_row(place_table, temp_db_cursor): prerequisite to the fixture. """ idseq = itertools.count(1001) + def _insert(osm_type='N', osm_id=None, cls='amenity', typ='cafe', names=None, - admin_level=None, address=None, extratags=None, geom=None): - temp_db_cursor.execute("INSERT INTO place VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)", - (osm_id or next(idseq), osm_type, cls, typ, names, - admin_level, address, extratags, - geom or 'SRID=4326;POINT(0 0)')) + admin_level=None, address=None, extratags=None, geom='POINT(0 0)'): + args = {'osm_type': osm_type, 'osm_id': osm_id or next(idseq), + 'class': cls, 'type': typ, 'name': names, 'admin_level': admin_level, + 'address': address, 'extratags': extratags, + 'geometry': _with_srid(geom)} + temp_db_cursor.insert_row('place', **args) return _insert @@ -203,17 +205,18 @@ def place_postcode_table(temp_db_with_extensions, table_factory): @pytest.fixture def place_postcode_row(place_postcode_table, temp_db_cursor): - """ A factory for rows in the place table. The table is created as a + """ A factory for rows in the place_postcode table. The table is created as a prerequisite to the fixture. """ idseq = itertools.count(5001) + def _insert(osm_type='N', osm_id=None, postcode=None, country=None, - centroid=None, geom=None): - temp_db_cursor.execute("INSERT INTO place_postcode VALUES (%s, %s, %s, %s, %s, %s)", - (osm_type, osm_id or next(idseq), - postcode, country, - _with_srid(centroid, 'POINT(12.0 4.0)'), - _with_srid(geom))) + centroid='POINT(12.0 4.0)', geom=None): + temp_db_cursor.insert_row('place_postcode', + osm_type=osm_type, osm_id=osm_id or next(idseq), + postcode=postcode, country_code=country, + centroid=_with_srid(centroid), + geometry=_with_srid(geom)) return _insert diff --git a/test/python/cursor.py b/test/python/cursor.py index b9237727..1fc18720 100644 --- a/test/python/cursor.py +++ b/test/python/cursor.py @@ -58,3 +58,44 @@ class CursorForTesting(psycopg.Cursor): sql += pysql.SQL('WHERE') + pysql.SQL(where) return self.scalar(sql) + + def insert_row(self, table, **data): + """ Insert a row into the given table. + + 'data' is a dictionary of column names and associated values. + When the value is a pysql.Literal or pysql.SQL, then the expression + will be inserted as is instead of loading the value. When the + value is a tuple, then the first element will be added as an + SQL expression for the value and the second element is treated + as the actual value to insert. The SQL expression must contain + a %s placeholder in that case. + + If data contains a 'place_id' column, then the value of the + place_id column after insert is returned. Otherwise the function + returns nothing. + """ + columns = [] + placeholders = [] + values = [] + for k, v in data.items(): + columns.append(pysql.Identifier(k)) + if isinstance(v, tuple): + placeholders.append(pysql.SQL(v[0])) + values.append(v[1]) + elif isinstance(v, (pysql.Literal, pysql.SQL)): + placeholders.append(v) + else: + placeholders.append(pysql.Placeholder()) + values.append(v) + + sql = pysql.SQL("INSERT INTO {table} ({columns}) VALUES({values})")\ + .format(table=pysql.Identifier(table), + columns=pysql.SQL(',').join(columns), + values=pysql.SQL(',').join(placeholders)) + + if 'place_id' in data: + sql += pysql.SQL('RETURNING place_id') + + self.execute(sql, values) + + return self.fetchone()[0] if 'place_id' in data else None diff --git a/test/python/tools/test_database_import.py b/test/python/tools/test_database_import.py index 221e4fba..df22100c 100644 --- a/test/python/tools/test_database_import.py +++ b/test/python/tools/test_database_import.py @@ -177,7 +177,7 @@ async def test_load_data(dsn, place_row, placex_table, osmline_table, for oid in range(100, 130): place_row(osm_id=oid) place_row(osm_type='W', osm_id=342, cls='place', typ='houses', - geom='SRID=4326;LINESTRING(0 0, 10 10)') + geom='LINESTRING(0 0, 10 10)') await database_import.load_data(dsn, threads)