Revert "switch to threading"

This reverts commit 8b1c2181be5aa5335c68d36a49cab9c4e2cd8bef.
This commit is contained in:
Sarah Hoffmann
2020-01-19 21:56:37 +01:00
parent 0a26ca7104
commit 6c0d6d3178

View File

@@ -30,8 +30,7 @@ import getpass
from datetime import datetime from datetime import datetime
import psycopg2 import psycopg2
from psycopg2.extras import wait_select from psycopg2.extras import wait_select
import threading import select
from queue import Queue
log = logging.getLogger() log = logging.getLogger()
@@ -40,44 +39,53 @@ def make_connection(options, asynchronous=False):
password=options.password, host=options.host, password=options.password, host=options.host,
port=options.port, async_=asynchronous) port=options.port, async_=asynchronous)
class IndexingThread(threading.Thread): class IndexingThread(object):
def __init__(self, queue, barrier, options): def __init__(self, thread_num, options):
super().__init__() log.debug("Creating thread {}".format(thread_num))
self.conn = make_connection(options) self.thread_num = thread_num
self.conn.autocommit = True self.conn = make_connection(options, asynchronous=True)
self.wait()
self.cursor = self.conn.cursor() self.cursor = self.conn.cursor()
self.perform("SET lc_messages TO 'C'") self.perform("SET lc_messages TO 'C'")
self.wait()
self.perform(InterpolationRunner.prepare()) self.perform(InterpolationRunner.prepare())
self.wait()
self.perform(RankRunner.prepare()) self.perform(RankRunner.prepare())
self.queue = queue self.wait()
self.barrier = barrier
def run(self): self.current_query = None
sql = None self.current_params = None
while True:
item = self.queue.get() def wait(self):
if item is None: wait_select(self.conn)
break self.current_query = None
elif isinstance(item, str):
sql = item
self.barrier.wait()
else:
self.perform(sql, (item,))
def perform(self, sql, args=None): def perform(self, sql, args=None):
while True: self.current_query = sql
try: self.current_params = args
self.cursor.execute(sql, args) self.cursor.execute(sql, args)
return
except psycopg2.extensions.TransactionRollbackError as e: def fileno(self):
if e.pgcode is None: return self.conn.fileno()
raise RuntimeError("Postgres exception has no error code")
if e.pgcode == '40P01': def is_done(self):
log.info("Deadlock detected, retry.") if self.current_query is None:
else: return True
raise
try:
if self.conn.poll() == psycopg2.extensions.POLL_OK:
self.current_query = None
return True
except psycopg2.extensions.TransactionRollbackError as e:
if e.pgcode is None:
raise RuntimeError("Postgres exception has no error code")
if e.pgcode == '40P01':
log.info("Deadlock detected, retry.")
self.cursor.execute(self.current_query, self.current_params)
else:
raise
@@ -88,12 +96,11 @@ class Indexer(object):
self.conn = make_connection(options) self.conn = make_connection(options)
self.threads = [] self.threads = []
self.queue = Queue(maxsize=1000) self.poll = select.poll()
self.barrier = threading.Barrier(options.threads + 1)
for i in range(options.threads): for i in range(options.threads):
t = IndexingThread(self.queue, self.barrier, options) t = IndexingThread(i, options)
self.threads.append(t) self.threads.append(t)
t.start() self.poll.register(t, select.EPOLLIN)
def run(self): def run(self):
log.info("Starting indexing rank ({} to {}) using {} threads".format( log.info("Starting indexing rank ({} to {}) using {} threads".format(
@@ -107,20 +114,9 @@ class Indexer(object):
self.index(InterpolationRunner()) self.index(InterpolationRunner())
self.index(RankRunner(30)) self.index(RankRunner(30))
self.queue_all(None)
for t in self.threads:
t.join()
def queue_all(self, item):
for t in self.threads:
self.queue.put(item)
def index(self, obj): def index(self, obj):
log.info("Starting {}".format(obj.name())) log.info("Starting {}".format(obj.name()))
self.queue_all(obj.sql_index_place())
self.barrier.wait()
cur = self.conn.cursor(name="main") cur = self.conn.cursor(name="main")
cur.execute(obj.sql_index_sectors()) cur.execute(obj.sql_index_sectors())
@@ -131,6 +127,7 @@ class Indexer(object):
cur.scroll(0, mode='absolute') cur.scroll(0, mode='absolute')
next_thread = self.find_free_thread()
done_tuples = 0 done_tuples = 0
rank_start_time = datetime.now() rank_start_time = datetime.now()
for r in cur: for r in cur:
@@ -149,8 +146,9 @@ class Indexer(object):
for place in pcur: for place in pcur:
place_id = place[0] place_id = place[0]
log.debug("Processing place {}".format(place_id)) log.debug("Processing place {}".format(place_id))
thread = next(next_thread)
self.queue.put(place_id) thread.perform(obj.sql_index_place(), (place_id,))
done_tuples += 1 done_tuples += 1
pcur.close() pcur.close()
@@ -160,8 +158,8 @@ class Indexer(object):
cur.close() cur.close()
self.queue_all("") for t in self.threads:
self.barrier.wait() t.wait()
rank_end_time = datetime.now() rank_end_time = datetime.now()
diff_seconds = (rank_end_time-rank_start_time).total_seconds() diff_seconds = (rank_end_time-rank_start_time).total_seconds()
@@ -170,6 +168,22 @@ class Indexer(object):
done_tuples, int(diff_seconds), done_tuples, int(diff_seconds),
done_tuples/diff_seconds, obj.name())) done_tuples/diff_seconds, obj.name()))
def find_free_thread(self):
thread_lookup = { t.fileno() : t for t in self.threads}
done_fids = [ t.fileno() for t in self.threads ]
while True:
for fid in done_fids:
thread = thread_lookup[fid]
if thread.is_done():
yield thread
else:
print("not good", fid)
done_fids = [ x[0] for x in self.poll.poll()]
assert(False, "Unreachable code")
class RankRunner(object): class RankRunner(object):