From 72047ba58884bc9b1be8efbe39484eefa4acc326 Mon Sep 17 00:00:00 2001 From: Suraksha Motilal Date: Thu, 5 Oct 2023 07:43:39 +0200 Subject: [PATCH 1/2] Update musicgen_app.py Fixed memory issue of models staying in memory when model is changed --- demos/musicgen_app.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/demos/musicgen_app.py b/demos/musicgen_app.py index 74c893e7..29dbeae1 100644 --- a/demos/musicgen_app.py +++ b/demos/musicgen_app.py @@ -72,6 +72,13 @@ def _cleanup(self): self.files.pop(0) else: break + + def delete_all(self): + # delete regardless of file's life time + for _, path in list(self.files): + if path.exists(): + path.unlink() + self.files = [] file_cleaner = FileCleaner() @@ -91,6 +98,10 @@ def load_model(version='facebook/musicgen-melody'): global MODEL print("Loading model", version) if MODEL is None or MODEL.name != version: + # Clear PyTorch CUDA cache and delete model + del MODEL + torch.cuda.empty_cache() + MODEL = MusicGen.get_pretrained(version) @@ -102,6 +113,9 @@ def load_diffusion(): def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs): + # get rid of temp files + file_cleaner.delete_all() + MODEL.set_generation_params(duration=duration, **gen_kwargs) print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies]) be = time.time() From ba67e4db7e711d2474909402e9f616a0e98931f1 Mon Sep 17 00:00:00 2001 From: Suraksha Motilal Date: Thu, 12 Oct 2023 10:40:29 +0200 Subject: [PATCH 2/2] Removed temp file deletion --- demos/musicgen_app.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/demos/musicgen_app.py b/demos/musicgen_app.py index 29dbeae1..9847e56c 100644 --- a/demos/musicgen_app.py +++ b/demos/musicgen_app.py @@ -73,14 +73,6 @@ def _cleanup(self): else: break - def delete_all(self): - # delete regardless of file's life time - for _, path in list(self.files): - if path.exists(): - path.unlink() - self.files = [] - - file_cleaner = FileCleaner() @@ -113,9 +105,7 @@ def load_diffusion(): def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs): - # get rid of temp files - file_cleaner.delete_all() - + MODEL.set_generation_params(duration=duration, **gen_kwargs) print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies]) be = time.time()