-
Notifications
You must be signed in to change notification settings - Fork 7
/
D3QN.py
124 lines (95 loc) · 4.83 KB
/
D3QN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch as T
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from buffer import ReplayBuffer
device = T.device("cuda:0" if T.cuda.is_available() else "cpu")
class DuelingDeepQNetwork(nn.Module):
def __init__(self, alpha, state_dim, action_dim, fc1_dim, fc2_dim):
super(DuelingDeepQNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, fc1_dim)
self.fc2 = nn.Linear(fc1_dim, fc2_dim)
self.V = nn.Linear(fc2_dim, 1)
self.A = nn.Linear(fc2_dim, action_dim)
self.optimizer = optim.Adam(self.parameters(), lr=alpha)
self.to(device)
def forward(self, state):
x = T.relu(self.fc1(state))
x = T.relu(self.fc2(x))
V = self.V(x)
A = self.A(x)
Q = V + A - T.mean(A, dim=-1, keepdim=True)
return Q
def save_checkpoint(self, checkpoint_file):
T.save(self.state_dict(), checkpoint_file)
def load_checkpoint(self, checkpoint_file):
self.load_state_dict(T.load(checkpoint_file))
class D3QN:
def __init__(self, alpha, state_dim, action_dim, fc1_dim, fc2_dim, ckpt_dir,
gamma=0.99, tau=0.005, epsilon=1.0, eps_end=0.01, eps_dec=5e-7,
max_size=1000000, batch_size=256):
self.gamma = gamma
self.tau = tau
self.epsilon = epsilon
self.eps_min = eps_end
self.eps_dec = eps_dec
self.batch_size = batch_size
self.checkpoint_dir = ckpt_dir
self.action_space = [i for i in range(action_dim)]
self.q_eval = DuelingDeepQNetwork(alpha=alpha, state_dim=state_dim, action_dim=action_dim,
fc1_dim=fc1_dim, fc2_dim=fc2_dim)
self.q_target = DuelingDeepQNetwork(alpha=alpha, state_dim=state_dim, action_dim=action_dim,
fc1_dim=fc1_dim, fc2_dim=fc2_dim)
self.memory = ReplayBuffer(state_dim=state_dim, action_dim=action_dim,
max_size=max_size, batch_size=batch_size)
self.update_network_parameters(tau=1.0)
def update_network_parameters(self, tau=None):
if tau is None:
tau = self.tau
for q_target_params, q_eval_params in zip(self.q_target.parameters(), self.q_eval.parameters()):
q_target_params.data.copy_(tau * q_eval_params + (1 - tau) * q_target_params)
def remember(self, state, action, reward, state_, done):
self.memory.store_transition(state, action, reward, state_, done)
def decrement_epsilon(self):
self.epsilon = self.epsilon - self.eps_dec \
if self.epsilon > self.eps_min else self.eps_min
def choose_action(self, observation, isTrain=True):
state = T.tensor([observation], dtype=T.float).to(device)
q_vals = self.q_eval.forward(state)
action = T.argmax(q_vals).item()
if (np.random.random() < self.epsilon) and isTrain:
action = np.random.choice(self.action_space)
return action
def learn(self):
if not self.memory.ready():
return
states, actions, rewards, next_states, terminals = self.memory.sample_buffer()
batch_idx = T.arange(self.batch_size, dtype=T.long).to(device)
states_tensor = T.tensor(states, dtype=T.float).to(device)
actions_tensor = T.tensor(actions, dtype=T.long).to(device)
rewards_tensor = T.tensor(rewards, dtype=T.float).to(device)
next_states_tensor = T.tensor(next_states, dtype=T.float).to(device)
terminals_tensor = T.tensor(terminals).to(device)
with T.no_grad():
q_ = self.q_target.forward(next_states_tensor)
max_actions = T.argmax(self.q_eval.forward(next_states_tensor), dim=-1)
q_[terminals_tensor] = 0.0
target = rewards_tensor + self.gamma * q_[batch_idx, max_actions]
q = self.q_eval.forward(states_tensor)[batch_idx, actions_tensor]
loss = F.mse_loss(q, target.detach())
self.q_eval.optimizer.zero_grad()
loss.backward()
self.q_eval.optimizer.step()
self.update_network_parameters()
self.decrement_epsilon()
def save_models(self, episode):
self.q_eval.save_checkpoint(self.checkpoint_dir + 'Q_eval/D3QN_q_eval_{}.pth'.format(episode))
print('Saving Q_eval network successfully!')
self.q_target.save_checkpoint(self.checkpoint_dir + 'Q_target/D3QN_Q_target_{}.pth'.format(episode))
print('Saving Q_target network successfully!')
def load_models(self, episode):
self.q_eval.load_checkpoint(self.checkpoint_dir + 'Q_eval/D3QN_q_eval_{}.pth'.format(episode))
print('Loading Q_eval network successfully!')
self.q_target.load_checkpoint(self.checkpoint_dir + 'Q_target/D3QN_Q_target_{}.pth'.format(episode))
print('Loading Q_target network successfully!')