Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
tianzikang authored and tianzikang committed Jul 7, 2022
1 parent b26defc commit e7dc4a0
Show file tree
Hide file tree
Showing 15 changed files with 1,146 additions and 0 deletions.
270 changes: 270 additions & 0 deletions learn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import random
from statistics import mean
from tqdm import *
from torch.distributions import Categorical
import torch

from tensorboardX import SummaryWriter

def init_episode_temp(ep_limits, state_shape, num_agents, obs_dim, action_dim):
episode_obs = np.zeros((ep_limits, num_agents, obs_dim), dtype=np.float32)
episode_state = np.zeros((ep_limits, state_shape), dtype=np.float32)
episode_action = np.zeros((ep_limits, num_agents), dtype=np.int64)
episode_reward = np.zeros((ep_limits), dtype=np.float32)
episode_avail_action = np.zeros((ep_limits, num_agents, action_dim), dtype=np.float32)
return episode_obs, episode_state, episode_action, episode_reward, episode_avail_action

def store_hyper_para(args, store_path):
argsDict = args.__dict__
f = open(os.path.join(store_path, 'hyper_para.txt'), 'w')
f.writelines('======================starts========================' + '\n')
for key, value in argsDict.items():
f.writelines(key + ':' + str(value) + '\n')
f.writelines('======================ends========================' + '\n')
f.close()
print('==================hyper parameters store done!==================')


def dqn_learning(
env_class,
env_id,
seed,
is_ddqn,
multi_steps,
is_per,
alpha,
beta,
prior_eps,
is_share_para,
is_evaluate,
q_func,
optimizer,
learning_rate,
exploration,
max_training_steps=1000000,
replay_buffer_size=1000000,
batch_size=32,
gamma=.99,
learning_starts=50000,
evaluate_num=4,
target_update_freq=10000,
args=None
):
'''
Parameters:
'''
env = env_class(env_id)
if is_evaluate:
env_eval = env_class(env_id)

env_info = env.get_env_info()
obs_size = env_info['obs_shape']
state_size = env_info['state_shape']
num_actions = env_info['n_actions']
num_agents = env_info['n_agents']
episode_limit = env_info['episode_limit']

# Construct tensor log writer
env_name = env_id
log_dir = f'./results/StarCraft/{env_name}/'
log_dir = log_dir + env_name
if is_ddqn:
log_dir = log_dir + '_ddqn'
if multi_steps > 1:
log_dir = log_dir + f'_{multi_steps}multisteps'
if is_per:
log_dir = log_dir + '_per'
if is_share_para:
log_dir = log_dir + '_sharepara'
log_dir = log_dir + '/'
if not os.path.exists(log_dir):
os.makedirs(log_dir)
num_results = len(next(os.walk(log_dir))[1])
log_dir = log_dir + f'{num_results}/'
writer = SummaryWriter(log_dir=log_dir)

# store hyper parameters
if args.store_hyper_para:
store_hyper_para(args, log_dir)

# Initialize QMIX_agent
QMIX_agent = q_func(
obs_size=obs_size,
state_size=state_size,
num_agents=num_agents,
num_actions=num_actions,
is_share_para=is_share_para,
is_ddqn=is_ddqn,
multi_steps=multi_steps,
is_per=is_per,
alpha=alpha,
beta=beta,
prior_eps=prior_eps,
gamma=gamma,
replay_buffer_size=replay_buffer_size,
episode_limits=episode_limit,
batch_size=batch_size,
optimizer=optimizer,
learning_rate=learning_rate
)

#############
# RUN ENV #
#############
num_param_update = 0
env.reset()
# init rnn_hidden and numpy of episode experience in the start of every episode
QMIX_agent.Q.init_eval_rnn_hidden()
episode_obs, episode_state, episode_action, episode_reward, episode_avail_action = \
init_episode_temp(episode_limit, state_size, num_agents, obs_size, num_actions)

last_obs = env.get_obs()
last_state = env.get_state()
# for episode experience
ep_rewards = []
episode_len = 0

# log paramaters
log_rewards = []
log_steps = []
log_win = []
queue_maxsize = 32
queue_cursor = 0
rewards_queue = []
steps_queue = []
win_queue = []

