mirror of
https://github.com/osm-search/Nominatim.git
synced 2026-03-08 02:54:08 +00:00
switch to threading
This commit is contained in:
@@ -30,7 +30,8 @@ import getpass
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import psycopg2
|
import psycopg2
|
||||||
from psycopg2.extras import wait_select
|
from psycopg2.extras import wait_select
|
||||||
import select
|
import threading
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
log = logging.getLogger()
|
log = logging.getLogger()
|
||||||
|
|
||||||
@@ -39,53 +40,44 @@ def make_connection(options, asynchronous=False):
|
|||||||
password=options.password, host=options.host,
|
password=options.password, host=options.host,
|
||||||
port=options.port, async_=asynchronous)
|
port=options.port, async_=asynchronous)
|
||||||
|
|
||||||
class IndexingThread(object):
|
class IndexingThread(threading.Thread):
|
||||||
|
|
||||||
def __init__(self, thread_num, options):
|
def __init__(self, queue, barrier, options):
|
||||||
log.debug("Creating thread {}".format(thread_num))
|
super().__init__()
|
||||||
self.thread_num = thread_num
|
self.conn = make_connection(options)
|
||||||
self.conn = make_connection(options, asynchronous=True)
|
self.conn.autocommit = True
|
||||||
self.wait()
|
|
||||||
|
|
||||||
self.cursor = self.conn.cursor()
|
self.cursor = self.conn.cursor()
|
||||||
self.perform("SET lc_messages TO 'C'")
|
self.perform("SET lc_messages TO 'C'")
|
||||||
self.wait()
|
|
||||||
self.perform(InterpolationRunner.prepare())
|
self.perform(InterpolationRunner.prepare())
|
||||||
self.wait()
|
|
||||||
self.perform(RankRunner.prepare())
|
self.perform(RankRunner.prepare())
|
||||||
self.wait()
|
self.queue = queue
|
||||||
|
self.barrier = barrier
|
||||||
|
|
||||||
self.current_query = None
|
def run(self):
|
||||||
self.current_params = None
|
sql = None
|
||||||
|
while True:
|
||||||
def wait(self):
|
item = self.queue.get()
|
||||||
wait_select(self.conn)
|
if item is None:
|
||||||
self.current_query = None
|
break
|
||||||
|
elif isinstance(item, str):
|
||||||
|
sql = item
|
||||||
|
self.barrier.wait()
|
||||||
|
else:
|
||||||
|
self.perform(sql, (item,))
|
||||||
|
|
||||||
def perform(self, sql, args=None):
|
def perform(self, sql, args=None):
|
||||||
self.current_query = sql
|
while True:
|
||||||
self.current_params = args
|
try:
|
||||||
self.cursor.execute(sql, args)
|
self.cursor.execute(sql, args)
|
||||||
|
return
|
||||||
def fileno(self):
|
except psycopg2.extensions.TransactionRollbackError as e:
|
||||||
return self.conn.fileno()
|
if e.pgcode is None:
|
||||||
|
raise RuntimeError("Postgres exception has no error code")
|
||||||
def is_done(self):
|
if e.pgcode == '40P01':
|
||||||
if self.current_query is None:
|
log.info("Deadlock detected, retry.")
|
||||||
return True
|
else:
|
||||||
|
raise
|
||||||
try:
|
|
||||||
if self.conn.poll() == psycopg2.extensions.POLL_OK:
|
|
||||||
self.current_query = None
|
|
||||||
return True
|
|
||||||
except psycopg2.extensions.TransactionRollbackError as e:
|
|
||||||
if e.pgcode is None:
|
|
||||||
raise RuntimeError("Postgres exception has no error code")
|
|
||||||
if e.pgcode == '40P01':
|
|
||||||
log.info("Deadlock detected, retry.")
|
|
||||||
self.cursor.execute(self.current_query, self.current_params)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -96,11 +88,12 @@ class Indexer(object):
|
|||||||
self.conn = make_connection(options)
|
self.conn = make_connection(options)
|
||||||
|
|
||||||
self.threads = []
|
self.threads = []
|
||||||
self.poll = select.poll()
|
self.queue = Queue(maxsize=1000)
|
||||||
|
self.barrier = threading.Barrier(options.threads + 1)
|
||||||
for i in range(options.threads):
|
for i in range(options.threads):
|
||||||
t = IndexingThread(i, options)
|
t = IndexingThread(self.queue, self.barrier, options)
|
||||||
self.threads.append(t)
|
self.threads.append(t)
|
||||||
self.poll.register(t, select.EPOLLIN)
|
t.start()
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
log.info("Starting indexing rank ({} to {}) using {} threads".format(
|
log.info("Starting indexing rank ({} to {}) using {} threads".format(
|
||||||
@@ -114,9 +107,20 @@ class Indexer(object):
|
|||||||
self.index(InterpolationRunner())
|
self.index(InterpolationRunner())
|
||||||
self.index(RankRunner(30))
|
self.index(RankRunner(30))
|
||||||
|
|
||||||
|
self.queue_all(None)
|
||||||
|
for t in self.threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
def queue_all(self, item):
|
||||||
|
for t in self.threads:
|
||||||
|
self.queue.put(item)
|
||||||
|
|
||||||
def index(self, obj):
|
def index(self, obj):
|
||||||
log.info("Starting {}".format(obj.name()))
|
log.info("Starting {}".format(obj.name()))
|
||||||
|
|
||||||
|
self.queue_all(obj.sql_index_place())
|
||||||
|
self.barrier.wait()
|
||||||
|
|
||||||
cur = self.conn.cursor(name="main")
|
cur = self.conn.cursor(name="main")
|
||||||
cur.execute(obj.sql_index_sectors())
|
cur.execute(obj.sql_index_sectors())
|
||||||
|
|
||||||
@@ -127,7 +131,6 @@ class Indexer(object):
|
|||||||
|
|
||||||
cur.scroll(0, mode='absolute')
|
cur.scroll(0, mode='absolute')
|
||||||
|
|
||||||
next_thread = self.find_free_thread()
|
|
||||||
done_tuples = 0
|
done_tuples = 0
|
||||||
rank_start_time = datetime.now()
|
rank_start_time = datetime.now()
|
||||||
for r in cur:
|
for r in cur:
|
||||||
@@ -146,9 +149,8 @@ class Indexer(object):
|
|||||||
for place in pcur:
|
for place in pcur:
|
||||||
place_id = place[0]
|
place_id = place[0]
|
||||||
log.debug("Processing place {}".format(place_id))
|
log.debug("Processing place {}".format(place_id))
|
||||||
thread = next(next_thread)
|
|
||||||
|
|
||||||
thread.perform(obj.sql_index_place(), (place_id,))
|
self.queue.put(place_id)
|
||||||
done_tuples += 1
|
done_tuples += 1
|
||||||
|
|
||||||
pcur.close()
|
pcur.close()
|
||||||
@@ -158,8 +160,8 @@ class Indexer(object):
|
|||||||
|
|
||||||
cur.close()
|
cur.close()
|
||||||
|
|
||||||
for t in self.threads:
|
self.queue_all("")
|
||||||
t.wait()
|
self.barrier.wait()
|
||||||
|
|
||||||
rank_end_time = datetime.now()
|
rank_end_time = datetime.now()
|
||||||
diff_seconds = (rank_end_time-rank_start_time).total_seconds()
|
diff_seconds = (rank_end_time-rank_start_time).total_seconds()
|
||||||
@@ -168,22 +170,6 @@ class Indexer(object):
|
|||||||
done_tuples, int(diff_seconds),
|
done_tuples, int(diff_seconds),
|
||||||
done_tuples/diff_seconds, obj.name()))
|
done_tuples/diff_seconds, obj.name()))
|
||||||
|
|
||||||
def find_free_thread(self):
|
|
||||||
thread_lookup = { t.fileno() : t for t in self.threads}
|
|
||||||
|
|
||||||
done_fids = [ t.fileno() for t in self.threads ]
|
|
||||||
|
|
||||||
while True:
|
|
||||||
for fid in done_fids:
|
|
||||||
thread = thread_lookup[fid]
|
|
||||||
if thread.is_done():
|
|
||||||
yield thread
|
|
||||||
else:
|
|
||||||
print("not good", fid)
|
|
||||||
|
|
||||||
done_fids = [ x[0] for x in self.poll.poll()]
|
|
||||||
|
|
||||||
assert(False, "Unreachable code")
|
|
||||||
|
|
||||||
class RankRunner(object):
|
class RankRunner(object):
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user