convert connect() into a context manager

This commit is contained in:
Sarah Hoffmann
2021-02-23 10:11:21 +01:00
parent 204fe20b4b
commit e520613362
12 changed files with 53 additions and 59 deletions

View File

@@ -54,9 +54,8 @@ class AdminFuncs:
if args.analyse_indexing:
LOG.warning('Analysing performance of indexing function')
from ..tools import admin
conn = connect(args.config.get_libpq_dsn())
admin.analyse_indexing(conn, osm_id=args.osm_id, place_id=args.place_id)
conn.close()
with connect(args.config.get_libpq_dsn()) as conn:
admin.analyse_indexing(conn, osm_id=args.osm_id, place_id=args.place_id)
return 0

View File

@@ -29,9 +29,8 @@ class SetupFreeze:
def run(args):
from ..tools import freeze
conn = connect(args.config.get_libpq_dsn())
freeze.drop_update_tables(conn)
with connect(args.config.get_libpq_dsn()) as conn:
freeze.drop_update_tables(conn)
freeze.drop_flatnode_file(args.config.FLATNODE_FILE)
conn.close()
return 0

View File

@@ -51,8 +51,7 @@ class UpdateIndex:
if not args.no_boundaries and not args.boundaries_only \
and args.minrank == 0 and args.maxrank == 30:
conn = connect(args.config.get_libpq_dsn())
status.set_indexed(conn, True)
conn.close()
with connect(args.config.get_libpq_dsn()) as conn:
status.set_indexed(conn, True)
return 0

View File

@@ -50,29 +50,25 @@ class UpdateRefresh:
if args.postcodes:
LOG.warning("Update postcodes centroid")
conn = connect(args.config.get_libpq_dsn())
refresh.update_postcodes(conn, args.sqllib_dir)
conn.close()
with connect(args.config.get_libpq_dsn()) as conn:
refresh.update_postcodes(conn, args.sqllib_dir)
if args.word_counts:
LOG.warning('Recompute frequency of full-word search terms')
conn = connect(args.config.get_libpq_dsn())
refresh.recompute_word_counts(conn, args.sqllib_dir)
conn.close()
with connect(args.config.get_libpq_dsn()) as conn:
refresh.recompute_word_counts(conn, args.sqllib_dir)
if args.address_levels:
cfg = Path(args.config.ADDRESS_LEVEL_CONFIG)
LOG.warning('Updating address levels from %s', cfg)
conn = connect(args.config.get_libpq_dsn())
refresh.load_address_levels_from_file(conn, cfg)
conn.close()
with connect(args.config.get_libpq_dsn()) as conn:
refresh.load_address_levels_from_file(conn, cfg)
if args.functions:
LOG.warning('Create functions')
conn = connect(args.config.get_libpq_dsn())
refresh.create_functions(conn, args.config, args.sqllib_dir,
args.diffs, args.enable_debug_statements)
conn.close()
with connect(args.config.get_libpq_dsn()) as conn:
refresh.create_functions(conn, args.config, args.sqllib_dir,
args.diffs, args.enable_debug_statements)
if args.wiki_data:
run_legacy_script('setup.php', '--import-wikipedia-articles',

View File

@@ -62,13 +62,12 @@ class UpdateReplication:
from ..tools import replication, refresh
LOG.warning("Initialising replication updates")
conn = connect(args.config.get_libpq_dsn())
replication.init_replication(conn, base_url=args.config.REPLICATION_URL)
if args.update_functions:
LOG.warning("Create functions")
refresh.create_functions(conn, args.config, args.sqllib_dir,
True, False)
conn.close()
with connect(args.config.get_libpq_dsn()) as conn:
replication.init_replication(conn, base_url=args.config.REPLICATION_URL)
if args.update_functions:
LOG.warning("Create functions")
refresh.create_functions(conn, args.config, args.sqllib_dir,
True, False)
return 0
@@ -76,10 +75,8 @@ class UpdateReplication:
def _check_for_updates(args):
from ..tools import replication
conn = connect(args.config.get_libpq_dsn())
ret = replication.check_for_updates(conn, base_url=args.config.REPLICATION_URL)
conn.close()
return ret
with connect(args.config.get_libpq_dsn()) as conn:
return replication.check_for_updates(conn, base_url=args.config.REPLICATION_URL)
@staticmethod
def _report_update(batchdate, start_import, start_index):
@@ -122,13 +119,12 @@ class UpdateReplication:
recheck_interval = args.config.get_int('REPLICATION_RECHECK_INTERVAL')
while True:
conn = connect(args.config.get_libpq_dsn())
start = dt.datetime.now(dt.timezone.utc)
state = replication.update(conn, params)
if state is not replication.UpdateState.NO_CHANGES:
status.log_status(conn, start, 'import')
batchdate, _, _ = status.get_status(conn)
conn.close()
with connect(args.config.get_libpq_dsn()) as conn:
start = dt.datetime.now(dt.timezone.utc)
state = replication.update(conn, params)
if state is not replication.UpdateState.NO_CHANGES:
status.log_status(conn, start, 'import')
batchdate, _, _ = status.get_status(conn)
if state is not replication.UpdateState.NO_CHANGES and args.do_index:
index_start = dt.datetime.now(dt.timezone.utc)
@@ -137,10 +133,9 @@ class UpdateReplication:
indexer.index_boundaries(0, 30)
indexer.index_by_rank(0, 30)
conn = connect(args.config.get_libpq_dsn())
status.set_indexed(conn, True)
status.log_status(conn, index_start, 'index')
conn.close()
with connect(args.config.get_libpq_dsn()) as conn:
status.set_indexed(conn, True)
status.log_status(conn, index_start, 'index')
else:
index_start = None

View File

@@ -1,6 +1,7 @@
"""
Specialised connection and cursor functions.
"""
import contextlib
import logging
import psycopg2
@@ -84,9 +85,14 @@ class _Connection(psycopg2.extensions.connection):
def connect(dsn):
""" Open a connection to the database using the specialised connection
factory.
factory. The returned object may be used in conjunction with 'with'.
When used outside a context manager, use the `connection` attribute
to get the connection.
"""
try:
return psycopg2.connect(dsn, connection_factory=_Connection)
conn = psycopg2.connect(dsn, connection_factory=_Connection)
ctxmgr = contextlib.closing(conn)
ctxmgr.connection = conn
return ctxmgr
except psycopg2.OperationalError as err:
raise UsageError("Cannot connect to database: {}".format(err)) from err

View File

@@ -60,7 +60,7 @@ def check_database(config):
""" Run a number of checks on the database and return the status.
"""
try:
conn = connect(config.get_libpq_dsn())
conn = connect(config.get_libpq_dsn()).connection
except UsageError as err:
conn = _BadConnection(str(err))

View File

@@ -85,9 +85,8 @@ def temp_db_with_extensions(temp_db):
def temp_db_conn(temp_db):
""" Connection to the test database.
"""
conn = connection.connect('dbname=' + temp_db)
yield conn
conn.close()
with connection.connect('dbname=' + temp_db) as conn:
yield conn
@pytest.fixture

View File

@@ -7,9 +7,8 @@ from nominatim.db.connection import connect
@pytest.fixture
def db(temp_db):
conn = connect('dbname=' + temp_db)
yield conn
conn.close()
with connect('dbname=' + temp_db) as conn:
yield conn
def test_connection_table_exists(db, temp_db_cursor):

View File

@@ -9,9 +9,8 @@ from nominatim.tools import admin
@pytest.fixture
def db(temp_db, placex_table):
conn = connect('dbname=' + temp_db)
yield conn
conn.close()
with connect('dbname=' + temp_db) as conn:
yield conn
def test_analyse_indexing_no_objects(db):
with pytest.raises(UsageError):

View File

@@ -10,6 +10,10 @@ def test_check_database_unknown_db(def_config, monkeypatch):
assert 1 == chkdb.check_database(def_config)
def test_check_database_fatal_test(def_config, temp_db):
assert 1 == chkdb.check_database(def_config)
def test_check_conection_good(temp_db_conn, def_config):
assert chkdb.check_connection(temp_db_conn, def_config) == chkdb.CheckState.OK

View File

@@ -11,9 +11,8 @@ SQL_DIR = (Path(__file__) / '..' / '..' / '..' / 'lib-sql').resolve()
@pytest.fixture
def db(temp_db):
conn = connect('dbname=' + temp_db)
yield conn
conn.close()
with connect('dbname=' + temp_db) as conn:
yield conn
@pytest.fixture
def db_with_tables(db):