for t in tqdm(range(max_training_steps)):

# get avail action for every agent
avail_actions = env.get_avail_actions()

# Choose random action if not yet start learning else eps-greedily select actions
if t > learning_starts:
random_selection = np.random.random(num_agents) < exploration.value(t-learning_starts)
# last_obs is a list of array that shape is (obs_shape,) --> numpy.array:(num_agents, obs_shape)
recent_observations = np.concatenate([np.expand_dims(ob, axis=0) for ob in last_obs], axis=0)
action = QMIX_agent.select_actions(recent_observations, avail_actions, random_selection)
else:
action = Categorical(torch.tensor(avail_actions)).sample()
action = [action[i].item() for i in range(num_agents)]

# Advance one step
reward, done, info = env.step(action)

# experience
episode_obs[episode_len] = np.concatenate([np.expand_dims(ob, axis=0) for ob in last_obs], axis=0)
episode_state[episode_len] = last_state
episode_action[episode_len] = np.array(action)
episode_reward[episode_len] = reward
episode_avail_action[episode_len] = np.array(avail_actions)

ep_rewards.append(reward)
obs = env.get_obs(action)
state = env.get_state()

# Resets the environment when reaching an episode boundary
if done:
# store one episode experience into buffer
for i in range(num_agents):
episode_dict = {
'obs': episode_obs[:, i],
'action': episode_action[:, i],
'reward': episode_reward,
'avail_action': episode_avail_action[:, i]
}
QMIX_agent.replay_buffer[i].store(episode_dict, episode_len)

episode_dict = {
'obs': episode_state,
'action': np.zeros(episode_limit),
'reward': episode_reward,
'avail_action': np.zeros((episode_limit, num_actions))
}
QMIX_agent.replay_buffer[-1].store(episode_dict, episode_len)

# tensorboard log
rewards_queue.append(sum(ep_rewards))
steps_queue.append(len(ep_rewards))
win_queue.append(1. if 'battle_won' in info and info['battle_won'] else 0.)
queue_cursor = min(queue_cursor + 1, queue_maxsize)
if queue_cursor == queue_maxsize:
log_rewards.append(mean(rewards_queue[-queue_maxsize:]))
log_steps.append(mean(steps_queue[-queue_maxsize:]))
log_win.append(mean(win_queue[-queue_maxsize:]))
# tensorboard log
writer.add_scalar(tag=f'starcraft{env_name}_train/reward', scalar_value=log_rewards[-1], global_step=t+1)
writer.add_scalar(tag=f'starcraft{env_name}_train/length', scalar_value=log_steps[-1], global_step=t+1)
writer.add_scalar(tag=f'starcraft{env_name}_train/wintag', scalar_value=log_win[-1], global_step=t+1)

ep_rewards = []
episode_len = 0

env.reset()
# init rnn_hidden and numpy of episode experience in the start of every episode
QMIX_agent.Q.init_eval_rnn_hidden()
obs = env.get_obs()
state = env.get_state()
# init para for new episide
episode_obs, episode_state, episode_action, episode_reward, episode_avail_action = \
init_episode_temp(episode_limit, state_size, num_agents, obs_size, num_actions)
else:
episode_len += 1

last_obs = obs
last_state = state

if is_per:
# PER: increase beta
QMIX_agent.increase_bate(t, max_training_steps)

# train and evaluate
if (t > learning_starts and done):
# gradient descent: train
loss = QMIX_agent.update()
num_param_update += 1

# tensorboard log
writer.add_scalar(tag=f'starcraft{env_name}_train/loss', scalar_value=loss, global_step=t+1)

# Periodically update the target network by Q network to target Q network
# and evaluate the Q-net in greedy mode
if num_param_update % target_update_freq == 0:
QMIX_agent.update_targets()
# evaluate the Q-net in greedy mode
eval_reward, eval_step, eval_win = QMIX_agent.evaluate(env_eval, evaluate_num)
writer.add_scalar(tag=f'starcraft{env_name}_eval/reward', scalar_value=mean(eval_reward), global_step=t+1)
writer.add_scalar(tag=f'starcraft{env_name}_eval/length', scalar_value=mean(eval_step), global_step=t+1)
writer.add_scalar(tag=f'starcraft{env_name}_eval/wintag', scalar_value=mean(eval_win), global_step=t+1)

