Skip to content

Commit

Permalink
add optimizer selection
Browse files Browse the repository at this point in the history
  • Loading branch information
tianzikang committed Aug 30, 2022
1 parent a8c1c74 commit ca7da80
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 11 deletions.
44 changes: 41 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,46 @@
# QMIXRNN
Referring to pymarl, qmix is implemented with RNN to cope with SMAC environment
Referring to pymarl, qmix is implemented clearly with RNN to cope with SMAC environment.
This clear implementation can help you figure out how does QMIX work

## Run
`python main.py --map-name=3s5z`
Note: --optimizer=0/1 means that optimizer `Adam` and `RMSprop` is good on this scenario, please just select one of both when running
`python main.py --map-name=3s5z --optimizer=0/1`
`python main.py --map-name=1c3s5z --optimizer=0/1`
`python main.py --map-name=2s3z --optimizer=0/1`
`python main.py --map-name=8m --optimizer=0/1`
`python main.py --map-name=2s_vs_1sc --optimizer=0`
`python main.py --map-name=3m --optimizer=0`
`python main.py --map-name=10m_vs_11m --optimizer=0`

## TODO
Now this code can deal with some easy scenarios like 2s3z, 3s5z, 3m, 8m, and I'm trying to approach the result of pymarl. At the same time, I'm also trying to achieve some tricks on this code like multi step TD target and so on.
Now this code can do very good on part of easy scenarios like 1c3s5z, 2s3z, 3s5z and 8m,
and relative good on easy scenarios like 2s_vs_1sc and 3m,
but not good on easy scenarios 10m_vs_11m.

I'm trying to approach the result of pymarl. At the same time, I'm also trying to achieve some tricks on this code like multi step TD target and so on.

## Reference
@inproceedings{rashid2018qmix,
title={Qmix: Monotonic value function factorisation for deep multi-agent reinforcement learning},
author={Rashid, Tabish and Samvelyan, Mikayel and Schroeder, Christian and Farquhar, Gregory and Foerster, Jakob and Whiteson, Shimon},
booktitle={International conference on machine learning},
pages={4295--4304},
year={2018},
organization={PMLR}
}

@article{samvelyan19smac,
title = {{The} {StarCraft} {Multi}-{Agent} {Challenge}},
author = {Mikayel Samvelyan and Tabish Rashid and Christian Schroeder de Witt and Gregory Farquhar and Nantas Nardelli and Tim G. J. Rudner and Chia-Man Hung and Philiph H. S. Torr and Jakob Foerster and Shimon Whiteson},
journal = {CoRR},
volume = {abs/1902.04043},
year = {2019},
}

@article{samvelyan19smac,
title = {{The} {StarCraft} {Multi}-{Agent} {Challenge}},
author = {Mikayel Samvelyan and Tabish Rashid and Christian Schroeder de Witt and Gregory Farquhar and Nantas Nardelli and Tim G. J. Rudner and Chia-Man Hung and Philiph H. S. Torr and Jakob Foerster and Shimon Whiteson},
journal = {CoRR},
volume = {abs/1902.04043},
year = {2019},
}
7 changes: 3 additions & 4 deletions learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def qmix_learning(
is_share_para,
is_evaluate,
q_func,
optimizer,
learning_rate,
exploration,
max_training_steps=1000000,
Expand Down Expand Up @@ -113,10 +112,10 @@ def qmix_learning(
gamma=gamma,
replay_buffer_size=replay_buffer_size,
episode_limits=episode_limit,
batch_size=batch_size,
optimizer=optimizer,
batch_size=batch_size,
learning_rate=learning_rate,
grad_norm_clip=grad_norm_clip
grad_norm_clip=grad_norm_clip,
args=args
)

#############
Expand Down
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def get_args():
parser.add_argument('--evaluate-num', type=int, default=32)
# store hyper parameters
parser.add_argument('--store-hyper-para', type=int, default=True)
# optimizer
parser.add_argument('--optimizer', type=int, default=0, help="0: Adam--[3m, 2s_vs_1sc]; 1: RMSprop--[others]")

return parser.parse_args()

Expand All @@ -66,7 +68,6 @@ def main(args=get_args()):
is_evaluate=args.is_evaluate,
evaluate_num=args.evaluate_num,
q_func=QMIX_agent,
optimizer=optim.RMSprop,
learning_rate=args.learning_rate,
exploration=exploration_schedule,
max_training_steps=args.training_steps,
Expand Down
11 changes: 8 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ def __init__(
replay_buffer_size=5000,
episode_limits=60,
batch_size=32,
optimizer=torch.optim.RMSprop,
learning_rate=3e-4,
grad_norm_clip=10,
args=None
) -> None:
super(QMIX_agent, self).__init__()
assert multi_steps == 1 and is_per == False and is_share_para == True, \
Expand Down Expand Up @@ -178,8 +178,13 @@ def __init__(

self.params = list(self.Q.parameters())
self.grad_norm_clip = grad_norm_clip
# RMSProp alpha:0.99, RMSProp epsilon:0.00001
self.optimizer = optimizer(self.params, learning_rate, alpha=0.99, eps=1e-5)
if args.optimizer == 0:
# Adam: 3m, 2s_vs_1sc
self.optimizer = torch.optim.Adam(self.params, learning_rate)
elif args.optimizer == 1:
# RMSProp alpha:0.99, RMSProp epsilon:0.00001
self.optimizer = torch.optim.RMSprop(self.params, learning_rate, alpha=0.99, eps=1e-5)

self.MseLoss = nn.MSELoss(reduction='sum')

# Consturct buffer
Expand Down

0 comments on commit ca7da80

Please sign in to comment.