From 735828cd7a3ace8b9d5475a5417c545651f8c963 Mon Sep 17 00:00:00 2001 From: tianzikang Date: Tue, 20 Sep 2022 12:45:14 +0800 Subject: [PATCH] fix bug: update obs and state when environment reset after evaluation --- learn.py | 35 ++++++----------------------------- 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/learn.py b/learn.py index 9020c31..7f6ff08 100644 --- a/learn.py +++ b/learn.py @@ -183,9 +183,7 @@ def qmix_learning( else: episode_len += 1 - last_obs = obs - last_state = state - + if args.is_per: # PER: increase beta QMIX_agent.increase_bate(t, args.training_steps) @@ -207,7 +205,9 @@ def qmix_learning( eval_data = QMIX_agent.evaluate(env, args.evaluate_num) # env reset after evaluate env.reset() - QMIX_agent.Q.init_eval_rnn_hidden() + QMIX_agent.Q.init_eval_rnn_hidden() + obs = env.get_obs() + state = env.get_state() writer.add_scalar(tag=f'starcraft{env_name}_eval/reward', scalar_value=eval_data[0], global_step=num_test * args.test_freq) writer.add_scalar(tag=f'starcraft{env_name}_eval/length', scalar_value=eval_data[1], global_step=num_test * args.test_freq) writer.add_scalar(tag=f'starcraft{env_name}_eval/wintag', scalar_value=eval_data[2], global_step=num_test * args.test_freq) @@ -216,31 +216,8 @@ def qmix_learning( # model save if num_param_update % args.save_model_freq == 0: QMIX_agent.save(checkpoint_path=os.path.join(log_dir, 'agent.pth')) - - ### log train results - df = pd.DataFrame({}) - df.insert(loc=0, column='rewards', value=log_rewards) - df.insert(loc=1, column='steps', value=log_steps) - df.insert(loc=2, column='wintag', value=log_win) - df_avg = pd.DataFrame({}) - df_avg.insert(loc=0, column='rewards', - value=df['rewards'].rolling(window=20, win_type='triang', min_periods=1).mean()) - df_avg.insert(loc=0, column='steps', - value=df['steps'].rolling(window=20, win_type='triang', min_periods=1).mean()) - df_avg.insert(loc=2, column='wintag', - value=df['wintag'].rolling(window=20, win_type='triang', min_periods=1).mean()) - _, (ax1, ax2, ax3) = plt.subplots(3, 1) - ax1.plot(df_avg['rewards'], label='rewards') - ax1.set_ylabel('rewards') - ax2.plot(df_avg['steps'], label='steps') - ax2.set_ylabel('steps') - ax3.plot(df_avg['wintag'], label='wintag') - ax3.set_ylabel('wintag') - - ax1.set_title(f'{env_name}-{num_agents}agents') - ax2.set_xlabel('∝episode') - plt.legend() - plt.savefig(log_dir + env_name) + last_obs = obs + last_state = state writer.close() env.close()