-
Notifications
You must be signed in to change notification settings - Fork 8
/
test_async_inference.py
49 lines (39 loc) · 1.71 KB
/
test_async_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import time
import unittest.mock
from tempfile import TemporaryDirectory
import covid19sim.inference.server_utils
import covid19sim.inference.server_bootstrap
def fake_proc_human_batch(sample, *args, **kwargs):
return sample
class InferenceServerWrapper(covid19sim.inference.server_utils.InferenceServer):
def __init__(self, **kwargs):
covid19sim.inference.server_utils.InferenceServer.__init__(self, **kwargs)
def run(self):
with unittest.mock.patch("covid19sim.inference.server_utils.proc_human_batch") as mock:
mock.side_effect = fake_proc_human_batch
return covid19sim.inference.server_utils.InferenceServer.run(self)
class ServerTests(unittest.TestCase):
def test_inference(self):
with TemporaryDirectory() as d:
frontend_address = "ipc://" + os.path.join(d, "frontend.ipc")
backend_address = "ipc://" + os.path.join(d, "backend.ipc")
inference_server = InferenceServerWrapper(
model_exp_path=covid19sim.inference.server_bootstrap.default_model_exp_path,
workers=2,
frontend_address=frontend_address,
backend_address=backend_address,
verbose=True,
)
inference_server.start()
time.sleep(10)
remote_engine = covid19sim.inference.server_utils.InferenceClient(
server_address=frontend_address,
)
for test_idx in range(100):
remote_output = remote_engine.infer(test_idx)
assert remote_output == test_idx
inference_server.stop_gracefully()
inference_server.join()
if __name__ == "__main__":
unittest.main()