Skip to content

Commit

Permalink
experience reshape (batch_size * num_agents, dim) for faster parallel…
Browse files Browse the repository at this point in the history
… computing
  • Loading branch information
Felixvillas committed Sep 1, 2022
1 parent 6a95b32 commit e0bf720
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
class QMIX(nn.Module):
def __init__(self, obs_size=16, state_size=32, num_agents=2, num_actions=5) -> None:
super(QMIX, self).__init__()
self.obs_size = obs_size
self.num_agents = num_agents
self.num_actions = num_actions
self.net_embed_dim = 64
Expand Down Expand Up @@ -77,20 +78,21 @@ def get_batch_value(self, obs):
# input : obs: (batch_size, episode_limits, num_agents, obs_size),
# avail_actions:(batch_size, episode_limits, num_agents, num_actions)
# output: q_value: (batch_size, episode_limits, num_agents, num_actions)
[batch_size, timesteps] = obs.shape[:2] # batch_size here is actually episode_number
batch_size, timesteps, num_agents, obs_dim = obs.shape # batch_size here is actually episode_number
self.init_train_rnn_hidden(episode_num=batch_size)
q_1 = F.relu(self.fc1(obs))
rnn_value = []
_, _, hidden_dim = self.train_rnn_hidden.shape

q_value = []
for t in range(timesteps):
batch_value = []
for a_id in range(self.num_agents):
batch_value.append(
self.rnn(q_1[:, t, a_id, :], self.train_rnn_hidden[:, a_id, :])
)
self.train_rnn_hidden = torch.stack(batch_value, dim=1)
rnn_value.append(self.train_rnn_hidden)
rnn_value = torch.stack(rnn_value, dim=1)
q_value = self.fc2(rnn_value)
# note: (batch_size, num_agents, dim) --> (batch_size*num_agents, dim) [###By tensor.reshape]
# As there is no temporal relationship among agents and nn.GRUCell can only accept 2-D data as inputs,
# so we can concatenate batch_size experiences of different agents for faster cuda parallel computing
q_1 = F.relu(self.fc1(obs[:, t].reshape(-1, obs_dim)))
rnn_value = self.rnn(q_1, self.train_rnn_hidden.reshape(-1, hidden_dim))
q_2 = self.fc2(rnn_value)
self.train_rnn_hidden = rnn_value.reshape(batch_size, num_agents, hidden_dim)
q_value.append(q_2.reshape(batch_size, num_agents, -1))
q_value = torch.stack(q_value, dim=1)
return q_value

def get_batch_total(self, max_q_value, state):
Expand Down Expand Up @@ -255,10 +257,10 @@ def update(self):
# take gradient step
# compute loss: Detach variable from the current graph since we don't want gradients for next Q to propagated
loss = self.MseLoss(total_current_Q_values, total_target_Q_values.detach())
loss = loss / not_done_total.sum()
# Clear previous gradients before backward pass
self.optimizer.zero_grad()
# run backward pass
loss = loss / not_done_total.sum()
loss.backward()
# grad_norm_clip: Reduce magnitude of gradients above this L2 norm
nn.utils.clip_grad_norm_(self.params, self.grad_norm_clip)
Expand Down

0 comments on commit e0bf720

Please sign in to comment.