forked from lemairecarl/hypertrainer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
worker.py
executable file
·75 lines (53 loc) · 1.98 KB
/
worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#!/usr/bin/env python
import argparse
import socket
from multiprocessing import Process
from typing import List
from redis import Redis
from rq import Connection, Worker
from rq.worker import StopRequested
# Preload libraries
# TODO import library_that_you_want_preloaded
from hypertrainer.utils import config_context
class WorkerContext:
def __init__(self, hostname, num_workers=1):
self.hostname = hostname if hostname is not None else socket.gethostname()
with config_context() as config:
redis_port = config['ht_platform']['redis_port']
self.redis_conn = Redis(port=redis_port)
self.conn = Connection(self.redis_conn)
self.worker_processes: List[Process] = []
self.num_workers = num_workers
print('Redis port:', redis_port)
def __enter__(self):
self.conn.__enter__()
self.worker_processes.append(Process(target=work, args=(self.hostname,))) # Worker specific queue
self.worker_processes += [Process(target=work, args=('jobs',)) for _ in range(self.num_workers)]
for w in self.worker_processes:
w.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.conn.__exit__(exc_type, exc_val, exc_tb)
for w in self.worker_processes:
w.terminate()
def wait(self):
for w in self.worker_processes:
w.join()
def work(queue_name):
# NOTE: Executed in a separate process. This affects print and logging.
print('Working on queue', queue_name)
w = Worker([queue_name])
try:
w.work()
except StopRequested:
print('StopRequested')
pass
def start_worker(**kwargs):
with WorkerContext(**kwargs) as c:
c.wait()
if __name__ == '__main__':
ap = argparse.ArgumentParser()
ap.add_argument('--hostname', type=str)
ap.add_argument('--workers', type=int, default=1)
args = ap.parse_args()
start_worker(hostname=args.hostname, num_workers=args.workers)