Skip to content

Commit

Permalink
now performance is close to pymarl
Browse files Browse the repository at this point in the history
  • Loading branch information
Felixvillas committed Sep 14, 2022
1 parent 5728728 commit b9d2faf
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 22 deletions.
3 changes: 3 additions & 0 deletions learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,6 @@ def qmix_learning(

writer.close()
env.close()

# last model save
QMIX_agent.save(checkpoint_path=os.path.join(log_dir, 'agent.pth'))
24 changes: 10 additions & 14 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,41 +209,37 @@ def select_actions(self, obs, avail_actions, random_selection):

def update(self):
'''update Q: 1 step of gradient descent'''
obs_batchs, act_batchs, _, \
total_obs_batch, total_rew_batch, total_done_batch, \
next_obs_batchs, next_avail_act_batchs, next_total_obs_batch = \
self.replay_buffer.sample()
obs_batchs, act_batchs, avail_act_batchs, \
total_obs_batch, total_rew_batch, total_done_batch, = self.replay_buffer.sample()

# Convert numpy nd_array to torch tensor for calculation
# every agent's experience
obs_batchs = torch.as_tensor(obs_batchs, dtype=torch.float32, device=device)
act_batchs = torch.as_tensor(act_batchs, dtype=torch.int64, device=device)
avail_act_batchs = torch.as_tensor(avail_act_batchs, dtype=torch.float32, device=device)
total_obs_batch = torch.as_tensor(total_obs_batch, dtype=torch.float32, device=device)
total_rew_batch = torch.as_tensor(total_rew_batch, dtype=torch.float32, device=device)
not_done_total = torch.as_tensor(1 - total_done_batch, dtype=torch.float32, device=device)
next_obs_batchs = torch.as_tensor(next_obs_batchs, dtype=torch.float32, device=device)
next_avail_act_batchs = torch.as_tensor(next_avail_act_batchs, dtype=torch.bool, device=device)
next_total_obs_batch = torch.as_tensor(next_total_obs_batch, dtype=torch.float32, device=device)

# We choose Q based on action taken.
all_current_Q_values = self.Q.get_batch_value(obs_batchs)
current_Q_values = all_current_Q_values[:, :-1].gather(-1, act_batchs.unsqueeze(-1)).squeeze(-1)
total_current_Q_values = self.Q.get_batch_total(current_Q_values, total_obs_batch)
total_current_Q_values = self.Q.get_batch_total(current_Q_values, total_obs_batch[:, :-1])

# compute target
target_Q_output = self.target_Q.get_batch_value(next_obs_batchs)
target_Q_output = self.target_Q.get_batch_value(obs_batchs)[:, 1:]
# Mask out unavailable actions: refer to pymarl
target_Q_output[next_avail_act_batchs == 0.0] = -9999999
target_Q_output[avail_act_batchs[:, 1:] == 0.0] = -9999999
if self.is_ddqn:
# target_current_Q_values: get target values from current values
target_current_Q_values = all_current_Q_values.clone().detach()[:, 1:]
target_current_Q_values[next_avail_act_batchs == 0.0] = -9999999
target_act_batch = target_current_Q_values.max(-1)[1]
target_current_Q_values = all_current_Q_values.clone().detach()
target_current_Q_values[avail_act_batchs == 0.0] = -9999999
target_act_batch = target_current_Q_values[:, 1:].max(-1)[1]
target_Q_values = target_Q_output.gather(-1, target_act_batch.unsqueeze(-1)).squeeze(-1)
else:
target_Q_values = target_Q_output.max(-1)[0]

total_target_Q_values = self.target_Q.get_batch_total(target_Q_values, next_total_obs_batch)
total_target_Q_values = self.target_Q.get_batch_total(target_Q_values, total_obs_batch[:, 1:])
# mask valueless target Q values and compute the target of the current Q values
total_target_Q_values = total_rew_batch + self.gamma * not_done_total * total_target_Q_values

Expand Down
11 changes: 3 additions & 8 deletions utils/simple_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,10 @@ def sample(self):
idxes = np.random.choice(range(self.num_in_buffer), self.batch_size, replace=False).tolist()
total_obs_batch, total_rew_batch, total_done_batch, max_ep_len = self.total_buffer.sample(idxes)
obs_batchs, act_batchs, avail_act_batchs = self.buffers.sample(idxes, max_ep_len)
next_obs_batchs = obs_batchs[:, 1:]
next_avail_act_batchs = avail_act_batchs[:, 1:]
next_total_obs_batch = total_obs_batch[:, 1:]
# obs_batchs = obs_batchs[:, :-1]
act_batchs = act_batchs[:, :-1]
avail_act_batchs = avail_act_batchs[:, :-1]
total_obs_batch = total_obs_batch[:, :-1]
avail_act_batchs = avail_act_batchs
total_obs_batch = total_obs_batch
total_rew_batch = total_rew_batch[:, :-1]
total_done_batch = total_done_batch[:, :-1]
return obs_batchs, act_batchs, avail_act_batchs, \
total_obs_batch, total_rew_batch, total_done_batch, \
next_obs_batchs, next_avail_act_batchs, next_total_obs_batch
total_obs_batch, total_rew_batch, total_done_batch

0 comments on commit b9d2faf

Please sign in to comment.