-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
90 lines (71 loc) · 2.33 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import time
import numpy as np
import fdlib
import gc
def format_duration(seconds):
""" Returns either seconds, milliseconds or microseconds depending on the value of seconds.
Displays the value to two decimal places.
"""
if seconds < 1e-3:
return "{:.2f} microseconds".format(seconds*1e6)
elif seconds < 1:
return "{:.2f} milliseconds".format(seconds*1e3)
else:
return "{:.2f} seconds".format(seconds)
test_worker = fdlib.create_worker("tcp://127.0.0.1:3042")
start_time = time.time()
prev_params = test_worker.get_parameters()
end_time = time.time()
print("Time for first call: ", format_duration(end_time - start_time))
print("Buffer length: ", len(test_worker.get_parameters()))
gc.collect()
gc.disable()
call_accumulator = 0
for _ in range(1000):
start_time = time.time()
params = test_worker.get_parameters()
end_time = time.time()
call_accumulator += end_time - start_time
if not np.equal(params, prev_params).all():
print("Parameters diverged")
prev_params = params
print("Rust time per call: {}".format(format_duration(call_accumulator / 1000)))
gc.enable()
gc.collect()
# Display histogram of params with np and matplotlib
import matplotlib.pyplot as plt
parameters = test_worker.get_parameters()
# Remove all 0.0 values from parameters
# Testing fold back distribution
parameters = parameters[parameters != 0.0]
print('Median', np.median(parameters))
print('Min, Max', np.min(parameters), np.max(parameters))
fig, ax = plt.subplots()
ax.hist(parameters, bins=100)
plt.show()
# gc.collect()
# gc.disable()
# start_time = time.time()
# for _ in range(1000):
# x = np.random.randn(1_000_000)
# end_time = time.time()
# gc.enable()
# gc.collect()
# print("Numpy time per call: {}".format(format_duration((end_time - start_time) / 1000)))
print(test_worker.get_parameters())
print(test_worker.get_parameters())
# worker = fdlib.create_worker("tcp://127.0.0.1:3042")
# worker = fdlib.create_worker("wss://127.0.0.1:3044/some/job")
# while True:
# worker.get_parameters()
# # Run environment here
# reward = 1
# worker.send_returns(reward)
# parameters = worker.get_parameters()
# print("Parameters:", parameters)
# while True:
# signal = worker.get_signal()
# if signal == "No signal":
# continue
# print(signal)
# time.sleep(0.01)