Skip to content

Commit

Permalink
remove extra starcraft env using for evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
Felixvillas committed Sep 8, 2022
1 parent 8c185f5 commit 0de1833
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 154 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ python main.py --map-name=10m_vs_11m --optimizer=0
And I find that in most scenarios, `0: Adam` converge faster than `1: RMSprop`.
## TODO
Now this code can do very good on part of easy scenarios like 1c3s5z, 2s3z, 3s5z and 8m,
and relative good on easy scenarios like 2s_vs_1sc and 3m,
but not good on easy scenarios 10m_vs_11m.
and relative good on easy scenarios like 2s_vs_1sc, 3m and 10m_vs_11m
but not good on hard and superhard scenarios.

I'm trying to approach the result of pymarl. At the same time, I'm also trying to achieve some tricks on this code like multi step TD target and so on.

Expand Down
124 changes: 37 additions & 87 deletions learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,104 +32,56 @@ def store_hyper_para(args, store_path):

def qmix_learning(
env_class,
env_id,
seed,
is_ddqn,
multi_steps,
is_per,
alpha,
beta,
prior_eps,
is_share_para,
is_evaluate,
q_func,
learning_rate,
exploration,
max_training_steps=1000000,
replay_buffer_size=5000,
batch_size=32,
gamma=.99,
learning_starts=20000,
evaluate_num=32,
target_update_freq=200,
save_model_freq=2000,
grad_norm_clip=10,
args=None
):
'''
Parameters:
'''
assert save_model_freq % target_update_freq == 0
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
env = env_class(map_name=env_id, seed=seed)
if is_evaluate:
env_eval = env_class(map_name=env_id, seed=seed)

env_info = env.get_env_info()
obs_size = env_info['obs_shape']
state_size = env_info['state_shape']
num_actions = env_info['n_actions']
num_agents = env_info['n_agents']
episode_limit = env_info['episode_limit']
assert args.save_model_freq % args.target_update_freq == 0
np.random.seed(args.seed)
torch.manual_seed(args.seed)
random.seed(args.seed)
# Initialize QMIX_agent
QMIX_agent = q_func(
env_class=env_class,
args=args
)
obs_size, state_size, num_actions, num_agents, episode_limit = QMIX_agent.get_env_info()

# Construct tensor log writer
env_name = env_id
env_name = args.map_name
log_dir = f'./results/StarCraft/{env_name}/'
log_dir = log_dir + env_name
if is_ddqn:
if args.is_ddqn:
log_dir = log_dir + '_ddqn'
if multi_steps > 1:
log_dir = log_dir + f'_{multi_steps}multisteps'
if is_per:
if args.multi_steps > 1:
log_dir = log_dir + f'_{args.multi_steps}multisteps'
if args.is_per:
log_dir = log_dir + '_per'
if is_share_para:
if args.share_para:
log_dir = log_dir + '_sharepara'
log_dir = log_dir + '/'
if not os.path.exists(log_dir):
os.makedirs(log_dir)
log_dir = log_dir + f'seed_{seed}_{datetime.datetime.now().strftime("%m%d_%H-%M-%S")}/'
log_dir = log_dir + f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d_%H-%M-%S")}/'
writer = SummaryWriter(log_dir=log_dir)

# store hyper parameters
if args.store_hyper_para:
store_hyper_para(args, log_dir)

# Initialize QMIX_agent
QMIX_agent = q_func(
obs_size=obs_size,
state_size=state_size,
num_agents=num_agents,
num_actions=num_actions,
is_share_para=is_share_para,
is_ddqn=is_ddqn,
multi_steps=multi_steps,
is_per=is_per,
alpha=alpha,
beta=beta,
prior_eps=prior_eps,
gamma=gamma,
replay_buffer_size=replay_buffer_size,
episode_limits=episode_limit,
batch_size=batch_size,
learning_rate=learning_rate,
grad_norm_clip=grad_norm_clip,
args=args
)

#############
# RUN ENV #
#############
num_param_update = 0
env.reset()
# init rnn_hidden and numpy of episode experience in the start of every episode
QMIX_agent.Q.init_eval_rnn_hidden()
QMIX_agent.reset()
episode_obs, episode_state, episode_action, episode_reward, episode_avail_action = \
init_episode_temp(episode_limit, state_size, num_agents, obs_size, num_actions)

