Merge pull request #3291 from lonvia/fix-timezone-handling

Fix timezone handling for timestamps from the database
This commit is contained in:
Sarah Hoffmann
2024-01-07 15:22:42 +01:00
committed by GitHub
2 changed files with 16 additions and 4 deletions

View File

@@ -37,7 +37,10 @@ async def get_status(conn: SearchConnection) -> StatusResult:
status.data_updated = await conn.scalar(sql) status.data_updated = await conn.scalar(sql)
if status.data_updated is not None: if status.data_updated is not None:
status.data_updated = status.data_updated.replace(tzinfo=dt.timezone.utc) if status.data_updated.tzinfo is None:
status.data_updated = status.data_updated.replace(tzinfo=dt.timezone.utc)
else:
status.data_updated = status.data_updated.astimezone(dt.timezone.utc)
# Database version # Database version
try: try:

View File

@@ -7,13 +7,14 @@
""" """
Exporting a Nominatim database to SQlite. Exporting a Nominatim database to SQlite.
""" """
from typing import Set from typing import Set, Any
import datetime as dt
import logging import logging
from pathlib import Path from pathlib import Path
import sqlalchemy as sa import sqlalchemy as sa
from nominatim.typing import SaSelect from nominatim.typing import SaSelect, SaRow
from nominatim.db.sqlalchemy_types import Geometry, IntArray from nominatim.db.sqlalchemy_types import Geometry, IntArray
from nominatim.api.search.query_analyzer_factory import make_query_analyzer from nominatim.api.search.query_analyzer_factory import make_query_analyzer
import nominatim.api as napi import nominatim.api as napi
@@ -124,12 +125,20 @@ class SqliteWriter:
async def copy_data(self) -> None: async def copy_data(self) -> None:
""" Copy data for all registered tables. """ Copy data for all registered tables.
""" """
def _getfield(row: SaRow, key: str) -> Any:
value = getattr(row, key)
if isinstance(value, dt.datetime):
if value.tzinfo is not None:
value = value.astimezone(dt.timezone.utc)
return value
for table in self.dest.t.meta.sorted_tables: for table in self.dest.t.meta.sorted_tables:
LOG.warning("Copying '%s'", table.name) LOG.warning("Copying '%s'", table.name)
async_result = await self.src.connection.stream(self.select_from(table.name)) async_result = await self.src.connection.stream(self.select_from(table.name))
async for partition in async_result.partitions(10000): async for partition in async_result.partitions(10000):
data = [{('class_' if k == 'class' else k): getattr(r, k) for k in r._fields} data = [{('class_' if k == 'class' else k): _getfield(r, k)
for k in r._fields}
for r in partition] for r in partition]
await self.dest.execute(table.insert(), data) await self.dest.execute(table.insert(), data)