Skip to content

Commit

Permalink
add model save during training and mode load
Browse files Browse the repository at this point in the history
  • Loading branch information
tianzikang authored and tianzikang committed Aug 28, 2022
1 parent 2617bda commit a8c1c74
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
18 changes: 11 additions & 7 deletions learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tqdm import *
from torch.distributions import Categorical
import torch
import datetime

from tensorboardX import SummaryWriter

Expand Down Expand Up @@ -46,18 +47,20 @@ def qmix_learning(
learning_rate,
exploration,
max_training_steps=1000000,
replay_buffer_size=1000000,
replay_buffer_size=5000,
batch_size=32,
gamma=.99,
learning_starts=50000,
evaluate_num=4,
target_update_freq=10000,
learning_starts=20000,
evaluate_num=32,
target_update_freq=200,
save_model_freq=2000,
grad_norm_clip=10,
args=None
):
'''
Parameters:
'''
assert save_model_freq % target_update_freq == 0
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
Expand Down Expand Up @@ -87,8 +90,7 @@ def qmix_learning(
log_dir = log_dir + '/'
if not os.path.exists(log_dir):
os.makedirs(log_dir)
num_results = len(next(os.walk(log_dir))[1])
log_dir = log_dir + f'{num_results}/'
log_dir = log_dir + f'seed_{seed}_{datetime.datetime.now().strftime("%m%d_%H-%M-%S")}/'
writer = SummaryWriter(log_dir=log_dir)

# store hyper parameters
Expand Down Expand Up @@ -245,6 +247,8 @@ def qmix_learning(
writer.add_scalar(tag=f'starcraft{env_name}_eval/reward', scalar_value=eval_data[0], global_step=t+1)
writer.add_scalar(tag=f'starcraft{env_name}_eval/length', scalar_value=eval_data[1], global_step=t+1)
writer.add_scalar(tag=f'starcraft{env_name}_eval/wintag', scalar_value=eval_data[2], global_step=t+1)
if num_param_update % save_model_freq == 0:
QMIX_agent.save(checkpoint_path=os.path.join(log_dir, 'agent.pth'))

### log train results
df = pd.DataFrame({})
Expand All @@ -258,7 +262,7 @@ def qmix_learning(
value=df['steps'].rolling(window=20, win_type='triang', min_periods=1).mean())
df_avg.insert(loc=2, column='wintag',
value=df['wintag'].rolling(window=20, win_type='triang', min_periods=1).mean())
fig, (ax1, ax2, ax3) = plt.subplots(3, 1)
_, (ax1, ax2, ax3) = plt.subplots(3, 1)
ax1.plot(df_avg['rewards'], label='rewards')
ax1.set_ylabel('rewards')
ax2.plot(df_avg['steps'], label='steps')
Expand Down
2 changes: 2 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def get_args():
parser.add_argument('--replay-buffer-size', type=int, default=5000)
parser.add_argument('--learning-starts', type=int, default=20000)
parser.add_argument('--target-update-freq', type=int, default=200)
parser.add_argument('--save-model-freq', type=int, default=2000)
parser.add_argument('--learning-rate', type=float, default=3e-4)
parser.add_argument('--grad-norm-clip', type=float, default=0.5)
# seed
Expand Down Expand Up @@ -74,6 +75,7 @@ def main(args=get_args()):
gamma=args.gamma,
learning_starts=args.learning_starts,
target_update_freq=args.target_update_freq,
save_model_freq=args.save_model_freq,
grad_norm_clip=args.grad_norm_clip,
args=args
)
Expand Down
9 changes: 9 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,15 @@ def update_targets(self):
'''load para from Q to target_Q'''
self.target_Q.load_state_dict(self.Q.state_dict())

def save(self, checkpoint_path):
'''save model'''
torch.save(self.Q.state_dict(), checkpoint_path)

def load(self, checkpoint_path):
'''load model'''
self.Q.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))
self.target_Q.load_state_dict(torch.load(checkpoint_path, map_location=lambda storage, loc: storage))

def evaluate(self, env, episode_num=32):
'''evaluate Q model'''
eval_data = []
Expand Down

0 comments on commit a8c1c74

Please sign in to comment.