Skip to content

Commit

Permalink
fix bug: orth init does not take effect and change statistical method…
Browse files Browse the repository at this point in the history
… of evaluation lightly
  • Loading branch information
tianzikang committed Aug 19, 2022
1 parent 2fb8e82 commit 89a4a6d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
12 changes: 7 additions & 5 deletions learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def qmix_learning(
learning_starts=50000,
evaluate_num=4,
target_update_freq=10000,
grad_norm_clip=10,
args=None
):
'''
Expand Down Expand Up @@ -112,7 +113,8 @@ def qmix_learning(
episode_limits=episode_limit,
batch_size=batch_size,
optimizer=optimizer,
learning_rate=learning_rate
learning_rate=learning_rate,
grad_norm_clip=grad_norm_clip
)

#############
Expand Down Expand Up @@ -239,10 +241,10 @@ def qmix_learning(
if num_param_update % target_update_freq == 0:
QMIX_agent.update_targets()
# evaluate the Q-net in greedy mode
eval_reward, eval_step, eval_win = QMIX_agent.evaluate(env_eval, evaluate_num)
writer.add_scalar(tag=f'starcraft{env_name}_eval/reward', scalar_value=mean(eval_reward), global_step=t+1)
writer.add_scalar(tag=f'starcraft{env_name}_eval/length', scalar_value=mean(eval_step), global_step=t+1)
writer.add_scalar(tag=f'starcraft{env_name}_eval/wintag', scalar_value=mean(eval_win), global_step=t+1)
eval_data = QMIX_agent.evaluate(env_eval, evaluate_num)
writer.add_scalar(tag=f'starcraft{env_name}_eval/reward', scalar_value=eval_data[0], global_step=t+1)
writer.add_scalar(tag=f'starcraft{env_name}_eval/length', scalar_value=eval_data[1], global_step=t+1)
writer.add_scalar(tag=f'starcraft{env_name}_eval/wintag', scalar_value=eval_data[2], global_step=t+1)

### log train results
df = pd.DataFrame({})
Expand Down
2 changes: 2 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def get_args():
parser.add_argument('--learning-starts', type=int, default=20000)
parser.add_argument('--target-update-freq', type=int, default=200)
parser.add_argument('--learning-rate', type=float, default=3e-4)
parser.add_argument('--grad-norm-clip', type=float, default=0.5)
# seed
parser.add_argument('--seed', type=int, default=0)
# ddqn
Expand Down Expand Up @@ -73,6 +74,7 @@ def main(args=get_args()):
gamma=args.gamma,
learning_starts=args.learning_starts,
target_update_freq=args.target_update_freq,
grad_norm_clip=args.grad_norm_clip,
args=args
)

Expand Down
23 changes: 10 additions & 13 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, obs_size=16, state_size=32, num_agents=2, num_actions=5) -> N

def orth_init(self):
# orthogonal initialization
for m in list(self.parameters()):
for m in list(self.modules()):
if isinstance(m, nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
Expand Down Expand Up @@ -149,7 +149,8 @@ def __init__(
episode_limits=60,
batch_size=32,
optimizer=torch.optim.RMSprop,
learning_rate=3e-4
learning_rate=3e-4,
grad_norm_clip=10,
) -> None:
super(QMIX_agent, self).__init__()
assert multi_steps == 1 and is_per == False and is_share_para == True, \
Expand All @@ -176,7 +177,7 @@ def __init__(
self.target_Q.load_state_dict(self.Q.state_dict())

self.params = list(self.Q.parameters())
self.grad_norm_clip = 0.5
self.grad_norm_clip = grad_norm_clip
# RMSProp alpha:0.99, RMSProp epsilon:0.00001
self.optimizer = optimizer(self.params, learning_rate, alpha=0.99, eps=1e-5)
self.MseLoss = nn.MSELoss(reduction='sum')
Expand Down Expand Up @@ -292,9 +293,7 @@ def update_targets(self):

def evaluate(self, env, episode_num=32):
'''evaluate Q model'''
eval_reward = []
eval_step = []
eval_win = []
eval_data = []
for _ in range(episode_num):
eval_ep_rewards = []
done = False
Expand All @@ -310,13 +309,11 @@ def evaluate(self, env, episode_num=32):
eval_ep_rewards.append(reward)

if done:
eval_reward.append(sum(eval_ep_rewards))
eval_step.append(len(eval_ep_rewards))
eval_win.append(1. if 'battle_won' in info and info['battle_won'] else 0.)
eval_data.append(
[sum(eval_ep_rewards), len(eval_ep_rewards), 1. if 'battle_won' in info and info['battle_won'] else 0.]
)

start = episode_num // 4
end = episode_num * 3 // 4
eval_reward = sorted(eval_reward)[start:end]
eval_step = sorted(eval_step)[start:end]
eval_win = sorted(eval_win)[start:end]
return eval_reward, eval_step, eval_win
sort_eval_data = sorted(eval_data, key=lambda x: x[-1])[start: end] # sorted by win or nor
return np.mean(sort_eval_data, axis=0)

0 comments on commit 89a4a6d

Please sign in to comment.