diff --git a/README.md b/README.md index 393a0dd..bd1f3a1 100644 --- a/README.md +++ b/README.md @@ -100,20 +100,29 @@ Aside from the `wandb` logging metrics, AMAGO outputs data in the following form ```bash {Experiment.dset_root} or {--buffer_dir}/ - - {Experiment.dset_name} or {--run_name}/ - - train/ - # replay buffer of sequence data stored on disk as `*.traj` files - {environment_name}_{random_id}_{unix_time}.traj - {environment_name}_{another_random_id}_{later_unix_time}.traj - ... - - {Experiment.run_name} or {--run_name}/ - - ckpts/ - # full checkpoints that allow training to resume; saved at regular intervals - - {Experiment.run_name}_epoch_0 - - {Experiment.run_name}_epoch_k - ... - - policy.pt # the current model weights - used to communicate between actor/learner processes - - config.txt # stores gin configuration details for reproducibility (see below) + |-- {Experiment.dset_name} or {--run_name}/ + |-- train/ + | # replay buffer of sequence data stored on disk as `*.traj` files. + | {environment_name}_{random_id}_{unix_time}.traj + | {environment_name}_{another_random_id}_{later_unix_time}.traj + | ... + |-- {Experiment.run_name} or {--run_name}/ + |-- config.txt # stores gin configuration details for reproducibility (see below) + |-- policy.pt # the current model weights - used to communicate between actor/learner processes + |-- ckpts/ + |-- training_states/ + | | # full checkpoints that restore the entire training setup in the same `accelerate` environment. + | | # (Including the optimizer, grad scaler, rng state, etc.) + | |-- {Experiment.run_name}_epoch_0/ + | - # `accelerate` files you probably don't need + | |-- {Experiment.run_name}_epoch_{Experiment.ckpt_interval}/ + | |-- ... + |-- policy_weights/ + | # pure weight files that avoid `accelerate` state version control and are more portable for inference. + |-- policy_epoch_0.pt + |-- policy_epoch_{Experiment.ckpt_interval}.pt + |-- ... + -- # any other runs that share this replay buffer would be listed here ```
diff --git a/amago/envs/amago_env.py b/amago/envs/amago_env.py index 121256a..b0463f0 100644 --- a/amago/envs/amago_env.py +++ b/amago/envs/amago_env.py @@ -96,7 +96,7 @@ def inner_reset(self, seed=None, options=None): def reset(self, seed=None, options=None) -> Timestep: self.step_count = 0 - obs, _ = self.inner_reset(seed=seed, options=options) + obs, info = self.inner_reset(seed=seed, options=options) if not isinstance(obs, dict): obs = {"observation": obs} timestep = Timestep( @@ -110,7 +110,7 @@ def reset(self, seed=None, options=None) -> Timestep: terminal=False, raw_time_idx=0, ) - return timestep + return timestep, info def inner_step(self, action): return self.env.step(action) diff --git a/amago/envs/env_utils.py b/amago/envs/env_utils.py index de3d0cb..cf185e1 100644 --- a/amago/envs/env_utils.py +++ b/amago/envs/env_utils.py @@ -202,7 +202,6 @@ def sequence_lengths(self): class ExplorationWrapper(ABC, gym.ActionWrapper): - @abstractmethod def add_exploration_noise(self, action: np.ndarray, local_step: int, horizon: int): raise NotImplementedError @@ -291,6 +290,7 @@ def add_exploration_noise(self, action, local_step, horizon): return expl_action +@gin.configurable class EpsilonGreedy(BilevelEpsilonGreedy): """ Sets the parameters of the BilevelEpsilonGreedy wrapper to be equivalent to standard epsilon-greedy. @@ -390,7 +390,7 @@ def reset_stats(self): self.success_history = SuccessHistory(self.env_name) def reset(self, seed=None) -> Timestep: - timestep = self.env.reset(seed=seed) + timestep, info = self.env.reset(seed=seed) self.active_traj = Trajectory( max_goals=self.env.max_goal_seq_length, timesteps=[timestep] ) @@ -400,7 +400,7 @@ def reset(self, seed=None) -> Timestep: ) self.total_return = 0.0 self._current_timestep = self.active_traj.make_sequence(last_only=True) - return timestep.obs, {} + return timestep.obs, info def step(self, action): timestep, reward, terminated, truncated, info = self.env.step(action) diff --git a/amago/learning.py b/amago/learning.py index 9a121ca..ac67f80 100644 --- a/amago/learning.py +++ b/amago/learning.py @@ -229,19 +229,42 @@ def init_checkpoints(self): self.dset_root, self.dset_name, self.run_name, "ckpts" ) os.makedirs(self.ckpt_dir, exist_ok=True) + os.makedirs(os.path.join(self.ckpt_dir, "training_states"), exist_ok=True) + os.makedirs(os.path.join(self.ckpt_dir, "policy_weights"), exist_ok=True) self.epoch = 0 - def load_checkpoint(self, epoch: int): - ckpt_name = f"{self.run_name}_epoch_{epoch}" - ckpt_path = os.path.join(self.ckpt_dir, ckpt_name) - self.accelerator.load_state(ckpt_path) + def load_checkpoint(self, epoch: int, resume_training_state: bool = True): + if not resume_training_state: + # load the weights without worrrying about resuming the accelerate state + ckpt = utils.retry_load_checkpoint( + os.path.join( + self.ckpt_dir, "policy_weights", f"policy_epoch_{epoch}.pt" + ), + map_location=self.DEVICE, + ) + self.policy_aclr.load_state_dict(ckpt) + else: + # loads weights and will set the epoch but otherwise resets training + # (optimizer, grad scaler, etc.) + ckpt_name = f"{self.run_name}_epoch_{epoch}" + ckpt_path = os.path.join(self.ckpt_dir, "training_states", ckpt_name) + self.accelerator.load_state(ckpt_path) self.epoch = epoch def save_checkpoint(self): ckpt_name = f"{self.run_name}_epoch_{self.epoch}" self.accelerator.save_state( - os.path.join(self.ckpt_dir, ckpt_name), safe_serialization=True + os.path.join(self.ckpt_dir, "training_states", ckpt_name), + safe_serialization=True, ) + if self.accelerator.is_main_process: + # create backup of raw weights unrelated to the more complex process of resuming an accelerate state + weights_only = torch.save( + self.policy_aclr.state_dict(), + os.path.join( + self.ckpt_dir, "policy_weights", f"policy_epoch_{self.epoch}.pt" + ), + ) def write_latest_policy(self): ckpt_name = os.path.join(