-
Notifications
You must be signed in to change notification settings - Fork 254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
save videos during training #597
base: master
Are you sure you want to change the base?
Conversation
A lot of the test failures on Mac seem to be down to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The approach seems good at a high-level and much simpler than the previous one.
Key issue is that unfortunately I think checkpoint_interval
does not have unit of timesteps or episodes but rather algorithm iterations, and this is additionally different from algorithm to algorithm. So we either need to do conversion, trigger the video recording in some other way, or specify video recording frequency separately in the config.
I'd also suggest having a way to disable video recording and probably leave it off by default as it introduces overhead and extra binary dependencies.
The other higher-level change is I suspect there's a way to cut down on code duplication in the script by putting a lot of this logic into the common
Sacred ingredient. But I've not figured out the details on this so this suggestion may end up being off-base.
Other comments were fairly minor.
@@ -177,4 +177,4 @@ def make_venv( | |||
try: | |||
yield venv | |||
finally: | |||
venv.close() | |||
venv.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this change (removing newline) intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope -- fixing now
@@ -10,6 +10,8 @@ | |||
from imitation.data import rollout | |||
from imitation.policies import base | |||
from imitation.scripts.common import common | |||
from imitation.util import video_wrapper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused import?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed now!
@@ -21,6 +23,8 @@ | |||
from imitation.scripts.config.train_preference_comparisons import ( | |||
train_preference_comparisons_ex, | |||
) | |||
import imitation.util.video_wrapper as video_wrapper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate import?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch -- fixed
tests/scripts/test_scripts.py
Outdated
Saves a preference comparisons ensemble reward, then loads it for transfer learning. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to remove this?
tests/scripts/test_scripts.py
Outdated
|
||
def _check_video_exists(log_dir, algo): | ||
video_dir = VIDEO_PATH_DICT[algo](log_dir) | ||
assert os.path.exists(video_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If video_dir
is a Pathlib.path
I think just video_dir.exists()
works
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just updated!
tests/scripts/test_scripts.py
Outdated
def _check_video_exists(log_dir, algo): | ||
video_dir = VIDEO_PATH_DICT[algo](log_dir) | ||
assert os.path.exists(video_dir) | ||
assert VIDEO_FILE_PATH in os.listdir(video_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I'd guess there's a pathlib version of this too but not sure.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be updated now
assert os.path.exists(video_dir) | ||
assert VIDEO_FILE_PATH in os.listdir(video_dir) | ||
|
||
def test_train_rl_video_saving(tmpdir): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a lot of duplication here with test_train_rl_main
, perhaps we can combine them somehow? This comment also applies to some extent to test_train_adversarial_*
and test_train_preference_comparisons_*
below.
@@ -113,6 +113,11 @@ def test_wandb_output_format(): | |||
{"_step": 0, "foo": 42, "fizz": 12}, | |||
{"_step": 3, "fizz": 21}, | |||
] | |||
|
|||
with pytest.raises(ValueError, match=r"wandb.Video accepts a file path.*"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Testing invalid input error handling is good but it's a bit odd we're not also testing that it does the right thing with a valid input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Testing the basic saving features already exists in a previous test (since manual video saving is already supported) -- I just added this test since it was in the original PR and figured it couldn't hurt.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes, I think this is almost there.
Left some comments but most of them are pretty minor stylistic things.
One high-level issue I noticed: I think the video_save_interval
only has any effect if single_video
is False. Otherwise, I think the video just keeps recording once you start it (though I might be wrong here -- I've not actually tested the bug exists). Unfortunately it's True by default! If I'm right, we should probably do input validation to require that single_video
is False
whenever video_save_interval != 1
, and consider adding a test case to make sure videos are saved at an appropriate interval (could probably just mock this to avoid actually stepping through an environment and saving multiple videos).
There's a lot of lint errors, they seem to mostly be about docstring formatting. Our linter here is a bit obscure, so let me know if you have trouble figuring out what any of them mean. We're expecting docstrings to be in this format: https://www.sphinx-doc.org/en/master/usage/extensions/example_google.html
video_dir = base_dir / "videos" | ||
video_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
post_wrappers_copy = [wrapper for wrapper in post_wrappers] \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
post_wrappers_copy = [wrapper for wrapper in post_wrappers] \ | |
post_wrappers_copy = list(post_wrappers) if post_wrappers != None else [] |
_run, | ||
base_dir: pathlib.Path, | ||
video_save_interval: int, | ||
post_wrappers: Optional[Sequence[Callable[[gym.Env, int], gym.Env]]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sequence[Callable[[gym.Env, int], gym.Env]]
is quite long and repeated -- maybe introduce a type alias for it?
@@ -30,6 +30,7 @@ def defaults(): | |||
algorithm_specific = {} # algorithm_specific[algorithm] is merged with config | |||
|
|||
checkpoint_interval = 0 # Num epochs between checkpoints (<0 disables) | |||
video_save_interval = 0 # Number of steps before saving video (<=0 disables) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll note you could put this in configs/common/common.py
. However, eval_policy
has its own video logic that is not easy to unify (it saves every episode for every environment, not just the first, which makes sense given its purpose) so it's not truly common so seems OK to leave it per-algorithm.
post_wrappers = common_config.setup_video_saving( | ||
base_dir=checkpoint_dir, | ||
video_save_interval=video_save_interval, | ||
post_wrappers=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None is the default so you could probably omit this line?
post_wrappers = common.setup_video_saving( | ||
base_dir=log_dir, | ||
video_save_interval=video_save_interval, | ||
post_wrappers=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None is the default so you could probably omit this line?
src/imitation/util/video_wrapper.py
Outdated
@@ -81,7 +82,7 @@ def reset(self): | |||
def step(self, action): | |||
res = self.env.step(action) | |||
self.step_count += 1 | |||
if self.step_count % self.cadence == 0: | |||
if self.step_count % self.video_save_interval == 0: | |||
self.should_record == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line doesn't do anything (equality check) -- was this meant to be an assignment? Also self.should_record
has type bool not int.
tests/scripts/test_scripts.py
Outdated
|
||
"""Set Sacred capture mode to "sys" because default "fd" option leads to error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's odd, most of our functions do not have whitespace before docstring, I am not sure why this would cause a linter error.
Is it this?
tests/scripts/test_scripts.py:1:1: D205 1 blank line required between summary line and description
I think that's complaining there's not a newline following this line, i.e. it's expecting docstrings in format:
A short one sentence summary.
Optionally, a more elaborate description of what this function does.
Some details only for the astute reader.
Args:
foo: ...
new list with just the video wrapper. If the video_save_interval is <=0, | ||
it will just return the inputted post_wrapper | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
""" | ||
|
||
if video_save_interval > 0: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return post_wrappers_copy | ||
|
||
return post_wrappers | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normally have two lines between functions not four (though I guess more is OK if you want to emphasize separation between them)
Codecov Report
@@ Coverage Diff @@
## master #597 +/- ##
==========================================
+ Coverage 97.49% 97.54% +0.04%
==========================================
Files 85 85
Lines 8099 8176 +77
==========================================
+ Hits 7896 7975 +79
+ Misses 203 201 -2
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
Description
This addresses Issue #523 to automatically save videos during training time. This builds off of the following, earlier PR.
Known Limitations:
(1) Will not necessarily save a video at the end of the training run - it just saves a video at the first episode after each checkpoint.
(2) Saves videos during training episodes (and not in a separate, evaluation environment)
Testing
Added tests to test_scripts.