From e6157eb8712e8386feb57b20d01a48387b122a63 Mon Sep 17 00:00:00 2001 From: jakegrigsby Date: Mon, 2 Sep 2024 14:23:14 -0500 Subject: [PATCH 1/3] pass reset info up through env stack --- amago/envs/amago_env.py | 4 ++-- amago/envs/env_utils.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/amago/envs/amago_env.py b/amago/envs/amago_env.py index 121256a..b0463f0 100644 --- a/amago/envs/amago_env.py +++ b/amago/envs/amago_env.py @@ -96,7 +96,7 @@ def inner_reset(self, seed=None, options=None): def reset(self, seed=None, options=None) -> Timestep: self.step_count = 0 - obs, _ = self.inner_reset(seed=seed, options=options) + obs, info = self.inner_reset(seed=seed, options=options) if not isinstance(obs, dict): obs = {"observation": obs} timestep = Timestep( @@ -110,7 +110,7 @@ def reset(self, seed=None, options=None) -> Timestep: terminal=False, raw_time_idx=0, ) - return timestep + return timestep, info def inner_step(self, action): return self.env.step(action) diff --git a/amago/envs/env_utils.py b/amago/envs/env_utils.py index de3d0cb..caa84e4 100644 --- a/amago/envs/env_utils.py +++ b/amago/envs/env_utils.py @@ -202,7 +202,6 @@ def sequence_lengths(self): class ExplorationWrapper(ABC, gym.ActionWrapper): - @abstractmethod def add_exploration_noise(self, action: np.ndarray, local_step: int, horizon: int): raise NotImplementedError @@ -390,7 +389,7 @@ def reset_stats(self): self.success_history = SuccessHistory(self.env_name) def reset(self, seed=None) -> Timestep: - timestep = self.env.reset(seed=seed) + timestep, info = self.env.reset(seed=seed) self.active_traj = Trajectory( max_goals=self.env.max_goal_seq_length, timesteps=[timestep] ) @@ -400,7 +399,7 @@ def reset(self, seed=None) -> Timestep: ) self.total_return = 0.0 self._current_timestep = self.active_traj.make_sequence(last_only=True) - return timestep.obs, {} + return timestep.obs, info def step(self, action): timestep, reward, terminated, truncated, info = self.env.step(action) From c18397feb86241fe6d7270420056bff619d18137 Mon Sep 17 00:00:00 2001 From: jakegrigsby Date: Mon, 9 Sep 2024 11:42:31 -0500 Subject: [PATCH 2/3] save backup raw weight files during checkpoints --- README.md | 10 ++++++++-- amago/envs/env_utils.py | 1 + amago/learning.py | 24 ++++++++++++++++++++---- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 393a0dd..275142d 100644 --- a/README.md +++ b/README.md @@ -108,10 +108,16 @@ Aside from the `wandb` logging metrics, AMAGO outputs data in the following form ... - {Experiment.run_name} or {--run_name}/ - ckpts/ - # full checkpoints that allow training to resume; saved at regular intervals + # full checkpoints *that allow training to resume unchanged* in the same `accelerate` environment. + # (Includes optimizer, grad scaler, rng state, etc.) - {Experiment.run_name}_epoch_0 - - {Experiment.run_name}_epoch_k + - {Experiment.run_name}_epoch_{Experiment.ckpt_interval} ... + # pure weight files saved at same interval. These are + # backups that avoid `accelerate` state version control + # and are more portable for inference. + - policy_epoch_0.pt + - policy_epoch_{Experiment.ckpt_interval}.pt - policy.pt # the current model weights - used to communicate between actor/learner processes - config.txt # stores gin configuration details for reproducibility (see below) ``` diff --git a/amago/envs/env_utils.py b/amago/envs/env_utils.py index caa84e4..cf185e1 100644 --- a/amago/envs/env_utils.py +++ b/amago/envs/env_utils.py @@ -290,6 +290,7 @@ def add_exploration_noise(self, action, local_step, horizon): return expl_action +@gin.configurable class EpsilonGreedy(BilevelEpsilonGreedy): """ Sets the parameters of the BilevelEpsilonGreedy wrapper to be equivalent to standard epsilon-greedy. diff --git a/amago/learning.py b/amago/learning.py index 9a121ca..21c1175 100644 --- a/amago/learning.py +++ b/amago/learning.py @@ -231,10 +231,20 @@ def init_checkpoints(self): os.makedirs(self.ckpt_dir, exist_ok=True) self.epoch = 0 - def load_checkpoint(self, epoch: int): - ckpt_name = f"{self.run_name}_epoch_{epoch}" - ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) - self.accelerator.load_state(ckpt_path) + def load_checkpoint(self, epoch: int, resume_training_state: bool = True): + if not resume_training_state: + # load the weights without worrrying about resuming the accelerate state + ckpt = utils.retry_load_checkpoint( + os.path.join(self.ckpt_dir, f"policy_epoch_{epoch}.pt"), + map_location=self.DEVICE, + ) + self.policy_aclr.load_state_dict(ckpt) + else: + # loads weights and will set the epoch but otherwise resets training + # (optimizer, grad scaler, etc.) + ckpt_name = f"{self.run_name}_epoch_{epoch}" + ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) + self.accelerator.load_state(ckpt_path) self.epoch = epoch def save_checkpoint(self): @@ -242,6 +252,12 @@ def save_checkpoint(self): self.accelerator.save_state( os.path.join(self.ckpt_dir, ckpt_name), safe_serialization=True ) + if self.accelerator.is_main_process: + # create backup of raw weights unrelated to the more complex process of resuming an accelerate state + weights_only = torch.save( + self.policy_aclr.state_dict(), + os.path.join(self.ckpt_dir, f"policy_epoch_{self.epoch}.pt"), + ) def write_latest_policy(self): ckpt_name = os.path.join( From 0861bd78a13a2d752fe90a20d22f0a2121aff7e4 Mon Sep 17 00:00:00 2001 From: jakegrigsby Date: Mon, 9 Sep 2024 12:13:00 -0500 Subject: [PATCH 3/3] split checkpoints into subdirs --- README.md | 43 +++++++++++++++++++++++-------------------- amago/learning.py | 15 +++++++++++---- 2 files changed, 34 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 275142d..bd1f3a1 100644 --- a/README.md +++ b/README.md @@ -100,26 +100,29 @@ Aside from the `wandb` logging metrics, AMAGO outputs data in the following form ```bash {Experiment.dset_root} or {--buffer_dir}/ - - {Experiment.dset_name} or {--run_name}/ - - train/ - # replay buffer of sequence data stored on disk as `*.traj` files - {environment_name}_{random_id}_{unix_time}.traj - {environment_name}_{another_random_id}_{later_unix_time}.traj - ... - - {Experiment.run_name} or {--run_name}/ - - ckpts/ - # full checkpoints *that allow training to resume unchanged* in the same `accelerate` environment. - # (Includes optimizer, grad scaler, rng state, etc.) - - {Experiment.run_name}_epoch_0 - - {Experiment.run_name}_epoch_{Experiment.ckpt_interval} - ... - # pure weight files saved at same interval. These are - # backups that avoid `accelerate` state version control - # and are more portable for inference. - - policy_epoch_0.pt - - policy_epoch_{Experiment.ckpt_interval}.pt - - policy.pt # the current model weights - used to communicate between actor/learner processes - - config.txt # stores gin configuration details for reproducibility (see below) + |-- {Experiment.dset_name} or {--run_name}/ + |-- train/ + | # replay buffer of sequence data stored on disk as `*.traj` files. + | {environment_name}_{random_id}_{unix_time}.traj + | {environment_name}_{another_random_id}_{later_unix_time}.traj + | ... + |-- {Experiment.run_name} or {--run_name}/ + |-- config.txt # stores gin configuration details for reproducibility (see below) + |-- policy.pt # the current model weights - used to communicate between actor/learner processes + |-- ckpts/ + |-- training_states/ + | | # full checkpoints that restore the entire training setup in the same `accelerate` environment. + | | # (Including the optimizer, grad scaler, rng state, etc.) + | |-- {Experiment.run_name}_epoch_0/ + | - # `accelerate` files you probably don't need + | |-- {Experiment.run_name}_epoch_{Experiment.ckpt_interval}/ + | |-- ... + |-- policy_weights/ + | # pure weight files that avoid `accelerate` state version control and are more portable for inference. + |-- policy_epoch_0.pt + |-- policy_epoch_{Experiment.ckpt_interval}.pt + |-- ... + -- # any other runs that share this replay buffer would be listed here ```
diff --git a/amago/learning.py b/amago/learning.py index 21c1175..ac67f80 100644 --- a/amago/learning.py +++ b/amago/learning.py @@ -229,13 +229,17 @@ def init_checkpoints(self): self.dset_root, self.dset_name, self.run_name, "ckpts" ) os.makedirs(self.ckpt_dir, exist_ok=True) + os.makedirs(os.path.join(self.ckpt_dir, "training_states"), exist_ok=True) + os.makedirs(os.path.join(self.ckpt_dir, "policy_weights"), exist_ok=True) self.epoch = 0 def load_checkpoint(self, epoch: int, resume_training_state: bool = True): if not resume_training_state: # load the weights without worrrying about resuming the accelerate state ckpt = utils.retry_load_checkpoint( - os.path.join(self.ckpt_dir, f"policy_epoch_{epoch}.pt"), + os.path.join( + self.ckpt_dir, "policy_weights", f"policy_epoch_{epoch}.pt" + ), map_location=self.DEVICE, ) self.policy_aclr.load_state_dict(ckpt) @@ -243,20 +247,23 @@ def load_checkpoint(self, epoch: int, resume_training_state: bool = True): # loads weights and will set the epoch but otherwise resets training # (optimizer, grad scaler, etc.) ckpt_name = f"{self.run_name}_epoch_{epoch}" - ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) + ckpt_path = os.path.join(self.ckpt_dir, "training_states", ckpt_name) self.accelerator.load_state(ckpt_path) self.epoch = epoch def save_checkpoint(self): ckpt_name = f"{self.run_name}_epoch_{self.epoch}" self.accelerator.save_state( - os.path.join(self.ckpt_dir, ckpt_name), safe_serialization=True + os.path.join(self.ckpt_dir, "training_states", ckpt_name), + safe_serialization=True, ) if self.accelerator.is_main_process: # create backup of raw weights unrelated to the more complex process of resuming an accelerate state weights_only = torch.save( self.policy_aclr.state_dict(), - os.path.join(self.ckpt_dir, f"policy_epoch_{self.epoch}.pt"), + os.path.join( + self.ckpt_dir, "policy_weights", f"policy_epoch_{self.epoch}.pt" + ), ) def write_latest_policy(self):