Skip to content

Commit

Permalink
some optimizations refer to pymarl
Browse files Browse the repository at this point in the history
  • Loading branch information
Felixvillas committed Sep 13, 2022
1 parent cf02016 commit 5728728
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 52 deletions.
48 changes: 30 additions & 18 deletions learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,19 @@
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import random
from statistics import mean
from tqdm import *
from torch.distributions import Categorical
import torch
import datetime

from tensorboardX import SummaryWriter

def init_episode_temp(ep_limits, state_shape, num_agents, obs_dim, action_dim):
episode_obs = np.zeros((ep_limits, num_agents, obs_dim), dtype=np.float32)
episode_state = np.zeros((ep_limits, state_shape), dtype=np.float32)
episode_action = np.zeros((ep_limits, num_agents), dtype=np.int64)
episode_reward = np.zeros((ep_limits), dtype=np.float32)
episode_avail_action = np.zeros((ep_limits, num_agents, action_dim), dtype=np.float32)
episode_obs = np.zeros((ep_limits+1, num_agents, obs_dim), dtype=np.float32)
episode_state = np.zeros((ep_limits+1, state_shape), dtype=np.float32)
episode_action = np.zeros((ep_limits+1, num_agents), dtype=np.int64)
episode_reward = np.zeros((ep_limits+1), dtype=np.float32)
episode_avail_action = np.zeros((ep_limits+1, num_agents, action_dim), dtype=np.float32)
return episode_obs, episode_state, episode_action, episode_reward, episode_avail_action

