From e0bf720d2afe039cf24170b7b86f7b4e1f6185e4 Mon Sep 17 00:00:00 2001 From: tianzikang Date: Thu, 1 Sep 2022 20:06:33 +0800 Subject: [PATCH] experience reshape (batch_size * num_agents, dim) for faster parallel computing --- model.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/model.py b/model.py index 3e36ddc..9694800 100644 --- a/model.py +++ b/model.py @@ -23,6 +23,7 @@ class QMIX(nn.Module): def __init__(self, obs_size=16, state_size=32, num_agents=2, num_actions=5) -> None: super(QMIX, self).__init__() + self.obs_size = obs_size self.num_agents = num_agents self.num_actions = num_actions self.net_embed_dim = 64 @@ -77,20 +78,21 @@ def get_batch_value(self, obs): # input : obs: (batch_size, episode_limits, num_agents, obs_size), # avail_actions:(batch_size, episode_limits, num_agents, num_actions) # output: q_value: (batch_size, episode_limits, num_agents, num_actions) - [batch_size, timesteps] = obs.shape[:2] # batch_size here is actually episode_number + batch_size, timesteps, num_agents, obs_dim = obs.shape # batch_size here is actually episode_number self.init_train_rnn_hidden(episode_num=batch_size) - q_1 = F.relu(self.fc1(obs)) - rnn_value = [] + _, _, hidden_dim = self.train_rnn_hidden.shape + + q_value = [] for t in range(timesteps): - batch_value = [] - for a_id in range(self.num_agents): - batch_value.append( - self.rnn(q_1[:, t, a_id, :], self.train_rnn_hidden[:, a_id, :]) - ) - self.train_rnn_hidden = torch.stack(batch_value, dim=1) - rnn_value.append(self.train_rnn_hidden) - rnn_value = torch.stack(rnn_value, dim=1) - q_value = self.fc2(rnn_value) + # note: (batch_size, num_agents, dim) --> (batch_size*num_agents, dim) [###By tensor.reshape] + # As there is no temporal relationship among agents and nn.GRUCell can only accept 2-D data as inputs, + # so we can concatenate batch_size experiences of different agents for faster cuda parallel computing + q_1 = F.relu(self.fc1(obs[:, t].reshape(-1, obs_dim))) + rnn_value = self.rnn(q_1, self.train_rnn_hidden.reshape(-1, hidden_dim)) + q_2 = self.fc2(rnn_value) + self.train_rnn_hidden = rnn_value.reshape(batch_size, num_agents, hidden_dim) + q_value.append(q_2.reshape(batch_size, num_agents, -1)) + q_value = torch.stack(q_value, dim=1) return q_value def get_batch_total(self, max_q_value, state): @@ -255,10 +257,10 @@ def update(self): # take gradient step # compute loss: Detach variable from the current graph since we don't want gradients for next Q to propagated loss = self.MseLoss(total_current_Q_values, total_target_Q_values.detach()) + loss = loss / not_done_total.sum() # Clear previous gradients before backward pass self.optimizer.zero_grad() # run backward pass - loss = loss / not_done_total.sum() loss.backward() # grad_norm_clip: Reduce magnitude of gradients above this L2 norm nn.utils.clip_grad_norm_(self.params, self.grad_norm_clip)