remove extension existence helper

This is only used in one place.
This commit is contained in:
Sarah Hoffmann
2024-07-02 14:52:57 +02:00
parent e3353deee0
commit 71249bd94a
2 changed files with 6 additions and 11 deletions

View File

@@ -175,20 +175,13 @@ class Connection(psycopg2.extensions.connection):
return (int(version_parts[0]), int(version_parts[1])) return (int(version_parts[0]), int(version_parts[1]))
def extension_loaded(self, extension_name: str) -> bool:
""" Return True if the hstore extension is loaded in the database.
"""
with self.cursor() as cur:
cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, ))
return cur.rowcount > 0
class ConnectionContext(ContextManager[Connection]): class ConnectionContext(ContextManager[Connection]):
""" Context manager of the connection that also provides direct access """ Context manager of the connection that also provides direct access
to the underlying connection. to the underlying connection.
""" """
connection: Connection connection: Connection
def connect(dsn: str) -> ConnectionContext: def connect(dsn: str) -> ConnectionContext:
""" Open a connection to the database using the specialised connection """ Open a connection to the database using the specialised connection
factory. The returned object may be used in conjunction with 'with'. factory. The returned object may be used in conjunction with 'with'.

View File

@@ -40,9 +40,11 @@ def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int,
def _require_loaded(extension_name: str, conn: Connection) -> None: def _require_loaded(extension_name: str, conn: Connection) -> None:
""" Check that the given extension is loaded. """ """ Check that the given extension is loaded. """
if not conn.extension_loaded(extension_name): with conn.cursor() as cur:
LOG.fatal('Required module %s is not loaded.', extension_name) cur.execute('SELECT * FROM pg_extension WHERE extname = %s', (extension_name, ))
raise UsageError(f'{extension_name} is not loaded.') if cur.rowcount <= 0:
LOG.fatal('Required module %s is not loaded.', extension_name)
raise UsageError(f'{extension_name} is not loaded.')
def check_existing_database_plugins(dsn: str) -> None: def check_existing_database_plugins(dsn: str) -> None: