mirror of
https://github.com/osm-search/Nominatim.git
synced 2026-03-11 05:14:07 +00:00
convert connect() into a context manager
This commit is contained in:
@@ -54,9 +54,8 @@ class AdminFuncs:
|
|||||||
if args.analyse_indexing:
|
if args.analyse_indexing:
|
||||||
LOG.warning('Analysing performance of indexing function')
|
LOG.warning('Analysing performance of indexing function')
|
||||||
from ..tools import admin
|
from ..tools import admin
|
||||||
conn = connect(args.config.get_libpq_dsn())
|
with connect(args.config.get_libpq_dsn()) as conn:
|
||||||
admin.analyse_indexing(conn, osm_id=args.osm_id, place_id=args.place_id)
|
admin.analyse_indexing(conn, osm_id=args.osm_id, place_id=args.place_id)
|
||||||
conn.close()
|
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|||||||
@@ -29,9 +29,8 @@ class SetupFreeze:
|
|||||||
def run(args):
|
def run(args):
|
||||||
from ..tools import freeze
|
from ..tools import freeze
|
||||||
|
|
||||||
conn = connect(args.config.get_libpq_dsn())
|
with connect(args.config.get_libpq_dsn()) as conn:
|
||||||
freeze.drop_update_tables(conn)
|
freeze.drop_update_tables(conn)
|
||||||
freeze.drop_flatnode_file(args.config.FLATNODE_FILE)
|
freeze.drop_flatnode_file(args.config.FLATNODE_FILE)
|
||||||
conn.close()
|
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -51,8 +51,7 @@ class UpdateIndex:
|
|||||||
|
|
||||||
if not args.no_boundaries and not args.boundaries_only \
|
if not args.no_boundaries and not args.boundaries_only \
|
||||||
and args.minrank == 0 and args.maxrank == 30:
|
and args.minrank == 0 and args.maxrank == 30:
|
||||||
conn = connect(args.config.get_libpq_dsn())
|
with connect(args.config.get_libpq_dsn()) as conn:
|
||||||
status.set_indexed(conn, True)
|
status.set_indexed(conn, True)
|
||||||
conn.close()
|
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -50,29 +50,25 @@ class UpdateRefresh:
|
|||||||
|
|
||||||
if args.postcodes:
|
if args.postcodes:
|
||||||
LOG.warning("Update postcodes centroid")
|
LOG.warning("Update postcodes centroid")
|
||||||
conn = connect(args.config.get_libpq_dsn())
|
with connect(args.config.get_libpq_dsn()) as conn:
|
||||||
refresh.update_postcodes(conn, args.sqllib_dir)
|
refresh.update_postcodes(conn, args.sqllib_dir)
|
||||||
conn.close()
|
|
||||||
|
|
||||||
if args.word_counts:
|
if args.word_counts:
|
||||||
LOG.warning('Recompute frequency of full-word search terms')
|
LOG.warning('Recompute frequency of full-word search terms')
|
||||||
conn = connect(args.config.get_libpq_dsn())
|
with connect(args.config.get_libpq_dsn()) as conn:
|
||||||
refresh.recompute_word_counts(conn, args.sqllib_dir)
|
refresh.recompute_word_counts(conn, args.sqllib_dir)
|
||||||
conn.close()
|
|
||||||
|
|
||||||
if args.address_levels:
|
if args.address_levels:
|
||||||
cfg = Path(args.config.ADDRESS_LEVEL_CONFIG)
|
cfg = Path(args.config.ADDRESS_LEVEL_CONFIG)
|
||||||
LOG.warning('Updating address levels from %s', cfg)
|
LOG.warning('Updating address levels from %s', cfg)
|
||||||
conn = connect(args.config.get_libpq_dsn())
|
with connect(args.config.get_libpq_dsn()) as conn:
|
||||||
refresh.load_address_levels_from_file(conn, cfg)
|
refresh.load_address_levels_from_file(conn, cfg)
|
||||||
conn.close()
|
|
||||||
|
|
||||||
if args.functions:
|
if args.functions:
|
||||||
LOG.warning('Create functions')
|
LOG.warning('Create functions')
|
||||||
conn = connect(args.config.get_libpq_dsn())
|
with connect(args.config.get_libpq_dsn()) as conn:
|
||||||
refresh.create_functions(conn, args.config, args.sqllib_dir,
|
refresh.create_functions(conn, args.config, args.sqllib_dir,
|
||||||
args.diffs, args.enable_debug_statements)
|
args.diffs, args.enable_debug_statements)
|
||||||
conn.close()
|
|
||||||
|
|
||||||
if args.wiki_data:
|
if args.wiki_data:
|
||||||
run_legacy_script('setup.php', '--import-wikipedia-articles',
|
run_legacy_script('setup.php', '--import-wikipedia-articles',
|
||||||
|
|||||||
@@ -62,13 +62,12 @@ class UpdateReplication:
|
|||||||
from ..tools import replication, refresh
|
from ..tools import replication, refresh
|
||||||
|
|
||||||
LOG.warning("Initialising replication updates")
|
LOG.warning("Initialising replication updates")
|
||||||
conn = connect(args.config.get_libpq_dsn())
|
with connect(args.config.get_libpq_dsn()) as conn:
|
||||||
replication.init_replication(conn, base_url=args.config.REPLICATION_URL)
|
replication.init_replication(conn, base_url=args.config.REPLICATION_URL)
|
||||||
if args.update_functions:
|
if args.update_functions:
|
||||||
LOG.warning("Create functions")
|
LOG.warning("Create functions")
|
||||||
refresh.create_functions(conn, args.config, args.sqllib_dir,
|
refresh.create_functions(conn, args.config, args.sqllib_dir,
|
||||||
True, False)
|
True, False)
|
||||||
conn.close()
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
@@ -76,10 +75,8 @@ class UpdateReplication:
|
|||||||
def _check_for_updates(args):
|
def _check_for_updates(args):
|
||||||
from ..tools import replication
|
from ..tools import replication
|
||||||
|
|
||||||
conn = connect(args.config.get_libpq_dsn())
|
with connect(args.config.get_libpq_dsn()) as conn:
|
||||||
ret = replication.check_for_updates(conn, base_url=args.config.REPLICATION_URL)
|
return replication.check_for_updates(conn, base_url=args.config.REPLICATION_URL)
|
||||||
conn.close()
|
|
||||||
return ret
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _report_update(batchdate, start_import, start_index):
|
def _report_update(batchdate, start_import, start_index):
|
||||||
@@ -122,13 +119,12 @@ class UpdateReplication:
|
|||||||
recheck_interval = args.config.get_int('REPLICATION_RECHECK_INTERVAL')
|
recheck_interval = args.config.get_int('REPLICATION_RECHECK_INTERVAL')
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
conn = connect(args.config.get_libpq_dsn())
|
with connect(args.config.get_libpq_dsn()) as conn:
|
||||||
start = dt.datetime.now(dt.timezone.utc)
|
start = dt.datetime.now(dt.timezone.utc)
|
||||||
state = replication.update(conn, params)
|
state = replication.update(conn, params)
|
||||||
if state is not replication.UpdateState.NO_CHANGES:
|
if state is not replication.UpdateState.NO_CHANGES:
|
||||||
status.log_status(conn, start, 'import')
|
status.log_status(conn, start, 'import')
|
||||||
batchdate, _, _ = status.get_status(conn)
|
batchdate, _, _ = status.get_status(conn)
|
||||||
conn.close()
|
|
||||||
|
|
||||||
if state is not replication.UpdateState.NO_CHANGES and args.do_index:
|
if state is not replication.UpdateState.NO_CHANGES and args.do_index:
|
||||||
index_start = dt.datetime.now(dt.timezone.utc)
|
index_start = dt.datetime.now(dt.timezone.utc)
|
||||||
@@ -137,10 +133,9 @@ class UpdateReplication:
|
|||||||
indexer.index_boundaries(0, 30)
|
indexer.index_boundaries(0, 30)
|
||||||
indexer.index_by_rank(0, 30)
|
indexer.index_by_rank(0, 30)
|
||||||
|
|
||||||
conn = connect(args.config.get_libpq_dsn())
|
with connect(args.config.get_libpq_dsn()) as conn:
|
||||||
status.set_indexed(conn, True)
|
status.set_indexed(conn, True)
|
||||||
status.log_status(conn, index_start, 'index')
|
status.log_status(conn, index_start, 'index')
|
||||||
conn.close()
|
|
||||||
else:
|
else:
|
||||||
index_start = None
|
index_start = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Specialised connection and cursor functions.
|
Specialised connection and cursor functions.
|
||||||
"""
|
"""
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import psycopg2
|
import psycopg2
|
||||||
@@ -84,9 +85,14 @@ class _Connection(psycopg2.extensions.connection):
|
|||||||
|
|
||||||
def connect(dsn):
|
def connect(dsn):
|
||||||
""" Open a connection to the database using the specialised connection
|
""" Open a connection to the database using the specialised connection
|
||||||
factory.
|
factory. The returned object may be used in conjunction with 'with'.
|
||||||
|
When used outside a context manager, use the `connection` attribute
|
||||||
|
to get the connection.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return psycopg2.connect(dsn, connection_factory=_Connection)
|
conn = psycopg2.connect(dsn, connection_factory=_Connection)
|
||||||
|
ctxmgr = contextlib.closing(conn)
|
||||||
|
ctxmgr.connection = conn
|
||||||
|
return ctxmgr
|
||||||
except psycopg2.OperationalError as err:
|
except psycopg2.OperationalError as err:
|
||||||
raise UsageError("Cannot connect to database: {}".format(err)) from err
|
raise UsageError("Cannot connect to database: {}".format(err)) from err
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ def check_database(config):
|
|||||||
""" Run a number of checks on the database and return the status.
|
""" Run a number of checks on the database and return the status.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
conn = connect(config.get_libpq_dsn())
|
conn = connect(config.get_libpq_dsn()).connection
|
||||||
except UsageError as err:
|
except UsageError as err:
|
||||||
conn = _BadConnection(str(err))
|
conn = _BadConnection(str(err))
|
||||||
|
|
||||||
|
|||||||
@@ -85,9 +85,8 @@ def temp_db_with_extensions(temp_db):
|
|||||||
def temp_db_conn(temp_db):
|
def temp_db_conn(temp_db):
|
||||||
""" Connection to the test database.
|
""" Connection to the test database.
|
||||||
"""
|
"""
|
||||||
conn = connection.connect('dbname=' + temp_db)
|
with connection.connect('dbname=' + temp_db) as conn:
|
||||||
yield conn
|
yield conn
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
@@ -7,9 +7,8 @@ from nominatim.db.connection import connect
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def db(temp_db):
|
def db(temp_db):
|
||||||
conn = connect('dbname=' + temp_db)
|
with connect('dbname=' + temp_db) as conn:
|
||||||
yield conn
|
yield conn
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
def test_connection_table_exists(db, temp_db_cursor):
|
def test_connection_table_exists(db, temp_db_cursor):
|
||||||
|
|||||||
@@ -9,9 +9,8 @@ from nominatim.tools import admin
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def db(temp_db, placex_table):
|
def db(temp_db, placex_table):
|
||||||
conn = connect('dbname=' + temp_db)
|
with connect('dbname=' + temp_db) as conn:
|
||||||
yield conn
|
yield conn
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def test_analyse_indexing_no_objects(db):
|
def test_analyse_indexing_no_objects(db):
|
||||||
with pytest.raises(UsageError):
|
with pytest.raises(UsageError):
|
||||||
|
|||||||
@@ -10,6 +10,10 @@ def test_check_database_unknown_db(def_config, monkeypatch):
|
|||||||
assert 1 == chkdb.check_database(def_config)
|
assert 1 == chkdb.check_database(def_config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_database_fatal_test(def_config, temp_db):
|
||||||
|
assert 1 == chkdb.check_database(def_config)
|
||||||
|
|
||||||
|
|
||||||
def test_check_conection_good(temp_db_conn, def_config):
|
def test_check_conection_good(temp_db_conn, def_config):
|
||||||
assert chkdb.check_connection(temp_db_conn, def_config) == chkdb.CheckState.OK
|
assert chkdb.check_connection(temp_db_conn, def_config) == chkdb.CheckState.OK
|
||||||
|
|
||||||
|
|||||||
@@ -11,9 +11,8 @@ SQL_DIR = (Path(__file__) / '..' / '..' / '..' / 'lib-sql').resolve()
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def db(temp_db):
|
def db(temp_db):
|
||||||
conn = connect('dbname=' + temp_db)
|
with connect('dbname=' + temp_db) as conn:
|
||||||
yield conn
|
yield conn
|
||||||
conn.close()
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def db_with_tables(db):
|
def db_with_tables(db):
|
||||||
|
|||||||
Reference in New Issue
Block a user