mirror of
https://github.com/osm-search/Nominatim.git
synced 2026-02-26 11:08:13 +00:00
port code to psycopg3
This commit is contained in:
@@ -1,236 +0,0 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
""" Non-blocking database connections.
|
||||
"""
|
||||
from typing import Callable, Any, Optional, Iterator, Sequence
|
||||
import logging
|
||||
import select
|
||||
import time
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import wait_select
|
||||
|
||||
# psycopg2 emits different exceptions pre and post 2.8. Detect if the new error
|
||||
# module is available and adapt the error handling accordingly.
|
||||
try:
|
||||
import psycopg2.errors # pylint: disable=no-name-in-module,import-error
|
||||
__has_psycopg2_errors__ = True
|
||||
except ImportError:
|
||||
__has_psycopg2_errors__ = False
|
||||
|
||||
from ..typing import T_cursor, Query
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
class DeadlockHandler:
|
||||
""" Context manager that catches deadlock exceptions and calls
|
||||
the given handler function. All other exceptions are passed on
|
||||
normally.
|
||||
"""
|
||||
|
||||
def __init__(self, handler: Callable[[], None], ignore_sql_errors: bool = False) -> None:
|
||||
self.handler = handler
|
||||
self.ignore_sql_errors = ignore_sql_errors
|
||||
|
||||
def __enter__(self) -> 'DeadlockHandler':
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
|
||||
if __has_psycopg2_errors__:
|
||||
if exc_type == psycopg2.errors.DeadlockDetected: # pylint: disable=E1101
|
||||
self.handler()
|
||||
return True
|
||||
elif exc_type == psycopg2.extensions.TransactionRollbackError \
|
||||
and exc_value.pgcode == '40P01':
|
||||
self.handler()
|
||||
return True
|
||||
|
||||
if self.ignore_sql_errors and isinstance(exc_value, psycopg2.Error):
|
||||
LOG.info("SQL error ignored: %s", exc_value)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class DBConnection:
|
||||
""" A single non-blocking database connection.
|
||||
"""
|
||||
|
||||
def __init__(self, dsn: str,
|
||||
cursor_factory: Optional[Callable[..., T_cursor]] = None,
|
||||
ignore_sql_errors: bool = False) -> None:
|
||||
self.dsn = dsn
|
||||
|
||||
self.current_query: Optional[Query] = None
|
||||
self.current_params: Optional[Sequence[Any]] = None
|
||||
self.ignore_sql_errors = ignore_sql_errors
|
||||
|
||||
self.conn: Optional['psycopg2._psycopg.connection'] = None
|
||||
self.cursor: Optional['psycopg2._psycopg.cursor'] = None
|
||||
self.connect(cursor_factory=cursor_factory)
|
||||
|
||||
def close(self) -> None:
|
||||
""" Close all open connections. Does not wait for pending requests.
|
||||
"""
|
||||
if self.conn is not None:
|
||||
if self.cursor is not None:
|
||||
self.cursor.close()
|
||||
self.cursor = None
|
||||
self.conn.close()
|
||||
|
||||
self.conn = None
|
||||
|
||||
def connect(self, cursor_factory: Optional[Callable[..., T_cursor]] = None) -> None:
|
||||
""" (Re)connect to the database. Creates an asynchronous connection
|
||||
with JIT and parallel processing disabled. If a connection was
|
||||
already open, it is closed and a new connection established.
|
||||
The caller must ensure that no query is pending before reconnecting.
|
||||
"""
|
||||
self.close()
|
||||
|
||||
# Use a dict to hand in the parameters because async is a reserved
|
||||
# word in Python3.
|
||||
self.conn = psycopg2.connect(**{'dsn': self.dsn, 'async': True}) # type: ignore
|
||||
assert self.conn
|
||||
self.wait()
|
||||
|
||||
if cursor_factory is not None:
|
||||
self.cursor = self.conn.cursor(cursor_factory=cursor_factory)
|
||||
else:
|
||||
self.cursor = self.conn.cursor()
|
||||
# Disable JIT and parallel workers as they are known to cause problems.
|
||||
# Update pg_settings instead of using SET because it does not yield
|
||||
# errors on older versions of Postgres where the settings are not
|
||||
# implemented.
|
||||
self.perform(
|
||||
""" UPDATE pg_settings SET setting = -1 WHERE name = 'jit_above_cost';
|
||||
UPDATE pg_settings SET setting = 0
|
||||
WHERE name = 'max_parallel_workers_per_gather';""")
|
||||
self.wait()
|
||||
|
||||
def _deadlock_handler(self) -> None:
|
||||
LOG.info("Deadlock detected (params = %s), retry.", str(self.current_params))
|
||||
assert self.cursor is not None
|
||||
assert self.current_query is not None
|
||||
assert self.current_params is not None
|
||||
|
||||
self.cursor.execute(self.current_query, self.current_params)
|
||||
|
||||
def wait(self) -> None:
|
||||
""" Block until any pending operation is done.
|
||||
"""
|
||||
while True:
|
||||
with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
|
||||
wait_select(self.conn)
|
||||
self.current_query = None
|
||||
return
|
||||
|
||||
def perform(self, sql: Query, args: Optional[Sequence[Any]] = None) -> None:
|
||||
""" Send SQL query to the server. Returns immediately without
|
||||
blocking.
|
||||
"""
|
||||
assert self.cursor is not None
|
||||
self.current_query = sql
|
||||
self.current_params = args
|
||||
self.cursor.execute(sql, args)
|
||||
|
||||
def fileno(self) -> int:
|
||||
""" File descriptor to wait for. (Makes this class select()able.)
|
||||
"""
|
||||
assert self.conn is not None
|
||||
return self.conn.fileno()
|
||||
|
||||
def is_done(self) -> bool:
|
||||
""" Check if the connection is available for a new query.
|
||||
|
||||
Also checks if the previous query has run into a deadlock.
|
||||
If so, then the previous query is repeated.
|
||||
"""
|
||||
assert self.conn is not None
|
||||
|
||||
if self.current_query is None:
|
||||
return True
|
||||
|
||||
with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
|
||||
if self.conn.poll() == psycopg2.extensions.POLL_OK:
|
||||
self.current_query = None
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class WorkerPool:
|
||||
""" A pool of asynchronous database connections.
|
||||
|
||||
The pool may be used as a context manager.
|
||||
"""
|
||||
REOPEN_CONNECTIONS_AFTER = 100000
|
||||
|
||||
def __init__(self, dsn: str, pool_size: int, ignore_sql_errors: bool = False) -> None:
|
||||
self.threads = [DBConnection(dsn, ignore_sql_errors=ignore_sql_errors)
|
||||
for _ in range(pool_size)]
|
||||
self.free_workers = self._yield_free_worker()
|
||||
self.wait_time = 0.0
|
||||
|
||||
|
||||
def finish_all(self) -> None:
|
||||
""" Wait for all connection to finish.
|
||||
"""
|
||||
for thread in self.threads:
|
||||
while not thread.is_done():
|
||||
thread.wait()
|
||||
|
||||
self.free_workers = self._yield_free_worker()
|
||||
|
||||
def close(self) -> None:
|
||||
""" Close all connections and clear the pool.
|
||||
"""
|
||||
for thread in self.threads:
|
||||
thread.close()
|
||||
self.threads = []
|
||||
self.free_workers = iter([])
|
||||
|
||||
|
||||
def next_free_worker(self) -> DBConnection:
|
||||
""" Get the next free connection.
|
||||
"""
|
||||
return next(self.free_workers)
|
||||
|
||||
|
||||
def _yield_free_worker(self) -> Iterator[DBConnection]:
|
||||
ready = self.threads
|
||||
command_stat = 0
|
||||
while True:
|
||||
for thread in ready:
|
||||
if thread.is_done():
|
||||
command_stat += 1
|
||||
yield thread
|
||||
|
||||
if command_stat > self.REOPEN_CONNECTIONS_AFTER:
|
||||
self._reconnect_threads()
|
||||
ready = self.threads
|
||||
command_stat = 0
|
||||
else:
|
||||
tstart = time.time()
|
||||
_, ready, _ = select.select([], self.threads, [])
|
||||
self.wait_time += time.time() - tstart
|
||||
|
||||
|
||||
def _reconnect_threads(self) -> None:
|
||||
for thread in self.threads:
|
||||
while not thread.is_done():
|
||||
thread.wait()
|
||||
thread.connect()
|
||||
|
||||
|
||||
def __enter__(self) -> 'WorkerPool':
|
||||
return self
|
||||
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
self.finish_all()
|
||||
self.close()
|
||||
@@ -7,73 +7,27 @@
|
||||
"""
|
||||
Specialised connection and cursor functions.
|
||||
"""
|
||||
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload,\
|
||||
Tuple, Iterable
|
||||
import contextlib
|
||||
from typing import Optional, Any, Dict, Tuple
|
||||
import logging
|
||||
import os
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
import psycopg2.extras
|
||||
from psycopg2 import sql as pysql
|
||||
import psycopg
|
||||
import psycopg.types.hstore
|
||||
from psycopg import sql as pysql
|
||||
|
||||
from ..typing import SysEnv, Query, T_cursor
|
||||
from ..typing import SysEnv
|
||||
from ..errors import UsageError
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
class Cursor(psycopg2.extras.DictCursor):
|
||||
""" A cursor returning dict-like objects and providing specialised
|
||||
execution functions.
|
||||
"""
|
||||
# pylint: disable=arguments-renamed,arguments-differ
|
||||
def execute(self, query: Query, args: Any = None) -> None:
|
||||
""" Query execution that logs the SQL query when debugging is enabled.
|
||||
"""
|
||||
if LOG.isEnabledFor(logging.DEBUG):
|
||||
LOG.debug(self.mogrify(query, args).decode('utf-8'))
|
||||
Cursor = psycopg.Cursor[Any]
|
||||
Connection = psycopg.Connection[Any]
|
||||
|
||||
super().execute(query, args)
|
||||
|
||||
|
||||
def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
|
||||
template: Optional[Query] = None) -> None:
|
||||
""" Wrapper for the psycopg2 convenience function to execute
|
||||
SQL for a list of values.
|
||||
"""
|
||||
LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
|
||||
|
||||
psycopg2.extras.execute_values(self, sql, argslist, template=template)
|
||||
|
||||
|
||||
class Connection(psycopg2.extensions.connection):
|
||||
""" A connection that provides the specialised cursor by default and
|
||||
adds convenience functions for administrating the database.
|
||||
"""
|
||||
@overload # type: ignore[override]
|
||||
def cursor(self) -> Cursor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cursor(self, name: str) -> Cursor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
|
||||
...
|
||||
|
||||
def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore
|
||||
""" Return a new cursor. By default the specialised cursor is returned.
|
||||
"""
|
||||
return super().cursor(cursor_factory=cursor_factory, **kwargs)
|
||||
|
||||
|
||||
def execute_scalar(conn: Connection, sql: Query, args: Any = None) -> Any:
|
||||
def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -> Any:
|
||||
""" Execute query that returns a single value. The value is returned.
|
||||
If the query yields more than one row, a ValueError is raised.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
with conn.cursor(row_factory=psycopg.rows.tuple_row) as cur:
|
||||
cur.execute(sql, args)
|
||||
|
||||
if cur.rowcount != 1:
|
||||
@@ -144,7 +98,7 @@ def server_version_tuple(conn: Connection) -> Tuple[int, int]:
|
||||
""" Return the server version as a tuple of (major, minor).
|
||||
Converts correctly for pre-10 and post-10 PostgreSQL versions.
|
||||
"""
|
||||
version = conn.server_version
|
||||
version = conn.info.server_version
|
||||
if version < 100000:
|
||||
return (int(version / 10000), int((version % 10000) / 100))
|
||||
|
||||
@@ -164,31 +118,25 @@ def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
|
||||
|
||||
return (int(version_parts[0]), int(version_parts[1]))
|
||||
|
||||
|
||||
def register_hstore(conn: Connection) -> None:
|
||||
""" Register the hstore type with psycopg for the connection.
|
||||
"""
|
||||
psycopg2.extras.register_hstore(conn)
|
||||
info = psycopg.types.TypeInfo.fetch(conn, "hstore")
|
||||
if info is None:
|
||||
raise RuntimeError('Hstore extension is requested but not installed.')
|
||||
psycopg.types.hstore.register_hstore(info, conn)
|
||||
|
||||
|
||||
class ConnectionContext(ContextManager[Connection]):
|
||||
""" Context manager of the connection that also provides direct access
|
||||
to the underlying connection.
|
||||
"""
|
||||
connection: Connection
|
||||
|
||||
|
||||
def connect(dsn: str) -> ConnectionContext:
|
||||
def connect(dsn: str, **kwargs: Any) -> Connection:
|
||||
""" Open a connection to the database using the specialised connection
|
||||
factory. The returned object may be used in conjunction with 'with'.
|
||||
When used outside a context manager, use the `connection` attribute
|
||||
to get the connection.
|
||||
"""
|
||||
try:
|
||||
conn = psycopg2.connect(dsn, connection_factory=Connection)
|
||||
ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
|
||||
ctxmgr.connection = conn
|
||||
return ctxmgr
|
||||
except psycopg2.OperationalError as err:
|
||||
return psycopg.connect(dsn, row_factory=psycopg.rows.namedtuple_row, **kwargs)
|
||||
except psycopg.OperationalError as err:
|
||||
raise UsageError(f"Cannot connect to database: {err}") from err
|
||||
|
||||
|
||||
@@ -233,10 +181,18 @@ def get_pg_env(dsn: str,
|
||||
"""
|
||||
env = dict(base_env if base_env is not None else os.environ)
|
||||
|
||||
for param, value in psycopg2.extensions.parse_dsn(dsn).items():
|
||||
for param, value in psycopg.conninfo.conninfo_to_dict(dsn).items():
|
||||
if param in _PG_CONNECTION_STRINGS:
|
||||
env[_PG_CONNECTION_STRINGS[param]] = value
|
||||
env[_PG_CONNECTION_STRINGS[param]] = str(value)
|
||||
else:
|
||||
LOG.error("Unknown connection parameter '%s' ignored.", param)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
async def run_async_query(dsn: str, query: psycopg.abc.Query) -> None:
|
||||
""" Open a connection to the database and run a single query
|
||||
asynchronously.
|
||||
"""
|
||||
async with await psycopg.AsyncConnection.connect(dsn) as aconn:
|
||||
await aconn.execute(query)
|
||||
|
||||
87
src/nominatim_db/db/query_pool.py
Normal file
87
src/nominatim_db/db/query_pool.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
A connection pool that executes incoming queries in parallel.
|
||||
"""
|
||||
from typing import Any, Tuple, Optional
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
|
||||
import psycopg
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
QueueItem = Optional[Tuple[psycopg.abc.Query, Any]]
|
||||
|
||||
class QueryPool:
|
||||
""" Pool to run SQL queries in parallel asynchronous execution.
|
||||
|
||||
All queries are run in autocommit mode. If parallel execution leads
|
||||
to a deadlock, then the query is repeated.
|
||||
The results of the queries is discarded.
|
||||
"""
|
||||
def __init__(self, dsn: str, pool_size: int = 1, **conn_args: Any) -> None:
|
||||
self.wait_time = 0.0
|
||||
self.query_queue: 'asyncio.Queue[QueueItem]' = asyncio.Queue(maxsize=2 * pool_size)
|
||||
|
||||
self.pool = [asyncio.create_task(self._worker_loop(dsn, **conn_args))
|
||||
for _ in range(pool_size)]
|
||||
|
||||
|
||||
async def put_query(self, query: psycopg.abc.Query, params: Any) -> None:
|
||||
""" Schedule a query for execution.
|
||||
"""
|
||||
tstart = time.time()
|
||||
await self.query_queue.put((query, params))
|
||||
self.wait_time += time.time() - tstart
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
async def finish(self) -> None:
|
||||
""" Wait for all queries to finish and close the pool.
|
||||
"""
|
||||
for _ in self.pool:
|
||||
await self.query_queue.put(None)
|
||||
|
||||
tstart = time.time()
|
||||
await asyncio.wait(self.pool)
|
||||
self.wait_time += time.time() - tstart
|
||||
|
||||
for task in self.pool:
|
||||
excp = task.exception()
|
||||
if excp is not None:
|
||||
raise excp
|
||||
|
||||
|
||||
async def _worker_loop(self, dsn: str, **conn_args: Any) -> None:
|
||||
conn_args['autocommit'] = True
|
||||
aconn = await psycopg.AsyncConnection.connect(dsn, **conn_args)
|
||||
async with aconn:
|
||||
async with aconn.cursor() as cur:
|
||||
item = await self.query_queue.get()
|
||||
while item is not None:
|
||||
try:
|
||||
if item[1] is None:
|
||||
await cur.execute(item[0])
|
||||
else:
|
||||
await cur.execute(item[0], item[1])
|
||||
|
||||
item = await self.query_queue.get()
|
||||
except psycopg.errors.DeadlockDetected:
|
||||
assert item is not None
|
||||
LOG.info("Deadlock detected (sql = %s, params = %s), retry.",
|
||||
str(item[0]), str(item[1]))
|
||||
# item is still valid here, causing a retry
|
||||
|
||||
|
||||
async def __aenter__(self) -> 'QueryPool':
|
||||
return self
|
||||
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
await self.finish()
|
||||
@@ -8,11 +8,12 @@
|
||||
Preprocessing of SQL files.
|
||||
"""
|
||||
from typing import Set, Dict, Any, cast
|
||||
|
||||
import jinja2
|
||||
|
||||
from .connection import Connection, server_version_tuple, postgis_version_tuple
|
||||
from .async_connection import WorkerPool
|
||||
from ..config import Configuration
|
||||
from ..db.query_pool import QueryPool
|
||||
|
||||
def _get_partitions(conn: Connection) -> Set[int]:
|
||||
""" Get the set of partitions currently in use.
|
||||
@@ -125,8 +126,8 @@ class SQLPreprocessor:
|
||||
conn.commit()
|
||||
|
||||
|
||||
def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
|
||||
**kwargs: Any) -> None:
|
||||
async def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
|
||||
**kwargs: Any) -> None:
|
||||
""" Execute the given SQL files using parallel asynchronous connections.
|
||||
The keyword arguments may supply additional parameters for
|
||||
preprocessing.
|
||||
@@ -138,6 +139,6 @@ class SQLPreprocessor:
|
||||
|
||||
parts = sql.split('\n---\n')
|
||||
|
||||
with WorkerPool(dsn, num_threads) as pool:
|
||||
async with QueryPool(dsn, num_threads) as pool:
|
||||
for part in parts:
|
||||
pool.next_free_worker().perform(part)
|
||||
await pool.put_query(part, None)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"""
|
||||
Access and helper functions for the status and status log table.
|
||||
"""
|
||||
from typing import Optional, Tuple, cast
|
||||
from typing import Optional, Tuple
|
||||
import datetime as dt
|
||||
import logging
|
||||
import re
|
||||
@@ -15,20 +15,11 @@ import re
|
||||
from .connection import Connection, table_exists, execute_scalar
|
||||
from ..utils.url_utils import get_url
|
||||
from ..errors import UsageError
|
||||
from ..typing import TypedDict
|
||||
|
||||
LOG = logging.getLogger()
|
||||
ISODATE_FORMAT = '%Y-%m-%dT%H:%M:%S'
|
||||
|
||||
|
||||
class StatusRow(TypedDict):
|
||||
""" Dictionary of columns of the import_status table.
|
||||
"""
|
||||
lastimportdate: dt.datetime
|
||||
sequence_id: Optional[int]
|
||||
indexed: Optional[bool]
|
||||
|
||||
|
||||
def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetime:
|
||||
""" Determine the date of the database from the newest object in the
|
||||
data base.
|
||||
@@ -102,8 +93,9 @@ def get_status(conn: Connection) -> Tuple[Optional[dt.datetime], Optional[int],
|
||||
if cur.rowcount < 1:
|
||||
return None, None, None
|
||||
|
||||
row = cast(StatusRow, cur.fetchone())
|
||||
return row['lastimportdate'], row['sequence_id'], row['indexed']
|
||||
row = cur.fetchone()
|
||||
assert row
|
||||
return row.lastimportdate, row.sequence_id, row.indexed
|
||||
|
||||
|
||||
def set_indexed(conn: Connection, state: bool) -> None:
|
||||
|
||||
@@ -7,14 +7,13 @@
|
||||
"""
|
||||
Helper functions for handling DB accesses.
|
||||
"""
|
||||
from typing import IO, Optional, Union, Any, Iterable
|
||||
from typing import IO, Optional, Union
|
||||
import subprocess
|
||||
import logging
|
||||
import gzip
|
||||
import io
|
||||
from pathlib import Path
|
||||
|
||||
from .connection import get_pg_env, Cursor
|
||||
from .connection import get_pg_env
|
||||
from ..errors import UsageError
|
||||
|
||||
LOG = logging.getLogger()
|
||||
@@ -72,58 +71,3 @@ def execute_file(dsn: str, fname: Path,
|
||||
|
||||
if ret != 0 or remain > 0:
|
||||
raise UsageError("Failed to execute SQL file.")
|
||||
|
||||
|
||||
# List of characters that need to be quoted for the copy command.
|
||||
_SQL_TRANSLATION = {ord('\\'): '\\\\',
|
||||
ord('\t'): '\\t',
|
||||
ord('\n'): '\\n'}
|
||||
|
||||
|
||||
class CopyBuffer:
|
||||
""" Data collector for the copy_from command.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.buffer = io.StringIO()
|
||||
|
||||
|
||||
def __enter__(self) -> 'CopyBuffer':
|
||||
return self
|
||||
|
||||
|
||||
def size(self) -> int:
|
||||
""" Return the number of bytes the buffer currently contains.
|
||||
"""
|
||||
return self.buffer.tell()
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
if self.buffer is not None:
|
||||
self.buffer.close()
|
||||
|
||||
|
||||
def add(self, *data: Any) -> None:
|
||||
""" Add another row of data to the copy buffer.
|
||||
"""
|
||||
first = True
|
||||
for column in data:
|
||||
if first:
|
||||
first = False
|
||||
else:
|
||||
self.buffer.write('\t')
|
||||
if column is None:
|
||||
self.buffer.write('\\N')
|
||||
else:
|
||||
self.buffer.write(str(column).translate(_SQL_TRANSLATION))
|
||||
self.buffer.write('\n')
|
||||
|
||||
|
||||
def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None:
|
||||
""" Copy all collected data into the given table.
|
||||
|
||||
The buffer is empty and reusable after this operation.
|
||||
"""
|
||||
if self.buffer.tell() > 0:
|
||||
self.buffer.seek(0)
|
||||
cur.copy_from(self.buffer, table, columns=columns)
|
||||
self.buffer = io.StringIO()
|
||||
|
||||
Reference in New Issue
Block a user