diff --git a/demos/musicgen_app.py b/demos/musicgen_app.py index a10d52b5..2bbd6556 100644 --- a/demos/musicgen_app.py +++ b/demos/musicgen_app.py @@ -77,8 +77,7 @@ def _cleanup(self): self.files.pop(0) else: break - - + file_cleaner = FileCleaner() @@ -96,6 +95,9 @@ def load_model(version='facebook/musicgen-melody'): global MODEL print("Loading model", version) if MODEL is None or MODEL.name != version: + # Clear PyTorch CUDA cache and delete model + del MODEL + torch.cuda.empty_cache() MODEL = None # in case loading would crash MODEL = MusicGen.get_pretrained(version)