### Log progress and keep track of statistics
df = pd.DataFrame({})
df.insert(loc=0, column='rewards', value=log_rewards)
df.insert(loc=1, column='steps', value=log_steps)
df.insert(loc=2, column='wintag', value=log_win)
df_avg = pd.DataFrame({})
df_avg.insert(loc=0, column='rewards',
value=df['rewards'].rolling(window=20, win_type='triang', min_periods=1).mean())
df_avg.insert(loc=0, column='steps',
value=df['steps'].rolling(window=20, win_type='triang', min_periods=1).mean())
df_avg.insert(loc=2, column='wintag',
value=df['wintag'].rolling(window=20, win_type='triang', min_periods=1).mean())
fig, (ax1, ax2, ax3) = plt.subplots(3, 1)
ax1.plot(df_avg['rewards'], label='rewards')
ax1.set_ylabel('rewards')
ax2.plot(df_avg['steps'], label='steps')
ax2.set_ylabel('steps')
ax3.plot(df_avg['wintag'], label='wintag')
ax3.set_ylabel('wintag')

ax1.set_title(f'{env_name}-{num_agents}agents')
ax2.set_xlabel('∝episode')
plt.legend()
plt.savefig(log_dir + env_name)

writer.close()
env.close()
80 changes: 80 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import argparse
import torch.optim as optim
from smac.env import StarCraft2Env

from model import QMIX_agent
from learn import dqn_learning
from utils.schedule import LinearSchedule
from utils.sc_wrapper import single_net_sc2env

def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--map-name', type=str, default='8m')
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--training-steps', type=int, default=2000000)
parser.add_argument('--anneal-steps', type=int, default=100000)
parser.add_argument('--anneal-start', type=float, default=1.0)
parser.add_argument('--anneal-end', type=float, default=0.01)
parser.add_argument('--replay-buffer-size', type=int, default=5000)
parser.add_argument('--learning-starts', type=int, default=20000)
parser.add_argument('--target-update-freq', type=int, default=200)
parser.add_argument('--learning-rate', type=float, default=3e-4)
# seed
parser.add_argument('--seed', type=int, default=0)
# ddqn
parser.add_argument('--is-ddqn', type=int, default=True)
# per
parser.add_argument('--is-per', type=int, default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.2)
parser.add_argument('--prior-eps', type=float, default=1e-6)
# multi_step
parser.add_argument('--multi-steps', type=int, default=1)
# share networks
parser.add_argument('--share-para', type=int, default=True)
# evaluate
parser.add_argument('--is-evaluate', type=int, default=True)
parser.add_argument('--evaluate-num', type=int, default=32)
# store hyper parameters
parser.add_argument('--store-hyper-para', type=int, default=True)

return parser.parse_args()

def main(args=get_args()):

exploration_schedule = LinearSchedule(args.anneal_steps, args.anneal_end, args.anneal_start)

if args.share_para:
env_class = single_net_sc2env
else:
env_class = StarCraft2Env

dqn_learning(
env_class=env_class,
env_id=args.map_name,
seed=args.seed,
is_ddqn=args.is_ddqn,
multi_steps=args.multi_steps,
is_per=args.is_per,
alpha=args.alpha,
beta=args.beta,
prior_eps=args.prior_eps,
is_share_para=args.share_para,
is_evaluate=args.is_evaluate,
evaluate_num=args.evaluate_num,
q_func=QMIX_agent,
optimizer=optim.RMSprop,
learning_rate=args.learning_rate,
exploration=exploration_schedule,
max_training_steps=args.training_steps,
replay_buffer_size=args.replay_buffer_size,
batch_size=args.batch_size,
gamma=args.gamma,
learning_starts=args.learning_starts,
target_update_freq=args.target_update_freq,
args=args
)

if __name__ == '__main__':
main()
Loading

0 comments on commit e7dc4a0

Please sign in to comment.