-
Notifications
You must be signed in to change notification settings - Fork 8
/
test_reproducibility.py
66 lines (52 loc) · 2.32 KB
/
test_reproducibility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import datetime
import hashlib
import os
import unittest
import pickle
from tempfile import TemporaryDirectory
from tests.utils import get_test_conf
from covid19sim.run import simulate
TEST_CONF_NAME = "base.yaml"
class ReproducibilityTests(unittest.TestCase):
config = None
def setUp(self):
self.config = get_test_conf(TEST_CONF_NAME)
self.config['COLLECT_LOGS'] = True
self.config['INTERVENTION_DAY'] = 0
self.config['INTERVENTION'] = "Tracing"
self.test_seed = 136
self.n_people = 100
self.location_start_time = datetime.datetime(2020, 2, 28, 0, 0)
self.simulation_days = 20
def test_reproducibility(self):
"""
Run three simulations to have a pair of same seed simulation and ensure we get the same output.
"""
events_logs = []
for seed in (self.test_seed, self.test_seed, self.test_seed+1):
with self.subTest(seed=seed):
with TemporaryDirectory() as d:
md5 = hashlib.md5()
outfile = os.path.join(d, "data")
city = simulate(
n_people=self.n_people,
start_time=self.location_start_time,
simulation_days=self.simulation_days,
outfile=outfile,
out_chunk_size=0,
init_fraction_sick=0.1,
seed=seed,
conf=self.config
)
from covid19sim.inference.heavy_jobs import make_human_as_message
for human in city.humans:
md5.update(pickle.dumps(make_human_as_message(human, city.global_mailbox[human.name], self.config)))
events_logs.append(md5.hexdigest())
md5sum, md5sum_same_seed, md5sum_diff_seed = events_logs
self.assertEqual(md5sum, md5sum_same_seed,
msg=f"Two simulations run with the same seed "
f"{self.test_seed} yielded different results")
self.assertNotEqual(md5sum, md5sum_diff_seed,
msg=f"Two simulations run with different seeds "
f"{self.test_seed}, {self.test_seed+1} yielded "
f"different results")