-
Notifications
You must be signed in to change notification settings - Fork 4
/
samplingthread.py
71 lines (66 loc) · 1.78 KB
/
samplingthread.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import sampling
import queue
from threading import Thread, Event
import logging
class SamplerServer():
def __init__(self, model):
self.sampler = sampling.Sampler(model)
self.queue = queue.Queue()
self.thread = Thread(target=self.threadmain_h)
self.thread.daemon = True
self.stopped = True
self.thread.start()
def stop(self):
self.queue.put(None)
self.queue.join()
self.thread.join()
self.stopped = True
def __del__(self):
if not self.stopped:
self.stop()
def threadmain_h(self):
while True:
try:
self.threadmain()
except Exception:
logging.exception("Exception in server thread")
for r in self.requests:
r.on_finish()
def threadmain(self):
requests = []
samples = []
self.requests = requests
stop = False
while True:
while (not requests) or (not self.queue.empty()):
r = self.queue.get(True)
if (r is None):
self.queue.task_done()
stop = True
break
r.run_inchain()
requests.append(r)
samples.extend(r.samples)
if (not requests) and stop:
return
if samples:
self.sampler.single_step(samples)
for s in [s for s in samples if s.finished]:
samples.remove(s)
for r in requests:
if all([rs.finished for rs in r.samples]):
r.run_outchain()
r.finished = True
r.on_finish()
requests.remove(r)
self.queue.task_done()
def run_request_sync(self, request):
evt = Event()
request.on_finish = lambda: evt.set()
self.queue.put(request)
evt.wait()
assert request.finished
return request
def run_request(self, request, on_finish):
request.on_finish = on_finish
self.queue.put(request)