-
Notifications
You must be signed in to change notification settings - Fork 9
/
DDPG-agent.py
67 lines (55 loc) · 2.11 KB
/
DDPG-agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
@File :DDPG_agent.py
@Author :JohsuaWu1997
@Date :01/05/2020
"""
import numpy as np
import pandas as pd
import torch
from DDPG import DDPG
from market import MarketEnv
cuda = torch.device('cuda')
raw_amount = pd.read_csv('../sh000016/i_amount.csv', header=0, index_col=0).values
raw_buy = pd.read_csv('../sh000016/o_buy.csv', header=0, index_col=0).values
raw_sell = pd.read_csv('../sh000016/o_sell.csv', header=0, index_col=0).values
START = 10441
END = 13899
def scale(data):
data_min = np.min(data, axis=0)
data_max = np.max(data, axis=0)
data_max[data_max - data_min == 0] = 1
data = (data - data_min) / (data_max - data_min)
return data
def train(Train_Env, Epoch):
agent = DDPG(train_env, lb, node)
save_iter = [1, 2, 5, 10, 20, 30, 50, 100, 150, 200]
for t in range(Epoch):
print('epoch:', t)
state, done = Train_Env.reset(), False
while not done:
action = agent.act(state, Train_Env.portfolio)
next_state, reward, done, _ = Train_Env.step(action)
agent.perceive(state, action, reward, next_state, done)
state = next_state
if Train_Env.n_step % 300 == 299:
print(Train_Env.n_step, ':',
int(Train_Env.rewards[Train_Env.n_step]), '\t',
int(sum(Train_Env.cost)), '\t',
int(Train_Env.available_cash[Train_Env.n_step]), '\t',
agent.critic_network.loss.data
)
total_reward = Train_Env.rewards[-1]
total_cost = sum(Train_Env.cost)
print('DDPG: Evaluation Average Reward:', total_reward)
print('DDPG: Average Cost: ', total_cost)
for k in save_iter:
if t == k:
torch.save(agent.actor_network.target.state_dict(), 'DDPG_model' + str(t) + '.pth')
return agent
if __name__ == '__main__':
lb, node, epoch = 12, 1024, 201
buy_train = raw_buy[:START]
sell_train = raw_sell[:START]
amount_train = raw_amount[:START]
train_env = MarketEnv([buy_train, sell_train, amount_train], 0)
agent = train(train_env, epoch)