switch reverse() to new Geometry datatype

Also switches to using bind parameters for recurring parameters.
This commit is contained in:
Sarah Hoffmann
2023-06-25 14:02:00 +02:00
parent 4bb4db0668
commit 6c4c9ec1f2
8 changed files with 127 additions and 100 deletions

View File

@@ -7,7 +7,7 @@
"""
Functions for specialised logging with HTML output.
"""
from typing import Any, Iterator, Optional, List, Tuple, cast
from typing import Any, Iterator, Optional, List, Tuple, cast, Union, Mapping, Sequence
from contextvars import ContextVar
import datetime as dt
import textwrap
@@ -74,22 +74,26 @@ class BaseLogger:
"""
def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
""" Print the SQL for the given statement.
"""
def format_sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> str:
def format_sql(self, conn: AsyncConnection, statement: 'sa.Executable',
extra_params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> str:
""" Return the comiled version of the statement.
"""
try:
return str(cast('sa.ClauseElement', statement)
.compile(conn.sync_engine, compile_kwargs={"literal_binds": True}))
except sa.exc.CompileError:
pass
except NotImplementedError:
pass
compiled = cast('sa.ClauseElement', statement).compile(conn.sync_engine)
return str(cast('sa.ClauseElement', statement).compile(conn.sync_engine))
params = dict(compiled.params)
if isinstance(extra_params, Mapping):
for k, v in extra_params.items():
params[k] = str(v)
elif isinstance(extra_params, Sequence) and extra_params:
for k in extra_params[0]:
params[k] = f':{k}'
return str(compiled) % params
class HTMLLogger(BaseLogger):
@@ -183,9 +187,10 @@ class HTMLLogger(BaseLogger):
self._write(f'</dl><b>TOTAL:</b> {total}</p>')
def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
self._timestamp()
sqlstr = self.format_sql(conn, statement)
sqlstr = self.format_sql(conn, statement, params)
if CODE_HIGHLIGHT:
sqlstr = highlight(sqlstr, PostgresLexer(),
HtmlFormatter(nowrap=True, lineseparator='<br />'))
@@ -276,8 +281,9 @@ class TextLogger(BaseLogger):
self._write(f'TOTAL: {total}\n\n')
def sql(self, conn: AsyncConnection, statement: 'sa.Executable') -> None:
sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement), width=78))
def sql(self, conn: AsyncConnection, statement: 'sa.Executable',
params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None]) -> None:
sqlstr = '\n| '.join(textwrap.wrap(self.format_sql(conn, statement, params), width=78))
self._write(f"| {sqlstr}\n\n")