Skip to content

Commit

Permalink
Fix test errors
Browse files Browse the repository at this point in the history
  • Loading branch information
taufeeque9 authored and ernestum committed Sep 26, 2023
1 parent 0628772 commit b01d51b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
12 changes: 9 additions & 3 deletions src/imitation/algorithms/adversarial/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def _on_step(self) -> bool:
return True

def _on_rollout_end(self) -> None:
if self.gen_ctx_manager is not None:
self.exit_gen_ctx_manager()
gen_trajs, ep_lens = self.adversarial_trainer.venv_buffering.pop_trajectories()
self.adversarial_trainer._check_fixed_horizon(ep_lens)
gen_samples = rollout.flatten_trajectories_with_rew(gen_trajs)
Expand All @@ -133,9 +135,13 @@ def _on_rollout_end(self) -> None:
self.gen_ctx_manager = self.adversarial_trainer.logger.accumulate_means("gen")
self.gen_ctx_manager.__enter__()

def _on_training_end(self) -> None:
def exit_gen_ctx_manager(self) -> None:
assert self.gen_ctx_manager is not None
self.gen_ctx_manager.__exit__(None, None, None)
self.gen_ctx_manager = None

def _on_training_end(self) -> None:
self.exit_gen_ctx_manager()


class AdversarialTrainer(base.DemonstrationAlgorithm[types.Transitions]):
Expand Down Expand Up @@ -514,8 +520,8 @@ def train(
) -> None:
"""Alternates between training the generator and discriminator.
Every "round" consists of a call to `train_gen_with_disc(self.gen_train_timesteps)`,
a call to `train_disc`, and finally a call to `callback(round)`.
Every "round" consists of a call to
`train_gen_with_disc(self.gen_train_timesteps)` and a call to `callback(round)`.
Training ends once an additional "round" would cause the number of transitions
sampled from the environment to exceed `total_timesteps`.
Expand Down
2 changes: 1 addition & 1 deletion tests/algorithms/test_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_train_gen_train_disc_no_crash(
n_updates: int = 2,
) -> None:
trainer_parametrized.train_gen_with_disc(
n_updates * trainer_parametrized.gen_train_timesteps
n_updates * trainer_parametrized.gen_train_timesteps,
)


Expand Down

0 comments on commit b01d51b

Please sign in to comment.