Skip to content

Commit

Permalink
Merge pull request #46 from UT-Austin-RPL/add_info
Browse files Browse the repository at this point in the history
Improve Checkpointing
  • Loading branch information
jakegrigsby authored Sep 9, 2024
2 parents 12380c0 + 0861bd7 commit 06d64ea
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 24 deletions.
37 changes: 23 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,29 @@ Aside from the `wandb` logging metrics, AMAGO outputs data in the following form

```bash
{Experiment.dset_root} or {--buffer_dir}/
- {Experiment.dset_name} or {--run_name}/
- train/
# replay buffer of sequence data stored on disk as `*.traj` files
{environment_name}_{random_id}_{unix_time}.traj
{environment_name}_{another_random_id}_{later_unix_time}.traj
...
- {Experiment.run_name} or {--run_name}/
- ckpts/
# full checkpoints that allow training to resume; saved at regular intervals
- {Experiment.run_name}_epoch_0
- {Experiment.run_name}_epoch_k
...
- policy.pt # the current model weights - used to communicate between actor/learner processes
- config.txt # stores gin configuration details for reproducibility (see below)
|-- {Experiment.dset_name} or {--run_name}/
|-- train/
| # replay buffer of sequence data stored on disk as `*.traj` files.
| {environment_name}_{random_id}_{unix_time}.traj
| {environment_name}_{another_random_id}_{later_unix_time}.traj
| ...
|-- {Experiment.run_name} or {--run_name}/
|-- config.txt # stores gin configuration details for reproducibility (see below)
|-- policy.pt # the current model weights - used to communicate between actor/learner processes
|-- ckpts/
|-- training_states/
| | # full checkpoints that restore the entire training setup in the same `accelerate` environment.
| | # (Including the optimizer, grad scaler, rng state, etc.)
| |-- {Experiment.run_name}_epoch_0/
| - # `accelerate` files you probably don't need
| |-- {Experiment.run_name}_epoch_{Experiment.ckpt_interval}/
| |-- ...
|-- policy_weights/
| # pure weight files that avoid `accelerate` state version control and are more portable for inference.
|-- policy_epoch_0.pt
|-- policy_epoch_{Experiment.ckpt_interval}.pt
|-- ...
-- # any other runs that share this replay buffer would be listed here
```

<br>
Expand Down
4 changes: 2 additions & 2 deletions amago/envs/amago_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def inner_reset(self, seed=None, options=None):

def reset(self, seed=None, options=None) -> Timestep:
self.step_count = 0
obs, _ = self.inner_reset(seed=seed, options=options)
obs, info = self.inner_reset(seed=seed, options=options)
if not isinstance(obs, dict):
obs = {"observation": obs}
timestep = Timestep(
Expand All @@ -110,7 +110,7 @@ def reset(self, seed=None, options=None) -> Timestep:
terminal=False,
raw_time_idx=0,
)
return timestep
return timestep, info

def inner_step(self, action):
return self.env.step(action)
Expand Down
6 changes: 3 additions & 3 deletions amago/envs/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def sequence_lengths(self):


class ExplorationWrapper(ABC, gym.ActionWrapper):

@abstractmethod
def add_exploration_noise(self, action: np.ndarray, local_step: int, horizon: int):
raise NotImplementedError
Expand Down Expand Up @@ -291,6 +290,7 @@ def add_exploration_noise(self, action, local_step, horizon):
return expl_action


@gin.configurable
class EpsilonGreedy(BilevelEpsilonGreedy):
"""
Sets the parameters of the BilevelEpsilonGreedy wrapper to be equivalent to standard epsilon-greedy.
Expand Down Expand Up @@ -390,7 +390,7 @@ def reset_stats(self):
self.success_history = SuccessHistory(self.env_name)

def reset(self, seed=None) -> Timestep:
timestep = self.env.reset(seed=seed)
timestep, info = self.env.reset(seed=seed)
self.active_traj = Trajectory(
max_goals=self.env.max_goal_seq_length, timesteps=[timestep]
)
Expand All @@ -400,7 +400,7 @@ def reset(self, seed=None) -> Timestep:
)
self.total_return = 0.0
self._current_timestep = self.active_traj.make_sequence(last_only=True)
return timestep.obs, {}
return timestep.obs, info

def step(self, action):
timestep, reward, terminated, truncated, info = self.env.step(action)
Expand Down
33 changes: 28 additions & 5 deletions amago/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,19 +229,42 @@ def init_checkpoints(self):
self.dset_root, self.dset_name, self.run_name, "ckpts"
)
os.makedirs(self.ckpt_dir, exist_ok=True)
os.makedirs(os.path.join(self.ckpt_dir, "training_states"), exist_ok=True)
os.makedirs(os.path.join(self.ckpt_dir, "policy_weights"), exist_ok=True)
self.epoch = 0

def load_checkpoint(self, epoch: int):
ckpt_name = f"{self.run_name}_epoch_{epoch}"
ckpt_path = os.path.join(self.ckpt_dir, ckpt_name)
self.accelerator.load_state(ckpt_path)
def load_checkpoint(self, epoch: int, resume_training_state: bool = True):
if not resume_training_state:
# load the weights without worrrying about resuming the accelerate state
ckpt = utils.retry_load_checkpoint(
os.path.join(
self.ckpt_dir, "policy_weights", f"policy_epoch_{epoch}.pt"
),
map_location=self.DEVICE,
)
self.policy_aclr.load_state_dict(ckpt)
else:
# loads weights and will set the epoch but otherwise resets training
# (optimizer, grad scaler, etc.)
ckpt_name = f"{self.run_name}_epoch_{epoch}"
ckpt_path = os.path.join(self.ckpt_dir, "training_states", ckpt_name)
self.accelerator.load_state(ckpt_path)
self.epoch = epoch

def save_checkpoint(self):
ckpt_name = f"{self.run_name}_epoch_{self.epoch}"
self.accelerator.save_state(
os.path.join(self.ckpt_dir, ckpt_name), safe_serialization=True
os.path.join(self.ckpt_dir, "training_states", ckpt_name),
safe_serialization=True,
)
if self.accelerator.is_main_process:
# create backup of raw weights unrelated to the more complex process of resuming an accelerate state
weights_only = torch.save(
self.policy_aclr.state_dict(),
os.path.join(
self.ckpt_dir, "policy_weights", f"policy_epoch_{self.epoch}.pt"
),
)

def write_latest_policy(self):
ckpt_name = os.path.join(
Expand Down

0 comments on commit 06d64ea

Please sign in to comment.