diff --git a/learn.py b/learn.py index 7f6ff08..19268b2 100644 --- a/learn.py +++ b/learn.py @@ -1,37 +1,17 @@ import os -import matplotlib.pyplot as plt -import pandas as pd import numpy as np from statistics import mean from tqdm import * import torch -import datetime - from tensorboardX import SummaryWriter -def init_episode_temp(ep_limits, state_shape, num_agents, obs_dim, action_dim): - episode_obs = np.zeros((ep_limits+1, num_agents, obs_dim), dtype=np.float32) - episode_state = np.zeros((ep_limits+1, state_shape), dtype=np.float32) - episode_action = np.zeros((ep_limits+1, num_agents), dtype=np.int64) - episode_reward = np.zeros((ep_limits+1), dtype=np.float32) - episode_avail_action = np.zeros((ep_limits+1, num_agents, action_dim), dtype=np.float32) - return episode_obs, episode_state, episode_action, episode_reward, episode_avail_action - -def store_hyper_para(args, store_path): - argsDict = args.__dict__ - f = open(os.path.join(store_path, 'hyper_para.txt'), 'w') - f.writelines('======================starts========================' + '\n') - for key, value in argsDict.items(): - f.writelines(key + ':' + str(value) + '\n') - f.writelines('======================ends========================' + '\n') - f.close() - print('==================hyper parameters store done!==================') - +from utils.tools import store_hyper_para, construct_results_dir +from utils.schedule import LinearSchedule +from utils.sc_wrapper import single_net_sc2env +from smac.env import StarCraft2Env +from model import QMIX_agent def qmix_learning( - env_class, - q_func, - exploration, args=None ): ''' @@ -42,33 +22,25 @@ def qmix_learning( np.random.seed(args.seed) torch.manual_seed(args.seed) # Initialize Env - env = env_class(map_name=args.map_name, seed=args.seed) + if args.share_para: + env = single_net_sc2env(map_name=args.map_name, seed=args.seed) + else: + env = StarCraft2Env(map_name=args.map_name, seed=args.seed) env_info = env.get_env_info() - # Initialize QMIX_agent - QMIX_agent = q_func( + + # Initialize qmix_agent + qmix_agent = QMIX_agent( env_info=env_info, args=args ) - obs_size, state_size, num_actions, num_agents, episode_limit = QMIX_agent.get_env_info() # Construct tensor log writer - env_name = args.map_name - log_dir = f'./results/StarCraft/{env_name}/' - log_dir = log_dir + env_name - if args.is_ddqn: - log_dir = log_dir + '_ddqn' - if args.multi_steps > 1: - log_dir = log_dir + f'_{args.multi_steps}multisteps' - if args.is_per: - log_dir = log_dir + '_per' - if args.share_para: - log_dir = log_dir + '_sharepara' - log_dir = log_dir + '/' - if not os.path.exists(log_dir): - os.makedirs(log_dir) - log_dir = log_dir + f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d_%H-%M-%S")}/' + log_dir = construct_results_dir(args) writer = SummaryWriter(log_dir=log_dir) + # Construct linear schedule + exploration_schedule = LinearSchedule(args.anneal_steps, args.anneal_end, args.anneal_start) + # store hyper parameters if args.store_hyper_para: store_hyper_para(args, log_dir) @@ -77,16 +49,6 @@ def qmix_learning( # RUN ENV # ############# num_param_update = 0 - env.reset() - QMIX_agent.Q.init_eval_rnn_hidden() - episode_obs, episode_state, episode_action, episode_reward, episode_avail_action = \ - init_episode_temp(episode_limit, state_size, num_agents, obs_size, num_actions) - - last_obs = env.get_obs() - last_state = env.get_state() - # for episode experience - ep_rewards = [] - episode_len = 0 # log paramaters log_rewards = [] @@ -99,128 +61,61 @@ def qmix_learning( win_queue = [] # refer pymarl: in every episode, t in exploration.value(t) is consistent - t_exploration = 0 - - for t in tqdm(range(args.training_steps)): - - # get avail action for every agent - avail_actions = env.get_avail_actions() - - # eps-greedily select actions - random_selection = np.random.random(num_agents) < exploration.value(t_exploration) - # last_obs is a list of array that shape is (obs_shape,) --> numpy.array:(num_agents, obs_shape) - recent_observations = np.concatenate([np.expand_dims(ob, axis=0) for ob in last_obs], axis=0) - action = QMIX_agent.select_actions(recent_observations, avail_actions, random_selection) - - # Advance one step - reward, done, info = env.step(action) - - # experience - episode_obs[episode_len] = np.concatenate([np.expand_dims(ob, axis=0) for ob in last_obs], axis=0) - episode_state[episode_len] = last_state - episode_action[episode_len] = np.array(action) - episode_reward[episode_len] = reward - episode_avail_action[episode_len] = np.array(avail_actions) - - ep_rewards.append(reward) - obs = env.get_obs(action) - state = env.get_state() - - # Resets the environment when reaching an episode boundary - if done: - '''for last experience in every episode''' - # get avail action for every agent - avail_actions = env.get_avail_actions() - # eps-greedily select actions - random_selection = np.random.random(num_agents) < exploration.value(t_exploration) - # last_obs is a list of array that shape is (obs_shape,) --> numpy.array:(num_agents, obs_shape) - recent_observations = np.concatenate([np.expand_dims(ob, axis=0) for ob in obs], axis=0) - action = QMIX_agent.select_actions(recent_observations, avail_actions, random_selection) - episode_obs[episode_len+1] = np.concatenate([np.expand_dims(ob, axis=0) for ob in obs], axis=0) - episode_state[episode_len+1] = state - episode_action[episode_len+1] = np.array(action) - episode_reward[episode_len+1] = 0 - episode_avail_action[episode_len+1] = np.array(avail_actions) - - # store one episode experience into buffer - episode_dict = { - 'obs': episode_obs, - 'action': episode_action, - 'avail_action': episode_avail_action - } - total_episode_dict = { - 'obs': episode_state, - 'reward': episode_reward, - } - QMIX_agent.replay_buffer.store(episode_dict, total_episode_dict, episode_len) - + t = 0 + pbar = tqdm(total=args.training_steps) + + while t < args.training_steps: + # run episode + epsilon = exploration_schedule.value(t) + ep_rewards, win_flag, episode_len = qmix_agent.run_episode(env, epsilon) + + rewards_queue.append(sum(ep_rewards)) + steps_queue.append(len(ep_rewards)) + win_queue.append(win_flag) + queue_cursor = min(queue_cursor + 1, queue_maxsize) + if queue_cursor == queue_maxsize: + log_rewards.append(mean(rewards_queue[-queue_maxsize:])) + log_steps.append(mean(steps_queue[-queue_maxsize:])) + log_win.append(mean(win_queue[-queue_maxsize:])) # tensorboard log - rewards_queue.append(sum(ep_rewards)) - steps_queue.append(len(ep_rewards)) - win_queue.append(1. if 'battle_won' in info and info['battle_won'] else 0.) - queue_cursor = min(queue_cursor + 1, queue_maxsize) - if queue_cursor == queue_maxsize: - log_rewards.append(mean(rewards_queue[-queue_maxsize:])) - log_steps.append(mean(steps_queue[-queue_maxsize:])) - log_win.append(mean(win_queue[-queue_maxsize:])) - # tensorboard log - writer.add_scalar(tag=f'starcraft{env_name}_train/reward', scalar_value=log_rewards[-1], global_step=t+1) - writer.add_scalar(tag=f'starcraft{env_name}_train/length', scalar_value=log_steps[-1], global_step=t+1) - writer.add_scalar(tag=f'starcraft{env_name}_train/wintag', scalar_value=log_win[-1], global_step=t+1) + writer.add_scalar(tag=f'starcraft{args.map_name}_train/reward', scalar_value=log_rewards[-1], global_step=t+1) + writer.add_scalar(tag=f'starcraft{args.map_name}_train/length', scalar_value=log_steps[-1], global_step=t+1) + writer.add_scalar(tag=f'starcraft{args.map_name}_train/wintag', scalar_value=log_win[-1], global_step=t+1) - ep_rewards = [] - episode_len = 0 - - env.reset() - QMIX_agent.Q.init_eval_rnn_hidden() - obs = env.get_obs() - state = env.get_state() - # init para for new episide - episode_obs, episode_state, episode_action, episode_reward, episode_avail_action = \ - init_episode_temp(episode_limit, state_size, num_agents, obs_size, num_actions) - # update t_exploration - t_exploration = t - else: - episode_len += 1 - + t += episode_len + pbar.update(episode_len) if args.is_per: # PER: increase beta - QMIX_agent.increase_bate(t, args.training_steps) + qmix_agent.increase_bate(t, args.training_steps) # train and evaluate - if (done and QMIX_agent.can_sample()): + if qmix_agent.can_sample(): # gradient descent: train - loss = QMIX_agent.update() + loss = qmix_agent.update() num_param_update += 1 # tensorboard log - writer.add_scalar(tag=f'starcraft{env_name}_train/loss', scalar_value=loss, global_step=t+1) + writer.add_scalar(tag=f'starcraft{args.map_name}_train/loss', scalar_value=loss, global_step=t+1) # Periodically update the target network by Q network to target Q network if num_param_update % args.target_update_freq == 0: - QMIX_agent.update_targets() + qmix_agent.update_targets() # evaluate the Q-net in greedy mode if (t - last_test_t) / args.test_freq >= 1.0: - eval_data = QMIX_agent.evaluate(env, args.evaluate_num) - # env reset after evaluate - env.reset() - QMIX_agent.Q.init_eval_rnn_hidden() - obs = env.get_obs() - state = env.get_state() - writer.add_scalar(tag=f'starcraft{env_name}_eval/reward', scalar_value=eval_data[0], global_step=num_test * args.test_freq) - writer.add_scalar(tag=f'starcraft{env_name}_eval/length', scalar_value=eval_data[1], global_step=num_test * args.test_freq) - writer.add_scalar(tag=f'starcraft{env_name}_eval/wintag', scalar_value=eval_data[2], global_step=num_test * args.test_freq) + eval_data = qmix_agent.evaluate(env, args.evaluate_num) + # env reset after evaluate + writer.add_scalar(tag=f'starcraft{args.map_name}_eval/reward', scalar_value=eval_data[0], global_step=num_test * args.test_freq) + writer.add_scalar(tag=f'starcraft{args.map_name}_eval/length', scalar_value=eval_data[1], global_step=num_test * args.test_freq) + writer.add_scalar(tag=f'starcraft{args.map_name}_eval/wintag', scalar_value=eval_data[2], global_step=num_test * args.test_freq) last_test_t = t num_test += 1 # model save if num_param_update % args.save_model_freq == 0: - QMIX_agent.save(checkpoint_path=os.path.join(log_dir, 'agent.pth')) - last_obs = obs - last_state = state + qmix_agent.save(checkpoint_path=os.path.join(log_dir, 'agent.pth')) writer.close() env.close() # last model save - QMIX_agent.save(checkpoint_path=os.path.join(log_dir, 'agent.pth')) \ No newline at end of file + qmix_agent.save(checkpoint_path=os.path.join(log_dir, 'agent.pth')) \ No newline at end of file diff --git a/main.py b/main.py index c24cead..e615083 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,5 @@ import argparse -from smac.env import StarCraft2Env - -from model import QMIX_agent from learn import qmix_learning -from utils.schedule import LinearSchedule -from utils.sc_wrapper import single_net_sc2env def get_args(): parser = argparse.ArgumentParser() @@ -46,17 +41,7 @@ def get_args(): def main(args=get_args()): - exploration_schedule = LinearSchedule(args.anneal_steps, args.anneal_end, args.anneal_start) - - if args.share_para: - env_class = single_net_sc2env - else: - env_class = StarCraft2Env - qmix_learning( - env_class=env_class, - q_func=QMIX_agent, - exploration=exploration_schedule, args=args ) diff --git a/model.py b/model.py index 62595f9..725373a 100644 --- a/model.py +++ b/model.py @@ -4,6 +4,7 @@ from torch.distributions import Categorical import numpy as np from utils.simple_replay_buffer import ReplayBuffer +from utils.tools import init_episode_temp ################################## set device ################################## print("============================================================================================") @@ -304,3 +305,55 @@ def evaluate(self, env, episode_num=32): ) return np.mean(eval_data, axis=0) + + def run_episode(self, env, epsilon): + env.reset() + self.Q.init_eval_rnn_hidden() + episode_len = 0 + done = False + action = None + episode_obs, episode_state, episode_action, episode_reward, episode_avail_action = \ + init_episode_temp(self.episode_limits, self.state_size, self.num_agents, self.obs_size, self.num_actions) + reward_list = [] + while not done: + obs = env.get_obs(action) + state = env.get_state() + avail_actions = env.get_avail_actions() + random_selection = np.random.random(self.num_agents) < epsilon + recent_observations = np.concatenate([np.expand_dims(ob, axis=0) for ob in obs], axis=0) + action = self.select_actions(recent_observations.copy(), avail_actions, random_selection) + reward, done, info = env.step(action) + # experience + episode_obs[episode_len] = recent_observations + episode_state[episode_len] = state + episode_action[episode_len] = np.array(action) + episode_reward[episode_len] = reward + episode_avail_action[episode_len] = np.array(avail_actions) + reward_list.append(reward) + episode_len += 1 + + '''done: for last experience in every episode''' + obs = env.get_obs(action) + state = env.get_state() + avail_actions = env.get_avail_actions() + random_selection = np.random.random(self.num_agents) < epsilon + recent_observations = np.concatenate([np.expand_dims(ob, axis=0) for ob in obs], axis=0) + action = self.select_actions(recent_observations.copy(), avail_actions, random_selection) + episode_obs[episode_len] = recent_observations + episode_state[episode_len] = state + episode_action[episode_len] = np.array(action) + episode_reward[episode_len] = 0 + episode_avail_action[episode_len] = np.array(avail_actions) + episode_dict = { + 'obs': episode_obs, + 'action': episode_action, + 'avail_action': episode_avail_action + } + total_episode_dict = { + 'obs': episode_state, + 'reward': episode_reward, + } + + self.replay_buffer.store(episode_dict, total_episode_dict, episode_len-1) + win_flag = 1. if 'battle_won' in info and info['battle_won'] else 0. + return reward_list, win_flag, episode_len \ No newline at end of file