add function for inserting data to testing cursor

This commit is contained in:
Sarah Hoffmann
2026-02-12 19:49:52 +01:00
parent 79682a94ce
commit 35a023d133
3 changed files with 57 additions and 13 deletions

View File

@@ -178,12 +178,14 @@ def place_row(place_table, temp_db_cursor):
prerequisite to the fixture.
"""
idseq = itertools.count(1001)
def _insert(osm_type='N', osm_id=None, cls='amenity', typ='cafe', names=None,
admin_level=None, address=None, extratags=None, geom=None):
temp_db_cursor.execute("INSERT INTO place VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)",
(osm_id or next(idseq), osm_type, cls, typ, names,
admin_level, address, extratags,
geom or 'SRID=4326;POINT(0 0)'))
admin_level=None, address=None, extratags=None, geom='POINT(0 0)'):
args = {'osm_type': osm_type, 'osm_id': osm_id or next(idseq),
'class': cls, 'type': typ, 'name': names, 'admin_level': admin_level,
'address': address, 'extratags': extratags,
'geometry': _with_srid(geom)}
temp_db_cursor.insert_row('place', **args)
return _insert
@@ -203,17 +205,18 @@ def place_postcode_table(temp_db_with_extensions, table_factory):
@pytest.fixture
def place_postcode_row(place_postcode_table, temp_db_cursor):
""" A factory for rows in the place table. The table is created as a
""" A factory for rows in the place_postcode table. The table is created as a
prerequisite to the fixture.
"""
idseq = itertools.count(5001)
def _insert(osm_type='N', osm_id=None, postcode=None, country=None,
centroid=None, geom=None):
temp_db_cursor.execute("INSERT INTO place_postcode VALUES (%s, %s, %s, %s, %s, %s)",
(osm_type, osm_id or next(idseq),
postcode, country,
_with_srid(centroid, 'POINT(12.0 4.0)'),
_with_srid(geom)))
centroid='POINT(12.0 4.0)', geom=None):
temp_db_cursor.insert_row('place_postcode',
osm_type=osm_type, osm_id=osm_id or next(idseq),
postcode=postcode, country_code=country,
centroid=_with_srid(centroid),
geometry=_with_srid(geom))
return _insert

View File

@@ -58,3 +58,44 @@ class CursorForTesting(psycopg.Cursor):
sql += pysql.SQL('WHERE') + pysql.SQL(where)
return self.scalar(sql)
def insert_row(self, table, **data):
""" Insert a row into the given table.
'data' is a dictionary of column names and associated values.
When the value is a pysql.Literal or pysql.SQL, then the expression
will be inserted as is instead of loading the value. When the
value is a tuple, then the first element will be added as an
SQL expression for the value and the second element is treated
as the actual value to insert. The SQL expression must contain
a %s placeholder in that case.
If data contains a 'place_id' column, then the value of the
place_id column after insert is returned. Otherwise the function
returns nothing.
"""
columns = []
placeholders = []
values = []
for k, v in data.items():
columns.append(pysql.Identifier(k))
if isinstance(v, tuple):
placeholders.append(pysql.SQL(v[0]))
values.append(v[1])
elif isinstance(v, (pysql.Literal, pysql.SQL)):
placeholders.append(v)
else:
placeholders.append(pysql.Placeholder())
values.append(v)
sql = pysql.SQL("INSERT INTO {table} ({columns}) VALUES({values})")\
.format(table=pysql.Identifier(table),
columns=pysql.SQL(',').join(columns),
values=pysql.SQL(',').join(placeholders))
if 'place_id' in data:
sql += pysql.SQL('RETURNING place_id')
self.execute(sql, values)
return self.fetchone()[0] if 'place_id' in data else None

View File

@@ -177,7 +177,7 @@ async def test_load_data(dsn, place_row, placex_table, osmline_table,
for oid in range(100, 130):
place_row(osm_id=oid)
place_row(osm_type='W', osm_id=342, cls='place', typ='houses',
geom='SRID=4326;LINESTRING(0 0, 10 10)')
geom='LINESTRING(0 0, 10 10)')
await database_import.load_data(dsn, threads)