From 763cdc1dab8a21be81a2e70ff4511c4b2bc23042 Mon Sep 17 00:00:00 2001 From: Alon Ziv <30550331+lonzi@users.noreply.github.com> Date: Mon, 15 Jan 2024 18:24:36 +0200 Subject: [PATCH] MAGNeT v1 release (#33) MAGNeT v1 --- audiocraft/models/genmodel.py | 12 +++++++++--- audiocraft/models/loaders.py | 1 + audiocraft/solvers/magnet.py | 1 + 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/audiocraft/models/genmodel.py b/audiocraft/models/genmodel.py index 9cd3e0cd..96397450 100644 --- a/audiocraft/models/genmodel.py +++ b/audiocraft/models/genmodel.py @@ -5,8 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -Base model for audio generative models. This will combine all the required components -and provide easy access to the generation API. +Base implementation for audio generative models. This base implementation +combines all the required components to run inference with pretrained audio +generative models. It can be easily inherited by downstream model classes to +provide easy access to the generation API. """ from abc import ABC, abstractmethod @@ -61,6 +63,10 @@ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, self.max_duration: float = max_duration self.duration = self.max_duration + + # self.extend_stride is the length of audio extension when generating samples longer + # than self.max_duration. NOTE: the derived class must set self.extend_stride to a + # positive float value when generating with self.duration > self.max_duration. self.extend_stride: tp.Optional[float] = None self.device = next(iter(lm.parameters())).device self.generation_params: dict = {} @@ -161,7 +167,7 @@ def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None, progress: bool = False, return_tokens: bool = False) \ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: - """Generate samples conditioned on audio prompts. + """Generate samples conditioned on audio prompts and an optional text description. Args: prompt (torch.Tensor): A batch of waveforms used for continuation. diff --git a/audiocraft/models/loaders.py b/audiocraft/models/loaders.py index 995953bc..a6ec475e 100644 --- a/audiocraft/models/loaders.py +++ b/audiocraft/models/loaders.py @@ -134,6 +134,7 @@ def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compression_mod cfg.transformer_lm.segment_duration = cfg.dataset.segment_duration cfg.transformer_lm.span_len = cfg.masking.span_len + # MAGNeT models v1 support only xformers backend. from audiocraft.modules.transformer import set_efficient_attention_backend if cfg.transformer_lm.memory_efficient: set_efficient_attention_backend("xformers") diff --git a/audiocraft/solvers/magnet.py b/audiocraft/solvers/magnet.py index 12e2778f..5c401202 100644 --- a/audiocraft/solvers/magnet.py +++ b/audiocraft/solvers/magnet.py @@ -47,6 +47,7 @@ def __init__(self, cfg: DictConfig): def build_model(self) -> None: self.cfg.transformer_lm.segment_duration = self.cfg.dataset.segment_duration self.cfg.transformer_lm.span_len = self.cfg.masking.span_len + assert self.cfg.efficient_attention_backend == "xformers", "MAGNeT v1 models support only xformers backend." super().build_model() def _calc_mean_maskrate_to_u_LUT(self, T: int):