Skip to content

Commit

Permalink
Fix: allow for additional kwargs to be passed to train(), remove …
Browse files Browse the repository at this point in the history
…unnecessary call of destructor of `runtime`, which is now handled by context manager and improve output formatting when checkpointing.
  • Loading branch information
m-kurz committed Oct 7, 2024
1 parent a828d26 commit 34e3dc7
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions relexi/rl/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def train( config_file
,mpi_launch_mpmd = False
,strategy = None
,debug = 0
,**kwargs
):
"""
Main training routine. Here, the (FLEXI) environment, the art. neural networks, the optimizer,...
Expand Down Expand Up @@ -343,9 +344,9 @@ def train( config_file

# Checkpoint the policy every ckpt_interval iterations
if (i % ckpt_interval) == 0:
rlxout.info('Saving checkpoint to: ' + ckpt_dir, newline=False)
rlxout.info('Saving checkpoint to: ' + ckpt_dir)
train_checkpointer.save(global_step)
rlxout.info('Saving current model to: ' + save_dir)
rlxout.info('Saving current model to: ' + save_dir, newline=False)
actor_net.model.save(os.path.join(save_dir,f'model_{global_step.numpy():06d}'))

# Flush summary to TensorBoard
Expand All @@ -358,6 +359,3 @@ def train( config_file
# Close all
del my_env
del my_eval_env

del runtime
time.sleep(2.) # Wait for orchestrator to be properly closed

0 comments on commit 34e3dc7

Please sign in to comment.