-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
84 lines (57 loc) · 3 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from collections import defaultdict
from env.portfolio_env import PortfolioEnv
import util
from env.util import plot_portfolio
import imageio
def test(data, agent, log_interval_steps, log_comet, experiment, visualize_portfolio=False):
# TODO: keep learning during test?
num_days = data.shape[0]
env = PortfolioEnv(data)
agent.is_training = False
current_state = env.reset()
results = defaultdict(list) # for logging
if visualize_portfolio:
holdings_imgs = []
for t in range(num_days - 1):
# regular training. Let agent select action based on observation
current_action = agent.select_action(current_state)
# execute action on environment, observe new state and reward
next_state, current_reward, done, _ = env.step(current_action)
if visualize_portfolio:
portfolio_img = plot_portfolio(env.portfolio, env.total_gains, title='day-{}'.format(t + 1))
holdings_imgs.append(portfolio_img)
# logging
results['reward'].append(current_reward)
results['current_volatility'].append(env.current_volatility)
results['current_gains'].append(env.current_gains)
if t % log_interval_steps == 0:
avg_reward = util.avg_results(results, 'reward', lookback=log_interval_steps)
avg_vol = util.avg_results(results, 'current_volatility', lookback=log_interval_steps)
avg_gains = util.avg_results(results, 'current_gains', lookback=log_interval_steps)
total_gains = env.total_gains
print('Test: step: %d | avg_reward: {:.5f} | avg vol: {:.2f} | avg_step_gains: {:.2f} | total_gains: {:.2f}'
.format(t, avg_reward, avg_vol, avg_gains, total_gains))
env.render()
if log_comet:
experiment.log_metric('test_interval_reward', avg_reward, step=t)
experiment.log_metric('test_interval_avg_vol', avg_vol, step=t)
experiment.log_metric('test_interval_avg_gains', avg_gains, step=t)
experiment.log_metric('test_interval_total_gains', total_gains, step=t)
current_state = next_state
if visualize_portfolio:
imageio.mimwrite('test_holdings.gif', holdings_imgs)
# logging
avg_reward = util.avg_results(results, 'reward')
avg_vol = util.avg_results(results, 'current_volatility')
avg_gains = util.avg_results(results, 'current_gains')
total_gains = env.total_gains
print('Test final results - reward: {:.2f} | avg vol: {:.2f} |avg_gains: {:.2f} | total_gains: {:.2f}'
.format(avg_reward, avg_vol, avg_gains, total_gains))
if log_comet:
experiment.log_metric('test_final_avg_reward', avg_reward)
experiment.log_metric('test_final_avg_vol', avg_vol)
experiment.log_metric('test_final_avg_gains', avg_gains)
experiment.log_metric('test_final_total_gains', total_gains)
if visualize_portfolio:
experiment.log_image('test_holdings.gif', 'test_holdings')
env.render()