From 34e3dc7c0f300c1307faf677934501d2631ed34e Mon Sep 17 00:00:00 2001 From: Marius Kurz Date: Mon, 7 Oct 2024 10:24:25 +0200 Subject: [PATCH] Fix: allow for additional `kwargs` to be passed to `train()`, remove unnecessary call of destructor of `runtime`, which is now handled by context manager and improve output formatting when checkpointing. --- relexi/rl/ppo/train.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/relexi/rl/ppo/train.py b/relexi/rl/ppo/train.py index c324bd3..e482c88 100644 --- a/relexi/rl/ppo/train.py +++ b/relexi/rl/ppo/train.py @@ -84,6 +84,7 @@ def train( config_file ,mpi_launch_mpmd = False ,strategy = None ,debug = 0 + ,**kwargs ): """ Main training routine. Here, the (FLEXI) environment, the art. neural networks, the optimizer,... @@ -343,9 +344,9 @@ def train( config_file # Checkpoint the policy every ckpt_interval iterations if (i % ckpt_interval) == 0: - rlxout.info('Saving checkpoint to: ' + ckpt_dir, newline=False) + rlxout.info('Saving checkpoint to: ' + ckpt_dir) train_checkpointer.save(global_step) - rlxout.info('Saving current model to: ' + save_dir) + rlxout.info('Saving current model to: ' + save_dir, newline=False) actor_net.model.save(os.path.join(save_dir,f'model_{global_step.numpy():06d}')) # Flush summary to TensorBoard @@ -358,6 +359,3 @@ def train( config_file # Close all del my_env del my_eval_env - - del runtime - time.sleep(2.) # Wait for orchestrator to be properly closed