Skip to content

Commit

Permalink
MAGNeT v1 release (#33)
Browse files Browse the repository at this point in the history
MAGNeT v1
  • Loading branch information
lonzi authored and alonzi committed Jan 15, 2024
1 parent 7581ba3 commit 763cdc1
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
12 changes: 9 additions & 3 deletions audiocraft/models/genmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
# LICENSE file in the root directory of this source tree.

"""
Base model for audio generative models. This will combine all the required components
and provide easy access to the generation API.
Base implementation for audio generative models. This base implementation
combines all the required components to run inference with pretrained audio
generative models. It can be easily inherited by downstream model classes to
provide easy access to the generation API.
"""

from abc import ABC, abstractmethod
Expand Down Expand Up @@ -61,6 +63,10 @@ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,

self.max_duration: float = max_duration
self.duration = self.max_duration

# self.extend_stride is the length of audio extension when generating samples longer
# than self.max_duration. NOTE: the derived class must set self.extend_stride to a
# positive float value when generating with self.duration > self.max_duration.
self.extend_stride: tp.Optional[float] = None
self.device = next(iter(lm.parameters())).device
self.generation_params: dict = {}
Expand Down Expand Up @@ -161,7 +167,7 @@ def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
progress: bool = False, return_tokens: bool = False) \
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
"""Generate samples conditioned on audio prompts.
"""Generate samples conditioned on audio prompts and an optional text description.
Args:
prompt (torch.Tensor): A batch of waveforms used for continuation.
Expand Down
1 change: 1 addition & 0 deletions audiocraft/models/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compression_mod
cfg.transformer_lm.segment_duration = cfg.dataset.segment_duration
cfg.transformer_lm.span_len = cfg.masking.span_len

# MAGNeT models v1 support only xformers backend.
from audiocraft.modules.transformer import set_efficient_attention_backend
if cfg.transformer_lm.memory_efficient:
set_efficient_attention_backend("xformers")
Expand Down
1 change: 1 addition & 0 deletions audiocraft/solvers/magnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, cfg: DictConfig):
def build_model(self) -> None:
self.cfg.transformer_lm.segment_duration = self.cfg.dataset.segment_duration
self.cfg.transformer_lm.span_len = self.cfg.masking.span_len
assert self.cfg.efficient_attention_backend == "xformers", "MAGNeT v1 models support only xformers backend."
super().build_model()

def _calc_mean_maskrate_to_u_LUT(self, T: int):
Expand Down

0 comments on commit 763cdc1

Please sign in to comment.