diff --git a/scripts/demo/video_sampling.py b/scripts/demo/video_sampling.py index 1f4fcfc4..73f19652 100644 --- a/scripts/demo/video_sampling.py +++ b/scripts/demo/video_sampling.py @@ -2,6 +2,7 @@ import sys sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../"))) +import torch from pytorch_lightning import seed_everything from scripts.demo.streamlit_helpers import * from scripts.demo.sv3d_helpers import * @@ -251,6 +252,7 @@ saving_fps = value_dict["fps"] if st.button("Sample"): + torch.cuda.empty_cache() out = do_sample( model, sampler,