def store_hyper_para(args, store_path):
Expand All @@ -43,7 +41,6 @@ def qmix_learning(
last_test_t, num_test = -args.test_freq - 1, 0
np.random.seed(args.seed)
torch.manual_seed(args.seed)
random.seed(args.seed)
# Initialize Env
env = env_class(map_name=args.map_name, seed=args.seed)
env_info = env.get_env_info()
Expand Down Expand Up @@ -101,20 +98,19 @@ def qmix_learning(
steps_queue = []
win_queue = []

# refer pymarl: in every episode, t in exploration.value(t) is consistent
t_exploration = 0

for t in tqdm(range(args.training_steps)):

# get avail action for every agent
avail_actions = env.get_avail_actions()

# Choose random action if not yet start learning else eps-greedily select actions
if t >= args.learning_starts:
random_selection = np.random.random(num_agents) < exploration.value(t-args.learning_starts)
# last_obs is a list of array that shape is (obs_shape,) --> numpy.array:(num_agents, obs_shape)
recent_observations = np.concatenate([np.expand_dims(ob, axis=0) for ob in last_obs], axis=0)
action = QMIX_agent.select_actions(recent_observations, avail_actions, random_selection)
else:
action = Categorical(torch.tensor(avail_actions)).sample()
action = [action[i].item() for i in range(num_agents)]
# eps-greedily select actions
random_selection = np.random.random(num_agents) < exploration.value(t_exploration)
# last_obs is a list of array that shape is (obs_shape,) --> numpy.array:(num_agents, obs_shape)
recent_observations = np.concatenate([np.expand_dims(ob, axis=0) for ob in last_obs], axis=0)
action = QMIX_agent.select_actions(recent_observations, avail_actions, random_selection)

# Advance one step
reward, done, info = env.step(action)
Expand All @@ -132,6 +128,20 @@ def qmix_learning(

# Resets the environment when reaching an episode boundary
if done:
'''for last experience in every episode'''
# get avail action for every agent
avail_actions = env.get_avail_actions()
# eps-greedily select actions
random_selection = np.random.random(num_agents) < exploration.value(t_exploration)
# last_obs is a list of array that shape is (obs_shape,) --> numpy.array:(num_agents, obs_shape)
recent_observations = np.concatenate([np.expand_dims(ob, axis=0) for ob in obs], axis=0)
action = QMIX_agent.select_actions(recent_observations, avail_actions, random_selection)
episode_obs[episode_len+1] = np.concatenate([np.expand_dims(ob, axis=0) for ob in obs], axis=0)
episode_state[episode_len+1] = state
episode_action[episode_len+1] = np.array(action)
episode_reward[episode_len+1] = 0
episode_avail_action[episode_len+1] = np.array(avail_actions)

# store one episode experience into buffer
episode_dict = {
'obs': episode_obs,
Expand Down Expand Up @@ -168,6 +178,8 @@ def qmix_learning(
# init para for new episide
episode_obs, episode_state, episode_action, episode_reward, episode_avail_action = \
init_episode_temp(episode_limit, state_size, num_agents, obs_size, num_actions)
# update t_exploration
t_exploration = t
else:
episode_len += 1

Expand All @@ -179,7 +191,7 @@ def qmix_learning(
QMIX_agent.increase_bate(t, args.training_steps)

# train and evaluate
if (t >= args.learning_starts and done and QMIX_agent.can_sample()):
if (done and QMIX_agent.can_sample()):
# gradient descent: train
loss = QMIX_agent.update()
num_param_update += 1
Expand Down
3 changes: 1 addition & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@ def get_args():
parser.add_argument('--map-name', type=str, default='8m')
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--training-steps', type=int, default=2005000)
parser.add_argument('--training-steps', type=int, default=2050000)
parser.add_argument('--anneal-steps', type=int, default=50000)
parser.add_argument('--anneal-start', type=float, default=1.0)
parser.add_argument('--anneal-end', type=float, default=0.05)
parser.add_argument('--replay-buffer-size', type=int, default=5000)
parser.add_argument('--learning-starts', type=int, default=0)
parser.add_argument('--target-update-freq', type=int, default=200)
parser.add_argument('--save-model-freq', type=int, default=2000)
parser.add_argument('--test-freq', type=int, default=10000)
Expand Down
31 changes: 12 additions & 19 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from torch.distributions import Categorical
import numpy as np
from utils.simple_replay_buffer import ReplayBuffer
import random

################################## set device ##################################
print("============================================================================================")
Expand Down Expand Up @@ -129,8 +128,8 @@ def get_mix_weight(self, state):
b1 = self.hyper_b1(state).unsqueeze(-2)
w2 = self.hyper_w2(state).unsqueeze(-1)
b2 = self.hyper_b2(state).unsqueeze(-1)
# return torch.abs(w1), b1, torch.abs(w2), b2
return F.softmax(w1, dim=-2), b1, F.softmax(w2, -2), b2
return torch.abs(w1), b1, torch.abs(w2), b2
# return F.softmax(w1, dim=-2), b1, F.softmax(w2, -2), b2

def init_train_rnn_hidden(self, episode_num):
# init a gru_hidden for every agent of every episode during training
Expand Down Expand Up @@ -181,8 +180,6 @@ def __init__(
elif args.optimizer == 1:
# RMSProp alpha:0.99, RMSProp epsilon:0.00001
self.optimizer = torch.optim.RMSprop(self.params, args.learning_rate, alpha=0.99, eps=1e-5)

self.MseLoss = nn.MSELoss(reduction='sum')

# Consturct buffer
self.replay_buffer = ReplayBuffer(
Expand Down Expand Up @@ -212,7 +209,7 @@ def select_actions(self, obs, avail_actions, random_selection):

def update(self):
'''update Q: 1 step of gradient descent'''
obs_batchs, act_batchs, avail_act_batchs, \
obs_batchs, act_batchs, _, \
total_obs_batch, total_rew_batch, total_done_batch, \
next_obs_batchs, next_avail_act_batchs, next_total_obs_batch = \
self.replay_buffer.sample()
Expand All @@ -221,7 +218,6 @@ def update(self):
# every agent's experience
obs_batchs = torch.as_tensor(obs_batchs, dtype=torch.float32, device=device)
act_batchs = torch.as_tensor(act_batchs, dtype=torch.int64, device=device)
avail_act_batchs = torch.as_tensor(avail_act_batchs, dtype=torch.bool, device=device)
total_obs_batch = torch.as_tensor(total_obs_batch, dtype=torch.float32, device=device)
total_rew_batch = torch.as_tensor(total_rew_batch, dtype=torch.float32, device=device)
not_done_total = torch.as_tensor(1 - total_done_batch, dtype=torch.float32, device=device)
Expand All @@ -231,24 +227,16 @@ def update(self):

# We choose Q based on action taken.
all_current_Q_values = self.Q.get_batch_value(obs_batchs)
current_Q_values = all_current_Q_values.gather(-1, act_batchs.unsqueeze(-1)).squeeze(-1)
current_Q_values = all_current_Q_values[:, :-1].gather(-1, act_batchs.unsqueeze(-1)).squeeze(-1)
total_current_Q_values = self.Q.get_batch_total(current_Q_values, total_obs_batch)
# mask valueless current Q values: In every episode, the first step is always have value
mask = torch.cat(
(torch.ones(size=(total_done_batch.shape[0], 1), dtype=torch.float32, device=device), not_done_total[:, :-1]),
dim=1
)
# mask = torch.cat((torch.ones(total_done_batch.shape[0], 1).to(device), not_done_total[:, :-1]), dim=1)
total_current_Q_values *= mask

# compute target
target_Q_output = self.target_Q.get_batch_value(next_obs_batchs)
# Mask out unavailable actions: refer to pymarl
target_Q_output[next_avail_act_batchs == 0.0] = -9999999
if self.is_ddqn:
# target_current_Q_values: get target values from current values
target_current_Q_values = torch.zeros_like(target_Q_output, dtype=torch.float32, device=device)
target_current_Q_values[:, :-1] = all_current_Q_values.clone().detach()[:, 1:]
target_current_Q_values = all_current_Q_values.clone().detach()[:, 1:]
target_current_Q_values[next_avail_act_batchs == 0.0] = -9999999
target_act_batch = target_current_Q_values.max(-1)[1]
target_Q_values = target_Q_output.gather(-1, target_act_batch.unsqueeze(-1)).squeeze(-1)
Expand All @@ -260,9 +248,14 @@ def update(self):
total_target_Q_values = total_rew_batch + self.gamma * not_done_total * total_target_Q_values

# take gradient step
# mask valueless current Q values: In every episode, the first step is always have value
mask = torch.cat(
(torch.ones(size=(total_done_batch.shape[0], 1), dtype=torch.float32, device=device), not_done_total[:, :-1]),
dim=1
)
# compute loss: Detach variable from the current graph since we don't want gradients for next Q to propagated
loss = self.MseLoss(total_current_Q_values, total_target_Q_values.detach())
loss = loss / mask.sum()
mask_td_error = (total_current_Q_values - total_target_Q_values.detach()) * mask
loss = (mask_td_error ** 2).sum() / mask.sum()
# Clear previous gradients before backward pass
self.optimizer.zero_grad()
# run backward pass
Expand Down
26 changes: 13 additions & 13 deletions utils/simple_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import random

class EpReplayBuffers:
'''
Expand Down Expand Up @@ -53,7 +52,7 @@ def store(self, ep_dict, ep_len, idx):

def sample(self, idxes):
# sample batch_size episode experience in uniform distribution
max_ep_len = max(self.ep_length[idxes])
max_ep_len = max(self.ep_length[idxes]) + 1
# get experience
obs_batch = self.obs[idxes][:, :max_ep_len]
rew_batch = self.reward[idxes][:, :max_ep_len]
Expand All @@ -76,11 +75,11 @@ def __init__(self, obs_dim, state_dim, num_agents, action_dim, ep_limits, ep_siz

self.buffers = EpReplayBuffers(
obs_dim=obs_dim, num_agents=num_agents, action_dim=action_dim,
ep_limits=ep_limits, ep_size=ep_size, multi_steps=multi_steps,
ep_limits=ep_limits+1, ep_size=ep_size, multi_steps=multi_steps,
batch_size=batch_size
)
self.total_buffer = TotalEpReplayBuffer(
obs_dim=state_dim, action_dim=action_dim, ep_limits=ep_limits,
obs_dim=state_dim, action_dim=action_dim, ep_limits=ep_limits+1,
ep_size=ep_size, multi_steps=multi_steps, batch_size=batch_size
)

Expand All @@ -90,19 +89,20 @@ def store(self, ep_dict, total_ep_dict, ep_len):

self.next_idx = (self.next_idx + 1) % self.ep_size
self.num_in_buffer = min(self.num_in_buffer + 1, self.ep_size)

def next_timestep(self, current_timestep_np):
next_timestep_np = np.zeros_like(current_timestep_np)
next_timestep_np[:, :-1] = current_timestep_np[:, 1:]
return next_timestep_np

def sample(self):
idxes = random.sample(range(self.num_in_buffer), self.batch_size)
idxes = np.random.choice(range(self.num_in_buffer), self.batch_size, replace=False).tolist()
total_obs_batch, total_rew_batch, total_done_batch, max_ep_len = self.total_buffer.sample(idxes)
obs_batchs, act_batchs, avail_act_batchs = self.buffers.sample(idxes, max_ep_len)
next_obs_batchs = self.next_timestep(obs_batchs)
next_avail_act_batchs = self.next_timestep(avail_act_batchs)
next_total_obs_batch = self.next_timestep(total_obs_batch)
next_obs_batchs = obs_batchs[:, 1:]
next_avail_act_batchs = avail_act_batchs[:, 1:]
next_total_obs_batch = total_obs_batch[:, 1:]
# obs_batchs = obs_batchs[:, :-1]
act_batchs = act_batchs[:, :-1]
avail_act_batchs = avail_act_batchs[:, :-1]
total_obs_batch = total_obs_batch[:, :-1]
total_rew_batch = total_rew_batch[:, :-1]
total_done_batch = total_done_batch[:, :-1]
return obs_batchs, act_batchs, avail_act_batchs, \
total_obs_batch, total_rew_batch, total_done_batch, \
next_obs_batchs, next_avail_act_batchs, next_total_obs_batch

0 comments on commit 5728728

Please sign in to comment.