last_obs = env.get_obs()
last_state = env.get_state()
last_obs = QMIX_agent.get_obs()
last_state = QMIX_agent.get_state()
# for episode experience
ep_rewards = []
episode_len = 0
Expand All @@ -144,14 +96,14 @@ def qmix_learning(
steps_queue = []
win_queue = []

for t in tqdm(range(max_training_steps)):
for t in tqdm(range(args.training_steps)):

# get avail action for every agent
avail_actions = env.get_avail_actions()
avail_actions = QMIX_agent.get_avail_actions()

# Choose random action if not yet start learning else eps-greedily select actions
if t >= learning_starts:
random_selection = np.random.random(num_agents) < exploration.value(t-learning_starts)
if t >= args.learning_starts:
random_selection = np.random.random(num_agents) < exploration.value(t-args.learning_starts)
# last_obs is a list of array that shape is (obs_shape,) --> numpy.array:(num_agents, obs_shape)
recent_observations = np.concatenate([np.expand_dims(ob, axis=0) for ob in last_obs], axis=0)
action = QMIX_agent.select_actions(recent_observations, avail_actions, random_selection)
Expand All @@ -160,7 +112,7 @@ def qmix_learning(
action = [action[i].item() for i in range(num_agents)]

# Advance one step
reward, done, info = env.step(action)
reward, done, info = QMIX_agent.step(action)

# experience
episode_obs[episode_len] = np.concatenate([np.expand_dims(ob, axis=0) for ob in last_obs], axis=0)
Expand All @@ -170,8 +122,8 @@ def qmix_learning(
episode_avail_action[episode_len] = np.array(avail_actions)

ep_rewards.append(reward)
obs = env.get_obs(action)
state = env.get_state()
obs = QMIX_agent.get_obs(action)
state = QMIX_agent.get_state()

# Resets the environment when reaching an episode boundary
if done:
Expand Down Expand Up @@ -204,11 +156,9 @@ def qmix_learning(
ep_rewards = []
episode_len = 0

env.reset()
# init rnn_hidden and numpy of episode experience in the start of every episode
QMIX_agent.Q.init_eval_rnn_hidden()
obs = env.get_obs()
state = env.get_state()
QMIX_agent.reset()
obs = QMIX_agent.get_obs()
state = QMIX_agent.get_state()
# init para for new episide
episode_obs, episode_state, episode_action, episode_reward, episode_avail_action = \
init_episode_temp(episode_limit, state_size, num_agents, obs_size, num_actions)
Expand All @@ -218,12 +168,12 @@ def qmix_learning(
last_obs = obs
last_state = state

if is_per:
if args.is_per:
# PER: increase beta
QMIX_agent.increase_bate(t, max_training_steps)
QMIX_agent.increase_bate(t, args.training_steps)

# train and evaluate
if (t >= learning_starts and done and QMIX_agent.can_sample()):
if (t >= args.learning_starts and done and QMIX_agent.can_sample()):
# gradient descent: train
loss = QMIX_agent.update()
num_param_update += 1
Expand All @@ -233,14 +183,14 @@ def qmix_learning(

# Periodically update the target network by Q network to target Q network
# and evaluate the Q-net in greedy mode
if num_param_update % target_update_freq == 0:
if num_param_update % args.target_update_freq == 0:
QMIX_agent.update_targets()
# evaluate the Q-net in greedy mode
eval_data = QMIX_agent.evaluate(env_eval, evaluate_num)
eval_data = QMIX_agent.evaluate(args.evaluate_num)
writer.add_scalar(tag=f'starcraft{env_name}_eval/reward', scalar_value=eval_data[0], global_step=t+1)
writer.add_scalar(tag=f'starcraft{env_name}_eval/length', scalar_value=eval_data[1], global_step=t+1)
writer.add_scalar(tag=f'starcraft{env_name}_eval/wintag', scalar_value=eval_data[2], global_step=t+1)
if num_param_update % save_model_freq == 0:
if num_param_update % args.save_model_freq == 0:
QMIX_agent.save(checkpoint_path=os.path.join(log_dir, 'agent.pth'))

### log train results
Expand Down Expand Up @@ -269,4 +219,4 @@ def qmix_learning(
plt.savefig(log_dir + env_name)

writer.close()
env.close()
QMIX_agent.close()
22 changes: 1 addition & 21 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_args():
parser.add_argument('--anneal-start', type=float, default=1.0)
parser.add_argument('--anneal-end', type=float, default=0.01)
parser.add_argument('--replay-buffer-size', type=int, default=5000)
parser.add_argument('--learning-starts', type=int, default=20000)
parser.add_argument('--learning-starts', type=int, default=0)
parser.add_argument('--target-update-freq', type=int, default=200)
parser.add_argument('--save-model-freq', type=int, default=2000)
parser.add_argument('--learning-rate', type=float, default=3e-4)
Expand Down Expand Up @@ -56,28 +56,8 @@ def main(args=get_args()):

qmix_learning(
env_class=env_class,
env_id=args.map_name,
seed=args.seed,
is_ddqn=args.is_ddqn,
multi_steps=args.multi_steps,
is_per=args.is_per,
alpha=args.alpha,
beta=args.beta,
prior_eps=args.prior_eps,
is_share_para=args.share_para,
is_evaluate=args.is_evaluate,
evaluate_num=args.evaluate_num,
q_func=QMIX_agent,
learning_rate=args.learning_rate,
exploration=exploration_schedule,
max_training_steps=args.training_steps,
replay_buffer_size=args.replay_buffer_size,
batch_size=args.batch_size,
gamma=args.gamma,
learning_starts=args.learning_starts,
target_update_freq=args.target_update_freq,
save_model_freq=args.save_model_freq,
grad_norm_clip=args.grad_norm_clip,
args=args
)

Expand Down
Loading

0 comments on commit 0de1833

Please sign in to comment.