Skip to content

Commit

Permalink
fix bug: update obs and state when environment reset after evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
Felixvillas committed Sep 20, 2022
1 parent 0f029bd commit 735828c
Showing 1 changed file with 6 additions and 29 deletions.
35 changes: 6 additions & 29 deletions learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,7 @@ def qmix_learning(
else:
episode_len += 1

last_obs = obs
last_state = state


if args.is_per:
# PER: increase beta
QMIX_agent.increase_bate(t, args.training_steps)
Expand All @@ -207,7 +205,9 @@ def qmix_learning(
eval_data = QMIX_agent.evaluate(env, args.evaluate_num)
# env reset after evaluate
env.reset()
QMIX_agent.Q.init_eval_rnn_hidden()
QMIX_agent.Q.init_eval_rnn_hidden()
obs = env.get_obs()
state = env.get_state()
writer.add_scalar(tag=f'starcraft{env_name}_eval/reward', scalar_value=eval_data[0], global_step=num_test * args.test_freq)
writer.add_scalar(tag=f'starcraft{env_name}_eval/length', scalar_value=eval_data[1], global_step=num_test * args.test_freq)
writer.add_scalar(tag=f'starcraft{env_name}_eval/wintag', scalar_value=eval_data[2], global_step=num_test * args.test_freq)
Expand All @@ -216,31 +216,8 @@ def qmix_learning(
# model save
if num_param_update % args.save_model_freq == 0:
QMIX_agent.save(checkpoint_path=os.path.join(log_dir, 'agent.pth'))

### log train results
df = pd.DataFrame({})
df.insert(loc=0, column='rewards', value=log_rewards)
df.insert(loc=1, column='steps', value=log_steps)
df.insert(loc=2, column='wintag', value=log_win)
df_avg = pd.DataFrame({})
df_avg.insert(loc=0, column='rewards',
value=df['rewards'].rolling(window=20, win_type='triang', min_periods=1).mean())
df_avg.insert(loc=0, column='steps',
value=df['steps'].rolling(window=20, win_type='triang', min_periods=1).mean())
df_avg.insert(loc=2, column='wintag',
value=df['wintag'].rolling(window=20, win_type='triang', min_periods=1).mean())
_, (ax1, ax2, ax3) = plt.subplots(3, 1)
ax1.plot(df_avg['rewards'], label='rewards')
ax1.set_ylabel('rewards')
ax2.plot(df_avg['steps'], label='steps')
ax2.set_ylabel('steps')
ax3.plot(df_avg['wintag'], label='wintag')
ax3.set_ylabel('wintag')

ax1.set_title(f'{env_name}-{num_agents}agents')
ax2.set_xlabel('∝episode')
plt.legend()
plt.savefig(log_dir + env_name)
last_obs = obs
last_state = state

writer.close()
env.close()
Expand Down

0 comments on commit 735828c

Please sign in to comment.