-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcontroller.py
229 lines (168 loc) · 7.42 KB
/
controller.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
from policy import ActorNet, CriticNet
from collections import deque
from common_utils import noop
import torch.optim as optim
import random
import torch
###########################
# Begin hyper-parameters. #
###########################
ACTOR_LR = 0.001
CRITIC_LR = 0.001
REPLAY_BUFFER_SIZE = 100000
ROLLOUT_LEN = 1
BATCH_SIZE = 64
EPSILON = 0.1
GAMMA = 0.99
TAU = 0.001
GRADIENT_CLIP = 1
LEARN_EVERY = 1
#########################
# End hyper-parameters. #
#########################
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class ReplayBuffer():
def __init__(self, size, batch_size, debug=noop):
self.replay_buffer = deque(maxlen=size)
self.batch_size = batch_size
self.debug = debug
def get_len(self):
return len(self.replay_buffer)
def can_sample(self):
return self.get_len() >= self.batch_size
def add(self, trajectories):
self.replay_buffer += trajectories
def sample(self):
return random.sample(self.replay_buffer, self.batch_size)
class Controller():
def __init__(self, state_size, action_size, debug=noop, seed=1337):
random.seed(seed)
self.n_atoms = 51
self.v_min = -0.1
self.v_max = 0.1
self.delta = (self.v_max - self.v_min) / (self.n_atoms - 1)
self.v_lin = torch.linspace(self.v_min, self.v_max, self.n_atoms).to(device)
self.actor = ActorNet(state_size, action_size, seed).to(device)
self.critic = CriticNet(state_size, action_size, self.n_atoms, seed).to(device)
self.target_actor = ActorNet(state_size, action_size, seed).to(device)
self.target_critic = CriticNet(state_size, action_size, self.n_atoms, seed).to(device)
# Initialize target networks.
self.soft_update_target_nets(1)
self.actor_opt = optim.Adam(self.actor.parameters(), lr=ACTOR_LR)
self.critic_opt = optim.Adam(self.critic.parameters(), lr=CRITIC_LR)
self.replay_buffer = ReplayBuffer(REPLAY_BUFFER_SIZE, BATCH_SIZE, debug=debug)
self.trajectories = deque(maxlen=ROLLOUT_LEN)
self.rollout_len = ROLLOUT_LEN
self.epsilon = EPSILON
self.action_size = action_size
self.debug = debug
self.num_steps = 0
def reset(self):
self.trajectories = deque(maxlen=self.rollout_len)
def act(self, states, test=False):
states = states.to(device)
if self.contains_zero_state(states):
return {
"actions": torch.zeros(states.shape[0], self.action_size)
}
with self.actor.eval_no_grad():
actions = self.actor(states)
if not test:
noise = torch.distributions.Normal(torch.zeros(actions.size()), 1)
actions += self.epsilon * noise.sample().to(device)
actions = actions.clamp(-1, 1)
return {
"actions": actions,
}
def step(self, transitions):
if self.num_steps % LEARN_EVERY == 0 and self.replay_buffer.can_sample():
self.learn()
self.num_steps += 1
if self.contains_zero_state(transitions["states"]):
return
exp_tuples = self.convert_to_exp_tuples(transitions)
self.trajectories.append(exp_tuples)
if len(self.trajectories) < self.rollout_len:
return
self.replay_buffer.add(list(zip(*self.trajectories)))
def contains_zero_state(self, states):
return (states.abs().sum(-1) == 0).sum().item() > 0
def learn(self):
samples = self.replay_buffer.sample()
trajectories = [self.convert_to_transitions(exp_tuples) for exp_tuples in zip(*samples)]
first_transitions = trajectories[0]
first_states = first_transitions["states"]
first_actions = first_transitions["actions"]
local_q_dists = self.critic(first_states, first_actions)
projected_target_q_dists = self.get_projected_target_q_dists(trajectories)
total_critic_loss = -(torch.log(local_q_dists + 1e-10) * projected_target_q_dists).sum(dim=-1).mean()
self.critic_opt.zero_grad()
total_critic_loss.backward()
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), GRADIENT_CLIP)
self.critic_opt.step()
actions = self.actor(first_states)
q_dists = self.critic(first_states, actions)
total_actor_loss = -q_dists.matmul(self.v_lin).mean()
self.actor_opt.zero_grad()
total_actor_loss.backward()
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), GRADIENT_CLIP)
self.actor_opt.step()
self.soft_update_target_nets(TAU)
def soft_update_target_nets(self, tau):
self.soft_update_target_net(self.actor, self.target_actor, tau)
self.soft_update_target_net(self.critic, self.target_critic, tau)
def soft_update_target_net(self, local_net, target_net, tau):
for target_params, local_params in zip(
target_net.parameters(),
local_net.parameters(),
):
target_params.data.copy_(
tau * local_params.data + \
(1 - tau) * target_params.data
)
def convert_to_exp_tuples(self, transitions):
return list(zip(
transitions["states"].cpu(),
transitions["actions"].cpu(),
transitions["rewards"].cpu(),
transitions["next_states"].cpu(),
transitions["dones"].cpu(),
))
def convert_to_transitions(self, exp_tuples):
states, actions, rewards, next_states, dones = zip(*exp_tuples)
return {
"states": torch.stack(states).to(device),
"actions": torch.stack(actions).to(device),
"rewards": torch.stack(rewards).to(device),
"next_states": torch.stack(next_states).to(device),
"dones": torch.stack(dones).to(device),
}
def get_projected_target_q_dists(self, trajectories):
N = len(trajectories)
last_transitions = trajectories[-1]
last_next_states = last_transitions["next_states"]
discounted_rewards = torch.zeros(last_transitions["rewards"].size()).to(device)
for transitions in reversed(trajectories):
discounted_rewards = transitions["rewards"] + GAMMA * discounted_rewards
discounted_rewards = discounted_rewards.squeeze()
target_actions = self.target_actor(last_next_states)
target_q_dists = self.target_critic(last_next_states, target_actions)
projected_target_q_dists = torch.zeros(target_q_dists.size()).to(device)
for j in range(self.n_atoms):
Tz_j = torch.clamp(
discounted_rewards + (GAMMA ** N) * (self.v_min + j * self.delta),
min=self.v_min,
max=self.v_max,
)
b_j = (Tz_j - self.v_min) / self.delta
l = b_j.floor().long()
u = b_j.ceil().long()
eq_mask = l == u
ne_mask = l != u
projected_target_q_dists[eq_mask, l[eq_mask]] += target_q_dists[eq_mask, j]
projected_target_q_dists[ne_mask, l[ne_mask]] += target_q_dists[ne_mask, j] * (u.float() - b_j)[ne_mask]
projected_target_q_dists[ne_mask, u[ne_mask]] += target_q_dists[ne_mask, j] * (b_j - l.float())[ne_mask]
return projected_target_q_dists.detach()
def save(self, i):
torch.save(self.actor.cpu().state_dict(), f"actor_{i}.pth")
torch.save(self.critic.cpu().state_dict(), f"critic_{i}.pth")