diff --git a/.github/actions/audiocraft_build/action.yml b/.github/actions/audiocraft_build/action.yml
index be5dae26..b412cd02 100644
--- a/.github/actions/audiocraft_build/action.yml
+++ b/.github/actions/audiocraft_build/action.yml
@@ -21,6 +21,8 @@ runs:
python3 -m venv env
. env/bin/activate
python -m pip install --upgrade pip
+ pip install torch torchvision torchaudio
+ pip install xformers
pip install -e '.[dev]'
- name: System Dependencies
shell: bash
diff --git a/CHANGELOG.md b/CHANGELOG.md
index aa599e4d..6036b72f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,6 +4,31 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
+## [1.2.0a] - TBD
+
+Adding stereo models.
+
+
+## [1.1.0] - 2023-11-06
+
+Not using torchaudio anymore when writing audio files, relying instead directly on the commandline ffmpeg. Also not using it anymore for reading audio files, for similar reasons.
+
+Fixed DAC support with non default number of codebooks.
+
+Fixed bug when `two_step_cfg` was overriden when calling `generate()`.
+
+Fixed samples being always prompted with audio, rather than having both prompted and unprompted.
+
+**Backward incompatible change:** A `torch.no_grad` around the computation of the conditioning made its way in the public release.
+ The released models were trained without this. Those impact linear layers applied to the output of the T5 or melody conditioners.
+ We removed it, so you might need to retrain models.
+
+**Backward incompatible change:** Fixing wrong sample rate in CLAP (WARNING if you trained model with CLAP before).
+
+**Backward incompatible change:** Renamed VALLEPattern to CoarseFirstPattern, as it was wrongly named. Probably no one
+ retrained a model with this pattern, so hopefully this won't impact you!
+
+
## [1.0.0] - 2023-09-07
Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
diff --git a/README.md b/README.md
index 21b3f497..e3687f1e 100644
--- a/README.md
+++ b/README.md
@@ -13,11 +13,11 @@ AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can ru
```shell
# Best to make sure you have torch installed first, in particular before installing xformers.
# Don't run this if you already have PyTorch installed.
-pip install 'torch>=2.0'
+python -m pip install 'torch>=2.0'
# Then proceed to one of the following
-pip install -U audiocraft # stable release
-pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
-pip install -e . # or if you cloned the repo locally (mandatory if you want to train).
+python -m pip install -U audiocraft # stable release
+python -m pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
+python -m pip install -e . # or if you cloned the repo locally (mandatory if you want to train).
```
We also recommend having `ffmpeg` installed, either through your system or Anaconda:
@@ -72,11 +72,11 @@ Finally, if you use a model that relies on Demucs (e.g. `musicgen-melody`) and w
For the general framework of AudioCraft, please cite the following.
```
-@article{copet2023simple,
+@inproceedings{copet2023simple,
title={Simple and Controllable Music Generation},
author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
+ booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
- journal={arXiv preprint arXiv:2306.05284},
}
```
diff --git a/audiocraft/__init__.py b/audiocraft/__init__.py
index 6ab34607..8b7acf22 100644
--- a/audiocraft/__init__.py
+++ b/audiocraft/__init__.py
@@ -23,4 +23,4 @@
# flake8: noqa
from . import data, modules, models
-__version__ = '1.0.0'
+__version__ = '1.2.0a1'
diff --git a/audiocraft/data/audio.py b/audiocraft/data/audio.py
index 2ac5e6cf..a35dfd9c 100644
--- a/audiocraft/data/audio.py
+++ b/audiocraft/data/audio.py
@@ -18,11 +18,11 @@
import soundfile
import torch
from torch.nn import functional as F
-import torchaudio as ta
import av
+import subprocess as sp
-from .audio_utils import f32_pcm, i16_pcm, normalize_audio
+from .audio_utils import f32_pcm, normalize_audio
_av_initialized = False
@@ -136,12 +136,6 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
wav = torch.from_numpy(wav).t().contiguous()
if len(wav.shape) == 1:
wav = torch.unsqueeze(wav, 0)
- elif (
- fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
- and duration <= 0 and seek_time == 0
- ):
- # Torchaudio is faster if we load an entire file at once.
- wav, sr = ta.load(fp)
else:
wav, sr = _av_read(filepath, seek_time, duration)
if pad and duration > 0:
@@ -150,10 +144,22 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
return wav, sr
+def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
+ # ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
+ assert wav.dim() == 2, wav.shape
+ command = [
+ 'ffmpeg',
+ '-loglevel', 'error',
+ '-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
+ '-i', '-'] + flags + [str(out_path)]
+ input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
+ sp.run(command, input=input_, check=True)
+
+
def audio_write(stem_name: tp.Union[str, Path],
wav: torch.Tensor, sample_rate: int,
- format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
- strategy: str = 'peak', peak_clip_headroom_db: float = 1,
+ format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
+ normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
loudness_compressor: bool = False,
log_clipping: bool = True, make_parent_dir: bool = True,
@@ -164,8 +170,9 @@ def audio_write(stem_name: tp.Union[str, Path],
stem_name (str or Path): Filename without extension which will be added automatically.
wav (torch.Tensor): Audio data to save.
sample_rate (int): Sample rate of audio data.
- format (str): Either "wav" or "mp3".
+ format (str): Either "wav", "mp3", "ogg", or "flac".
mp3_rate (int): kbps when using mp3s.
+ ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
normalize (bool): if `True` (default), normalizes according to the prescribed
strategy (see after). If `False`, the strategy is only used in case clipping
would happen.
@@ -193,14 +200,20 @@ def audio_write(stem_name: tp.Union[str, Path],
rms_headroom_db, loudness_headroom_db, loudness_compressor,
log_clipping=log_clipping, sample_rate=sample_rate,
stem_name=str(stem_name))
- kwargs: dict = {}
if format == 'mp3':
suffix = '.mp3'
- kwargs.update({"compression": mp3_rate})
+ flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
elif format == 'wav':
- wav = i16_pcm(wav)
suffix = '.wav'
- kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
+ flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
+ elif format == 'ogg':
+ suffix = '.ogg'
+ flags = ['-f', 'ogg', '-c:a', 'libvorbis']
+ if ogg_rate is not None:
+ flags += ['-b:a', f'{ogg_rate}k']
+ elif format == 'flac':
+ suffix = '.flac'
+ flags = ['-f', 'flac']
else:
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
if not add_suffix:
@@ -209,7 +222,7 @@ def audio_write(stem_name: tp.Union[str, Path],
if make_parent_dir:
path.parent.mkdir(exist_ok=True, parents=True)
try:
- ta.save(path, wav, sample_rate, **kwargs)
+ _piping_to_ffmpeg(path, wav, sample_rate, flags)
except Exception:
if path.exists():
# we do not want to leave half written files around.
diff --git a/audiocraft/grids/musicgen/musicgen_stereo_finetune_32khz.py b/audiocraft/grids/musicgen/musicgen_stereo_finetune_32khz.py
new file mode 100644
index 00000000..2904e73d
--- /dev/null
+++ b/audiocraft/grids/musicgen/musicgen_stereo_finetune_32khz.py
@@ -0,0 +1,57 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from pathlib import Path
+from ._explorers import LMExplorer
+from ...environment import AudioCraftEnvironment
+
+
+@LMExplorer
+def explorer(launcher):
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
+ launcher.slurm_(gpus=32, partition=partitions)
+ launcher.bind_(solver='musicgen/musicgen_base_32khz')
+ # replace this by the desired music dataset, which needs to be stereo
+ launcher.bind_(dset='audio/example')
+
+ fsdp = {'autocast': False, 'fsdp.use': True}
+ medium = {'model/lm/model_scale': 'medium'}
+ large = {'model/lm/model_scale': 'large'}
+
+ cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
+ wd_low = {'conditioners.description.t5.word_dropout': 0.2}
+
+ adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
+
+ stereo = {
+ 'codebooks_pattern.delay.delays': [0, 0, 1, 1, 2, 2, 3, 3],
+ 'transformer_lm.n_q': 8,
+ 'interleave_stereo_codebooks.use': True,
+ 'channels': 2,
+ }
+
+ # You must follow the instructions in docs/MUSICGEN.md about the creation
+ # of the proper fine tuning checkpoints. We will assume they are stored under
+ # ~/checkpoints/{mode_name}.
+
+ checkpoints = Path.home() / 'checkpoints'
+
+ launcher.bind_(fsdp, stereo, {'optim.epochs': 100})
+
+ launcher.slurm_(gpus=32).bind_(label='32gpus')
+ with launcher.job_array():
+ sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-small.th')})
+ sub()
+
+ launcher.slurm_(gpus=64).bind_(label='64gpus')
+ with launcher.job_array():
+ sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-medium.th')})
+ sub(medium, adam)
+
+ launcher.slurm_(gpus=96).bind_(label='96gpus')
+ with launcher.job_array():
+ sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-large.th')})
+ sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
diff --git a/audiocraft/models/audiogen.py b/audiocraft/models/audiogen.py
index 5cb88998..b4df536e 100644
--- a/audiocraft/models/audiogen.py
+++ b/audiocraft/models/audiogen.py
@@ -38,6 +38,10 @@ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
self.name = name
self.compression_model = compression_model
self.lm = lm
+ # Just to be safe, let's put everything in eval mode.
+ self.compression_model.eval()
+ self.lm.eval()
+
if max_duration is None:
if hasattr(lm, 'cfg'):
max_duration = lm.cfg.dataset.segment_duration # type: ignore
diff --git a/audiocraft/models/builders.py b/audiocraft/models/builders.py
index 038bf99c..b7144874 100644
--- a/audiocraft/models/builders.py
+++ b/audiocraft/models/builders.py
@@ -15,7 +15,7 @@
import omegaconf
import torch
-from .encodec import CompressionModel, EncodecModel
+from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel
from .lm import LMModel
from ..modules.codebooks_patterns import (
CodebooksPatternProvider,
@@ -23,7 +23,7 @@
MusicLMPattern,
ParallelPatternProvider,
UnrolledPatternProvider,
- VALLEPattern,
+ CoarseFirstPattern,
)
from ..modules.conditioners import (
BaseConditioner,
@@ -172,7 +172,7 @@ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> Codeb
'parallel': ParallelPatternProvider,
'delay': DelayedPatternProvider,
'unroll': UnrolledPatternProvider,
- 'valle': VALLEPattern,
+ 'coarse_first': CoarseFirstPattern,
'musiclm': MusicLMPattern,
}
name = cfg.modeling
@@ -196,7 +196,6 @@ def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
'dimension': 32,
'ratios': ratios,
}
- print(seanet_kwargs)
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
@@ -248,5 +247,12 @@ def get_debug_lm_model(device='cpu'):
def get_wrapped_compression_model(
compression_model: CompressionModel,
cfg: omegaconf.DictConfig) -> CompressionModel:
- # more to come.
+ if hasattr(cfg, 'interleave_stereo_codebooks'):
+ if cfg.interleave_stereo_codebooks.use:
+ kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
+ kwargs.pop('use')
+ compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs)
+ if hasattr(cfg, 'compression_model_n_q'):
+ if cfg.compression_model_n_q is not None:
+ compression_model.set_num_codebooks(cfg.compression_model_n_q)
return compression_model
diff --git a/audiocraft/models/encodec.py b/audiocraft/models/encodec.py
index 40d13301..d4e77a94 100644
--- a/audiocraft/models/encodec.py
+++ b/audiocraft/models/encodec.py
@@ -13,6 +13,7 @@
from pathlib import Path
import typing as tp
+from einops import rearrange
import numpy as np
import torch
from torch import nn
@@ -276,7 +277,7 @@ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
codes = self.model.encode(x, self.n_quantizers)[1]
- return codes, None
+ return codes[:, :self.n_quantizers], None
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
assert scale is None
@@ -391,3 +392,115 @@ def set_num_codebooks(self, n: int):
if n not in self.possible_num_codebooks:
raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}")
self._num_codebooks = n
+
+
+class InterleaveStereoCompressionModel(CompressionModel):
+ """Wraps a CompressionModel to support stereo inputs. The wrapped model
+ will be applied independently to the left and right channels, and both codebooks
+ will be interleaved. If the wrapped model returns a representation `[B, K ,T]` per
+ channel, then the output will be `[B, K * 2, T]` or `[B, K, T * 2]` depending on
+ `per_timestep`.
+
+ Args:
+ model (CompressionModel): Compression model to wrap.
+ per_timestep (bool): Whether to interleave on the timestep dimension
+ or on the codebooks dimension.
+ """
+ def __init__(self, model: CompressionModel, per_timestep: bool = False):
+ super().__init__()
+ self.model = model
+ self.per_timestep = per_timestep
+ assert self.model.channels == 1, "Wrapped model is expected to be for monophonic audio"
+
+ @property
+ def total_codebooks(self):
+ return self.model.total_codebooks
+
+ @property
+ def num_codebooks(self):
+ """Active number of codebooks used by the quantizer.
+
+ ..Warning:: this reports the number of codebooks after the interleaving
+ of the codebooks!
+ """
+ return self.model.num_codebooks if self.per_timestep else self.model.num_codebooks * 2
+
+ def set_num_codebooks(self, n: int):
+ """Set the active number of codebooks used by the quantizer.
+
+ ..Warning:: this sets the number of codebooks before the interleaving!
+ """
+ self.model.set_num_codebooks(n)
+
+ @property
+ def num_virtual_steps(self) -> float:
+ """Return the number of virtual steps, e.g. one real step
+ will be split into that many steps.
+ """
+ return 2 if self.per_timestep else 1
+
+ @property
+ def frame_rate(self) -> float:
+ return self.model.frame_rate * self.num_virtual_steps
+
+ @property
+ def sample_rate(self) -> int:
+ return self.model.sample_rate
+
+ @property
+ def channels(self) -> int:
+ return 2
+
+ @property
+ def cardinality(self):
+ """Cardinality of each codebook.
+ """
+ return self.model.cardinality
+
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+ raise NotImplementedError("Not supported, use encode and decode.")
+
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+ B, C, T = x.shape
+ assert C == self.channels, f"Expecting stereo audio but audio num channels is {C}"
+
+ indices_c0, scales_c0 = self.model.encode(x[:, 0, ...].unsqueeze(1))
+ indices_c1, scales_c1 = self.model.encode(x[:, 1, ...].unsqueeze(1))
+ indices = torch.stack([indices_c0, indices_c1], dim=0)
+ scales: tp.Optional[torch.Tensor] = None
+ if scales_c0 is not None and scales_c1 is not None:
+ scales = torch.stack([scales_c0, scales_c1], dim=1)
+
+ if self.per_timestep:
+ indices = rearrange(indices, 'c b k t -> b k (t c)', c=2)
+ else:
+ indices = rearrange(indices, 'c b k t -> b (k c) t', c=2)
+
+ return (indices, scales)
+
+ def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ if self.per_timestep:
+ codes = rearrange(codes, 'b k (t c) -> c b k t', c=2)
+ else:
+ codes = rearrange(codes, 'b (k c) t -> c b k t', c=2)
+ return codes[0], codes[1]
+
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+ B, K, T = codes.shape
+ assert T % self.num_virtual_steps == 0, "Provided codes' number of timesteps does not match"
+ assert K == self.num_codebooks, "Provided codes' number of codebooks does not match"
+
+ scale_c0, scale_c1 = None, None
+ if scale is not None:
+ assert scale.size(0) == B and scale.size(1) == 2, f"Scale has unexpected shape: {scale.shape}"
+ scale_c0 = scale[0, ...]
+ scale_c1 = scale[1, ...]
+
+ codes_c0, codes_c1 = self.get_left_right_codes(codes)
+ audio_c0 = self.model.decode(codes_c0, scale_c0)
+ audio_c1 = self.model.decode(codes_c1, scale_c1)
+ return torch.cat([audio_c0, audio_c1], dim=1)
+
+ def decode_latent(self, codes: torch.Tensor):
+ """Decode from the discrete codes to continuous latent space."""
+ raise NotImplementedError("Not supported by interleaved stereo wrapped models.")
diff --git a/audiocraft/models/lm.py b/audiocraft/models/lm.py
index 8cefd2c5..c4ea2e5e 100644
--- a/audiocraft/models/lm.py
+++ b/audiocraft/models/lm.py
@@ -314,7 +314,8 @@ def _sample_next_token(self,
temp: float = 1.0,
top_k: int = 0,
top_p: float = 0.0,
- cfg_coef: tp.Optional[float] = None) -> torch.Tensor:
+ cfg_coef: tp.Optional[float] = None,
+ two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
"""Sample next token from the model given a sequence and a set of conditions. The model supports
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
@@ -335,7 +336,8 @@ def _sample_next_token(self,
B = sequence.shape[0]
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
model = self if self._fsdp is None else self._fsdp
- if self.two_step_cfg and cfg_conditions != {}:
+ two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
+ if two_step_cfg and cfg_conditions != {}:
assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
condition_tensors, null_condition_tensors = cfg_conditions
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
@@ -493,7 +495,7 @@ def generate(self,
# sample next token from the model, next token shape is [B, K, 1]
next_token = self._sample_next_token(
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
- cfg_coef=cfg_coef)
+ cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
# ensure the tokens that should be masked are properly set to special_token_id
# as the model never output special_token_id
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
diff --git a/audiocraft/models/loaders.py b/audiocraft/models/loaders.py
index 7fd49d84..f02ba115 100644
--- a/audiocraft/models/loaders.py
+++ b/audiocraft/models/loaders.py
@@ -27,6 +27,7 @@
from omegaconf import OmegaConf, DictConfig
import torch
+import audiocraft
from . import builders
from .encodec import CompressionModel
@@ -60,7 +61,9 @@ def _get_state_dict(
else:
assert filename is not None, "filename needs to be defined if using HF checkpoints"
- file = hf_hub_download(repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir)
+ file = hf_hub_download(
+ repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir,
+ library_name="audiocraft", library_version=audiocraft.__version__)
return torch.load(file, map_location=device)
diff --git a/audiocraft/models/musicgen.py b/audiocraft/models/musicgen.py
index 557d1196..88ee13b6 100644
--- a/audiocraft/models/musicgen.py
+++ b/audiocraft/models/musicgen.py
@@ -12,11 +12,12 @@
import typing as tp
import warnings
+import omegaconf
import torch
from .encodec import CompressionModel
from .lm import LMModel
-from .builders import get_debug_compression_model, get_debug_lm_model
+from .builders import get_debug_compression_model, get_debug_lm_model, get_wrapped_compression_model
from .loaders import load_compression_model, load_lm_model
from ..data.audio_utils import convert_audio
from ..modules.conditioners import ConditioningAttributes, WavCondition
@@ -52,14 +53,28 @@ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
self.name = name
self.compression_model = compression_model
self.lm = lm
+ self.cfg: tp.Optional[omegaconf.DictConfig] = None
+ # Just to be safe, let's put everything in eval mode.
+ self.compression_model.eval()
+ self.lm.eval()
+
+ if hasattr(lm, 'cfg'):
+ cfg = lm.cfg
+ assert isinstance(cfg, omegaconf.DictConfig)
+ self.cfg = cfg
+
+ if self.cfg is not None:
+ self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg)
+
if max_duration is None:
- if hasattr(lm, 'cfg'):
+ if self.cfg is not None:
max_duration = lm.cfg.dataset.segment_duration # type: ignore
else:
raise ValueError("You must provide max_duration when building directly MusicGen")
assert max_duration is not None
self.max_duration: float = max_duration
self.device = next(iter(lm.parameters())).device
+
self.generation_params: dict = {}
self.set_generation_params(duration=15) # 15 seconds by default
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
@@ -118,6 +133,7 @@ def get_pretrained(name: str = 'facebook/musicgen-melody', device=None):
compression_model = load_compression_model(name, device=device)
if 'self_wav' in lm.condition_provider.conditioners:
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
+ lm.condition_provider.conditioners['self_wav']._use_masking = False
return MusicGen(name, compression_model, lm)
diff --git a/audiocraft/modules/codebooks_patterns.py b/audiocraft/modules/codebooks_patterns.py
index 3cf3bb41..61362588 100644
--- a/audiocraft/modules/codebooks_patterns.py
+++ b/audiocraft/modules/codebooks_patterns.py
@@ -486,9 +486,14 @@ def get_pattern(self, timesteps: int) -> Pattern:
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
-class VALLEPattern(CodebooksPatternProvider):
- """Almost VALL-E style pattern.
- We further allow some delays for the codebooks other than the first one.
+class CoarseFirstPattern(CodebooksPatternProvider):
+ """First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
+ potentially with delays.
+
+ ..Warning:: You must always generate the full training duration at test time, for instance,
+ 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
+ location. This is due to the non causality of the remaining codebooks with respect to
+ the first ones.
Args:
n_q (int): Number of codebooks.
diff --git a/audiocraft/modules/conditioners.py b/audiocraft/modules/conditioners.py
index d10ac8dc..178957d1 100644
--- a/audiocraft/modules/conditioners.py
+++ b/audiocraft/modules/conditioners.py
@@ -469,6 +469,8 @@ class WaveformConditioner(BaseConditioner):
def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
super().__init__(dim, output_dim)
self.device = device
+ # if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample.
+ self._use_masking = True
def tokenize(self, x: WavCondition) -> WavCondition:
wav, length, sample_rate, path, seek_time = x
@@ -496,13 +498,12 @@ def forward(self, x: WavCondition) -> ConditionType:
embeds = embeds.to(self.output_proj.weight)
embeds = self.output_proj(embeds)
- if lengths is not None:
+ if lengths is not None and self._use_masking:
lengths = lengths / self._downsampling_factor()
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
else:
- mask = torch.ones_like(embeds)
- embeds = (embeds * mask.unsqueeze(2).to(self.device))
-
+ mask = torch.ones_like(embeds[..., 0])
+ embeds = (embeds * mask.unsqueeze(-1))
return embeds, mask
@@ -537,6 +538,8 @@ def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp:
self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
self.sample_rate = sample_rate
self.match_len_on_eval = match_len_on_eval
+ if match_len_on_eval:
+ self._use_masking = False
self.duration = duration
self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
stem_sources: list = self.demucs.sources # type: ignore
@@ -792,6 +795,8 @@ def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
import laion_clap # type: ignore
except ImportError:
raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
+ warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). "
+ "Please retrain all models.")
checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
diff --git a/audiocraft/modules/rope.py b/audiocraft/modules/rope.py
index 503e6748..c12cee09 100644
--- a/audiocraft/modules/rope.py
+++ b/audiocraft/modules/rope.py
@@ -81,13 +81,16 @@ def get_rotation(self, start: int, end: int):
self.rotation = torch.polar(torch.ones_like(angles), angles)
return self.rotation[start:end]
- def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
+ def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False):
"""Apply rope rotation to query or key tensor."""
- T = x.shape[1]
- rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2)
+ T = x.shape[time_dim]
+ target_shape = [1] * x.dim()
+ target_shape[time_dim] = T
+ target_shape[-1] = -1
+ rotation = self.get_rotation(start, start + T).view(target_shape)
if self.xpos:
- decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2)
+ decay = self.xpos.get_decay(start, start + T).view(target_shape)
else:
decay = 1.0
@@ -96,11 +99,11 @@ def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
- x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2)
+ x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x)
return x_out.type_as(x)
- def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
+ def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1):
""" Apply rope rotation to both query and key tensors.
Supports streaming mode, in which query and key are not expected to have the same shape.
In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
@@ -110,12 +113,13 @@ def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
query (torch.Tensor): Query to rotate.
key (torch.Tensor): Key to rotate.
start (int): Start index of the sequence for time offset.
+ time_dim (int): which dimension represent the time steps.
"""
- query_timesteps = query.shape[1]
- key_timesteps = key.shape[1]
+ query_timesteps = query.shape[time_dim]
+ key_timesteps = key.shape[time_dim]
streaming_offset = key_timesteps - query_timesteps
- query_out = self.rotate(query, start + streaming_offset)
- key_out = self.rotate(key, start, invert_decay=True)
+ query_out = self.rotate(query, start + streaming_offset, time_dim)
+ key_out = self.rotate(key, start, time_dim, invert_decay=True)
return query_out, key_out
diff --git a/audiocraft/modules/transformer.py b/audiocraft/modules/transformer.py
index 048c06df..e8100a4c 100644
--- a/audiocraft/modules/transformer.py
+++ b/audiocraft/modules/transformer.py
@@ -35,8 +35,8 @@ def set_efficient_attention_backend(backend: str = 'torch'):
_efficient_attention_backend = backend
-def _get_attention_time_dimension() -> int:
- if _efficient_attention_backend == 'torch':
+def _get_attention_time_dimension(memory_efficient: bool) -> int:
+ if _efficient_attention_backend == 'torch' and memory_efficient:
return 2
else:
return 1
@@ -89,11 +89,11 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float =
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
-def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
+def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
if n_rep == 1:
return x
- if _efficient_attention_backend == 'torch':
+ if _efficient_attention_backend == 'torch' and memory_efficient:
bs, n_kv_heads, slen, head_dim = x.shape
return (
x[:, :, None, :, :]
@@ -234,7 +234,7 @@ def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype
# Return a causal mask, accounting for potentially stored past keys/values
# We actually return a bias for the attention score, as this has the same
# convention both in the builtin MHA in Pytorch, and Xformers functions.
- time_dim = _get_attention_time_dimension()
+ time_dim = _get_attention_time_dimension(self.memory_efficient)
if self.memory_efficient:
from xformers.ops import LowerTriangularMask
if current_steps == 1:
@@ -264,7 +264,7 @@ def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype
torch.full([], float('-inf'), device=device, dtype=dtype))
def _complete_kv(self, k, v):
- time_dim = _get_attention_time_dimension()
+ time_dim = _get_attention_time_dimension(self.memory_efficient)
if self.cross_attention:
# With cross attention we assume all keys and values
# are already available, and streaming is with respect
@@ -298,8 +298,7 @@ def _complete_kv(self, k, v):
return nk, nv
def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
- # TODO: fix and verify layout.
- assert _efficient_attention_backend == 'xformers', "Rope not supported with torch attn."
+ time_dim = _get_attention_time_dimension(self.memory_efficient)
# Apply rope embeddings to query and key tensors.
assert self.rope is not None
if 'past_keys' in self._streaming_state:
@@ -311,7 +310,7 @@ def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
else:
past_context_offset = 0
streaming_offset = past_context_offset + past_keys_offset
- return self.rope.rotate_qk(query, key, start=streaming_offset)
+ return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim)
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
key_padding_mask=None, need_weights=False, attn_mask=None,
@@ -320,7 +319,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
assert not is_causal, ("New param added in torch 2.0.1 not supported, "
"use the causal args in the constructor.")
- time_dim = _get_attention_time_dimension()
+ time_dim = _get_attention_time_dimension(self.memory_efficient)
if time_dim == 2:
layout = "b h t d"
else:
@@ -394,8 +393,8 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
q, k = self._apply_rope(q, k)
k, v = self._complete_kv(k, v)
if self.kv_repeat > 1:
- k = expand_repeated_kv(k, self.kv_repeat)
- v = expand_repeated_kv(v, self.kv_repeat)
+ k = expand_repeated_kv(k, self.kv_repeat, self.memory_efficient)
+ v = expand_repeated_kv(v, self.kv_repeat, self.memory_efficient)
if self.attention_as_float32:
q, k, v = [x.float() for x in [q, k, v]]
if self.memory_efficient:
@@ -649,7 +648,6 @@ def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforwar
# see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
# backward hook inside of FSDP...
layer._magma_checkpointed = True # type: ignore
- assert layer.layer_drop == 0., "Need further checking" # type: ignore
def _apply_layer(self, layer, *args, **kwargs):
method = self.checkpointing
diff --git a/audiocraft/optim/dadam.py b/audiocraft/optim/dadam.py
index a84402f7..e009969f 100644
--- a/audiocraft/optim/dadam.py
+++ b/audiocraft/optim/dadam.py
@@ -5,19 +5,15 @@
# LICENSE file in the root directory of this source tree.
import logging
-from typing import TYPE_CHECKING, Any
+from typing import Any
import torch
import torch.optim
import torch.distributed as dist
-if TYPE_CHECKING:
- from torch.optim.optimizer import _params_t
-else:
- _params_t = Any
-
logger = logging.getLogger(__name__)
+_params_t = Any
def to_real(x):
diff --git a/audiocraft/optim/fsdp.py b/audiocraft/optim/fsdp.py
index b3c1a55b..1090d3d7 100644
--- a/audiocraft/optim/fsdp.py
+++ b/audiocraft/optim/fsdp.py
@@ -143,8 +143,8 @@ def _name_without_fsdp_prefix(name: str) -> str:
new_parts = [part for part in parts if part != FSDP_WRAPPED_MODULE]
return '.'.join(new_parts)
- def state_dict(self) -> tp.Dict[str, tp.Any]: # type: ignore
- state = dict(super().state_dict())
+ def state_dict(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: # type: ignore
+ state = dict(super().state_dict(*args, **kwargs))
for key, value in list(state.items()):
if is_sharded_tensor(value):
del state[key]
diff --git a/audiocraft/solvers/musicgen.py b/audiocraft/solvers/musicgen.py
index bb615abf..2439da33 100644
--- a/audiocraft/solvers/musicgen.py
+++ b/audiocraft/solvers/musicgen.py
@@ -7,6 +7,7 @@
from pathlib import Path
import time
import typing as tp
+import warnings
import flashy
import math
@@ -226,7 +227,6 @@ def _compute_cross_entropy(
ce = ce / K
return ce, ce_per_codebook
- @torch.no_grad()
def _prepare_tokens_and_attributes(
self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
check_synchronization_points: bool = False
@@ -243,6 +243,12 @@ def _prepare_tokens_and_attributes(
with B the batch size, K the number of codebooks, T_s the token timesteps.
Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s].
"""
+ if self.model.training:
+ warnings.warn(
+ "Up to version 1.0.1, the _prepare_tokens_and_attributes was evaluated with `torch.no_grad()`. "
+ "This is inconsistent with how model were trained in the MusicGen paper. We removed the "
+ "`torch.no_grad()` in version 1.1.0. Small changes to the final performance are expected. "
+ "Really sorry about that.")
if self._cached_batch_loader is None or self.current_stage != "train":
audio, infos = batch
audio = audio.to(self.device)
@@ -533,7 +539,7 @@ def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
rtf = 1.
else:
gen_unprompted_outputs = self.run_generate_step(
- batch, gen_duration=target_duration, prompt_duration=prompt_duration,
+ batch, gen_duration=target_duration, prompt_duration=None,
**self.generation_params)
gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu()
rtf = gen_unprompted_outputs['rtf']
diff --git a/audiocraft/train.py b/audiocraft/train.py
index 22dd1178..5851222c 100644
--- a/audiocraft/train.py
+++ b/audiocraft/train.py
@@ -12,6 +12,7 @@
import logging
import multiprocessing
import os
+from pathlib import Path
import sys
import typing as tp
@@ -119,6 +120,11 @@ def init_seed_and_system(cfg):
logger.debug('Setting num threads to %d', cfg.num_threads)
set_efficient_attention_backend(cfg.efficient_attention_backend)
logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend)
+ if 'SLURM_JOB_ID' in os.environ:
+ tmpdir = Path('/scratch/slurm_tmpdir/' + os.environ['SLURM_JOB_ID'])
+ if tmpdir.exists():
+ logger.info("Changing tmpdir to %s", tmpdir)
+ os.environ['TMPDIR'] = str(tmpdir)
@hydra_main(config_path='../config', config_name='config', version_base='1.1')
diff --git a/audiocraft/utils/cache.py b/audiocraft/utils/cache.py
index 2fccc0ac..6ba017a7 100644
--- a/audiocraft/utils/cache.py
+++ b/audiocraft/utils/cache.py
@@ -57,7 +57,7 @@ class EmbeddingCache:
specify the index corresponding to the current embedding in the object that can represent batch metadata.
If not specified, will return the full embedding unmodified.
"""
- def __init__(self, cache_path: tp.Union[Path], device: tp.Union[str, torch.device],
+ def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[str, torch.device],
compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor],
extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None):
self.cache_path = Path(cache_path)
@@ -287,6 +287,7 @@ def _load_one(self, index: int):
if isinstance(part[0], torch.Tensor):
out.append(torch.stack(part))
else:
+ assert isinstance(part, torch.Tensor)
out.append(part)
return out
except Exception:
diff --git a/audiocraft/utils/export_legacy.py b/audiocraft/utils/export_legacy.py
index 52f145f3..367c3f3c 100644
--- a/audiocraft/utils/export_legacy.py
+++ b/audiocraft/utils/export_legacy.py
@@ -14,13 +14,21 @@
from omegaconf import OmegaConf, DictConfig
import torch
+from audiocraft import __version__
+
def _clean_lm_cfg(cfg: DictConfig):
OmegaConf.set_struct(cfg, False)
# This used to be set automatically in the LM solver, need a more robust solution
# for the future.
cfg['transformer_lm']['card'] = 2048
- cfg['transformer_lm']['n_q'] = 4
+ n_q = 4
+ stereo_cfg = getattr(cfg, 'interleave_stereo_codebooks', None)
+ if stereo_cfg is not None and stereo_cfg.use:
+ if 'downsample' in stereo_cfg:
+ del stereo_cfg['downsample']
+ n_q = 8
+ cfg['transformer_lm']['n_q'] = n_q
# Experimental params no longer supported.
bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
@@ -30,27 +38,33 @@ def _clean_lm_cfg(cfg: DictConfig):
return cfg
-def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
- sig = Path(checkpoint_path).parent.name
- assert len(sig) == 8, "Not a valid Dora signature"
+def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
pkg = torch.load(checkpoint_path, 'cpu')
new_pkg = {
'best_state': pkg['ema']['state']['model'],
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
+ # The following params were NOT exported for the first release of MusicGen.
+ 'version': __version__,
+ 'exported': True,
}
- out_file = Path(out_folder) / f'{sig}.th'
+ Path(out_file).parent.mkdir(exist_ok=True, parents=True)
torch.save(new_pkg, out_file)
return out_file
-def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
- sig = Path(checkpoint_path).parent.name
- assert len(sig) == 8, "Not a valid Dora signature"
+def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
pkg = torch.load(checkpoint_path, 'cpu')
+ if pkg['fsdp_best_state']:
+ best_state = pkg['fsdp_best_state']['model']
+ else:
+ best_state = pkg['best_state']['model']
new_pkg = {
- 'best_state': pkg['fsdp_best_state']['model'],
- 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
+ 'best_state': best_state,
+ 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])),
+ # The following params were NOT exported for the first release of MusicGen.
+ 'version': __version__,
+ 'exported': True,
}
- out_file = Path(out_folder) / f'{sig}.th'
+ Path(out_file).parent.mkdir(exist_ok=True, parents=True)
torch.save(new_pkg, out_file)
return out_file
diff --git a/audiocraft/utils/utils.py b/audiocraft/utils/utils.py
index 3135d70e..2c5799f8 100644
--- a/audiocraft/utils/utils.py
+++ b/audiocraft/utils/utils.py
@@ -185,7 +185,7 @@ def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> t
assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
final_length = lengths.max().item() if not max_len else max_len
final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
- return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None]
+ return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None]
def hash_trick(word: str, vocab_size: int) -> int:
diff --git a/config/conditioner/clapemb2music.yaml b/config/conditioner/clapemb2music.yaml
index 8500a826..d44ac774 100644
--- a/config/conditioner/clapemb2music.yaml
+++ b/config/conditioner/clapemb2music.yaml
@@ -23,7 +23,7 @@ conditioners:
checkpoint: //reference/clap/music_audioset_epoch_15_esc_90.14.pt
model_arch: 'HTSAT-base'
enable_fusion: false
- sample_rate: 44100
+ sample_rate: 48000
max_audio_length: 10
audio_stride: 1
dim: 512
diff --git a/config/model/lm/audiogen_lm.yaml b/config/model/lm/audiogen_lm.yaml
index 696f7462..d17e7a93 100644
--- a/config/model/lm/audiogen_lm.yaml
+++ b/config/model/lm/audiogen_lm.yaml
@@ -18,7 +18,7 @@ codebooks_pattern:
delays: [0, 0, 0, 0]
music_lm:
group_by: 2
- valle:
+ coarse_first:
delays: [0, 0, 0]
transformer_lm:
diff --git a/config/model/lm/musicgen_lm.yaml b/config/model/lm/musicgen_lm.yaml
index 5bc87a62..be1fbc14 100644
--- a/config/model/lm/musicgen_lm.yaml
+++ b/config/model/lm/musicgen_lm.yaml
@@ -18,7 +18,7 @@ codebooks_pattern:
delays: [0, 0, 0, 0]
music_lm:
group_by: 2
- valle:
+ coarse_first:
delays: [0, 0, 0]
transformer_lm:
diff --git a/config/solver/musicgen/default.yaml b/config/solver/musicgen/default.yaml
index 59e01137..8bdf9c74 100644
--- a/config/solver/musicgen/default.yaml
+++ b/config/solver/musicgen/default.yaml
@@ -14,10 +14,20 @@ solver: musicgen
sample_rate: ???
channels: ???
compression_model_checkpoint: ???
+# The following will set the num codebooks on the underlying
+# model, this might be different from the actual value for n_q
+# given to the transformer, when the model output is postprocessed, for instance
+# for stereo channels. If not provided, default value for the compression model
+# will be used.
+compression_model_n_q: null
tokens:
padding_with_special_token: false
+interleave_stereo_codebooks:
+ use: false
+ per_timestep: false
+
cache:
path:
write: false
diff --git a/demos/musicgen_app.py b/demos/musicgen_app.py
index 9847e56c..2bbd6556 100644
--- a/demos/musicgen_app.py
+++ b/demos/musicgen_app.py
@@ -9,24 +9,29 @@
import argparse
from concurrent.futures import ProcessPoolExecutor
+import logging
import os
from pathlib import Path
import subprocess as sp
+import sys
from tempfile import NamedTemporaryFile
import time
import typing as tp
import warnings
+from einops import rearrange
import torch
import gradio as gr
from audiocraft.data.audio_utils import convert_audio
from audiocraft.data.audio import audio_write
+from audiocraft.models.encodec import InterleaveStereoCompressionModel
from audiocraft.models import MusicGen, MultiBandDiffusion
MODEL = None # Last used model
-IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
+SPACE_ID = os.environ.get('SPACE_ID', '')
+IS_BATCHED = "facebook/MusicGen" in SPACE_ID or 'musicgen-internal/musicgen_dev' in SPACE_ID
print(IS_BATCHED)
MAX_BATCH_SIZE = 12
BATCHED_DURATION = 15
@@ -93,7 +98,7 @@ def load_model(version='facebook/musicgen-melody'):
# Clear PyTorch CUDA cache and delete model
del MODEL
torch.cuda.empty_cache()
-
+ MODEL = None # in case loading would crash
MODEL = MusicGen.get_pretrained(version)
@@ -104,8 +109,7 @@ def load_diffusion():
MBD = MultiBandDiffusion.get_mbd_musicgen()
-def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
-
+def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=None, **gen_kwargs):
MODEL.set_generation_params(duration=duration, **gen_kwargs)
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
be = time.time()
@@ -123,18 +127,30 @@ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
melody = convert_audio(melody, sr, target_sr, target_ac)
processed_melodies.append(melody)
- if any(m is not None for m in processed_melodies):
- outputs = MODEL.generate_with_chroma(
- descriptions=texts,
- melody_wavs=processed_melodies,
- melody_sample_rate=target_sr,
- progress=progress,
- return_tokens=USE_DIFFUSION
- )
- else:
- outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION)
+ try:
+ if any(m is not None for m in processed_melodies):
+ outputs = MODEL.generate_with_chroma(
+ descriptions=texts,
+ melody_wavs=processed_melodies,
+ melody_sample_rate=target_sr,
+ progress=progress,
+ return_tokens=USE_DIFFUSION
+ )
+ else:
+ outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION)
+ except RuntimeError as e:
+ raise gr.Error("Error while generating " + e.args[0])
if USE_DIFFUSION:
- outputs_diffusion = MBD.tokens_to_wav(outputs[1])
+ if gradio_progress is not None:
+ gradio_progress(1, desc='Running MultiBandDiffusion...')
+ tokens = outputs[1]
+ if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
+ left, right = MODEL.compression_model.get_left_right_codes(tokens)
+ tokens = torch.cat([left, right])
+ outputs_diffusion = MBD.tokens_to_wav(tokens)
+ if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel):
+ assert outputs_diffusion.shape[1] == 1 # output is mono
+ outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2)
outputs = torch.cat([outputs[0], outputs_diffusion], dim=0)
outputs = outputs.detach().cpu().float()
pending_videos = []
@@ -158,15 +174,24 @@ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
def predict_batched(texts, melodies):
max_text_length = 512
texts = [text[:max_text_length] for text in texts]
- load_model('facebook/musicgen-melody')
+ load_model('facebook/musicgen-stereo-melody')
res = _do_predictions(texts, melodies, BATCHED_DURATION)
return res
-def predict_full(model, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
+def predict_full(model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
global INTERRUPTING
global USE_DIFFUSION
INTERRUPTING = False
+ progress(0, desc="Loading model...")
+ model_path = model_path.strip()
+ if model_path:
+ if not Path(model_path).exists():
+ raise gr.Error(f"Model path {model_path} doesn't exist.")
+ if not Path(model_path).is_dir():
+ raise gr.Error(f"Model path {model_path} must be a folder containing "
+ "state_dict.bin and compression_state_dict_.bin.")
+ model = model_path
if temperature < 0:
raise gr.Error("Temperature must be >= 0.")
if topk < 0:
@@ -177,20 +202,26 @@ def predict_full(model, decoder, text, melody, duration, topk, topp, temperature
topk = int(topk)
if decoder == "MultiBand_Diffusion":
USE_DIFFUSION = True
+ progress(0, desc="Loading diffusion model...")
load_diffusion()
else:
USE_DIFFUSION = False
load_model(model)
+ max_generated = 0
+
def _progress(generated, to_generate):
- progress((min(generated, to_generate), to_generate))
+ nonlocal max_generated
+ max_generated = max(generated, max_generated)
+ progress((min(max_generated, to_generate), to_generate))
if INTERRUPTING:
raise gr.Error("Interrupted.")
MODEL.set_custom_progress_callback(_progress)
videos, wavs = _do_predictions(
[text], [melody], duration, progress=True,
- top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
+ top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef,
+ gradio_progress=progress)
if USE_DIFFUSION:
return videos[0], wavs[0], videos[1], wavs[1]
return videos[0], wavs[0], None, None
@@ -235,8 +266,12 @@ def ui_full(launch_kwargs):
_ = gr.Button("Interrupt").click(fn=interrupt, queue=False)
with gr.Row():
model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small",
- "facebook/musicgen-large"],
- label="Model", value="facebook/musicgen-melody", interactive=True)
+ "facebook/musicgen-large", "facebook/musicgen-melody-large",
+ "facebook/musicgen-stereo-small", "facebook/musicgen-stereo-medium",
+ "facebook/musicgen-stereo-melody", "facebook/musicgen-stereo-large",
+ "facebook/musicgen-stereo-melody-large"],
+ label="Model", value="facebook/musicgen-stereo-melody", interactive=True)
+ model_path = gr.Text(label="Model Path (custom models)")
with gr.Row():
decoder = gr.Radio(["Default", "MultiBand_Diffusion"],
label="Decoder", value="Default", interactive=True)
@@ -253,7 +288,7 @@ def ui_full(launch_kwargs):
diffusion_output = gr.Video(label="MultiBand Diffusion Decoder")
audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath')
submit.click(toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False,
- show_progress=False).then(predict_full, inputs=[model, decoder, text, melody, duration, topk, topp,
+ show_progress=False).then(predict_full, inputs=[model, model_path, decoder, text, melody, duration, topk, topp,
temperature, cfg_coef],
outputs=[output, audio_output, diffusion_output, audio_diffusion])
radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
@@ -264,37 +299,37 @@ def ui_full(launch_kwargs):
[
"An 80s driving pop song with heavy drums and synth pads in the background",
"./assets/bach.mp3",
- "facebook/musicgen-melody",
+ "facebook/musicgen-stereo-melody",
"Default"
],
[
"A cheerful country song with acoustic guitars",
"./assets/bolero_ravel.mp3",
- "facebook/musicgen-melody",
+ "facebook/musicgen-stereo-melody",
"Default"
],
[
"90s rock song with electric guitar and heavy drums",
None,
- "facebook/musicgen-medium",
+ "facebook/musicgen-stereo-medium",
"Default"
],
[
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
"./assets/bach.mp3",
- "facebook/musicgen-melody",
+ "facebook/musicgen-stereo-melody",
"Default"
],
[
"lofi slow bpm electro chill with organic samples",
None,
- "facebook/musicgen-medium",
+ "facebook/musicgen-stereo-medium",
"Default"
],
[
"Punk rock with loud drum and power guitar",
None,
- "facebook/musicgen-medium",
+ "facebook/musicgen-stereo-medium",
"MultiBand_Diffusion"
],
],
@@ -306,8 +341,18 @@ def ui_full(launch_kwargs):
### More details
The model will generate a short music extract based on the description you provided.
- The model can generate up to 30 seconds of audio in one pass. It is now possible
- to extend the generation by feeding back the end of the previous chunk of audio.
+ The model can generate up to 30 seconds of audio in one pass.
+
+ The model was trained with description from a stock music catalog, descriptions that will work best
+ should include some level of details on the instruments present, along with some intended use case
+ (e.g. adding "perfect for a commercial" can somehow help).
+
+ Using one of the `melody` model (e.g. `musicgen-melody-*`), you can optionally provide a reference audio
+ from which a broad melody will be extracted.
+ The model will then try to follow both the description and melody provided.
+ For best results, the melody should be 30 seconds long (I know, the samples we provide are not...)
+
+ It is now possible to extend the generation by feeding back the end of the previous chunk of audio.
This can take a long time, and the model might lose consistency. The model might also
decide at arbitrary positions that the song ends.
@@ -315,23 +360,23 @@ def ui_full(launch_kwargs):
An overlap of 12 seconds is kept with the previously generated chunk, and 18 "new" seconds
are generated each time.
- We present 4 model variations:
+ We present 10 model variations:
1. facebook/musicgen-melody -- a music generation model capable of generating music condition
on text and melody inputs. **Note**, you can also use text only.
2. facebook/musicgen-small -- a 300M transformer decoder conditioned on text only.
3. facebook/musicgen-medium -- a 1.5B transformer decoder conditioned on text only.
4. facebook/musicgen-large -- a 3.3B transformer decoder conditioned on text only.
+ 5. facebook/musicgen-melody-large -- a 3.3B transformer decoder conditioned on and melody.
+ 6. facebook/musicgen-stereo-*: same as the previous models but fine tuned to output stereo audio.
We also present two way of decoding the audio tokens
- 1. Use the default GAN based compression model
- 2. Use MultiBand Diffusion from (paper linknano )
+ 1. Use the default GAN based compression model. It can suffer from artifacts especially
+ for crashes, snares etc.
+ 2. Use [MultiBand Diffusion](https://arxiv.org/abs/2308.02560). Should improve the audio quality,
+ at an extra computational cost. When this is selected, we provide both the GAN based decoded
+ audio, and the one obtained with MBD.
- When using `facebook/musicgen-melody`, you can optionally provide a reference audio from
- which a broad melody will be extracted. The model will then try to follow both
- the description and melody provided.
-
- You can also use your own GPU or a Google Colab by following the instructions on our repo.
- See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md)
for more details.
"""
)
@@ -345,7 +390,7 @@ def ui_batched(launch_kwargs):
"""
# MusicGen
- This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft),
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md),
a simple and controllable model for music generation
presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
@@ -403,15 +448,27 @@ def ui_batched(launch_kwargs):
gr.Markdown("""
### More details
- The model will generate 12 seconds of audio based on the description you provided.
+ The model will generate 15 seconds of audio based on the description you provided.
+ The model was trained with description from a stock music catalog, descriptions that will work best
+ should include some level of details on the instruments present, along with some intended use case
+ (e.g. adding "perfect for a commercial" can somehow help).
+
You can optionally provide a reference audio from which a broad melody will be extracted.
The model will then try to follow both the description and melody provided.
- All samples are generated with the `melody` model.
+ For best results, the melody should be 30 seconds long (I know, the samples we provide are not...)
- You can also use your own GPU or a Google Colab by following the instructions on our repo.
+ You can access more control (longer generation, more models etc.) by clicking
+ the
+
+ (you will then need a paid GPU from HuggingFace).
+ If you have a GPU, you can run the gradio demo locally (click the link to our repo below for more info).
+ Finally, you can get a GPU for free from Google
+ and run the demo in [a Google Colab.](https://ai.honu.io/red/musicgen-colab).
- See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
- for more details.
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md)
+ for more details. All samples are generated with the `stereo-melody` model.
""")
demo.queue(max_size=8 * 4).launch(**launch_kwargs)
@@ -458,6 +515,8 @@ def ui_batched(launch_kwargs):
if args.share:
launch_kwargs['share'] = args.share
+ logging.basicConfig(level=logging.INFO, stream=sys.stderr)
+
# Show the interface
if IS_BATCHED:
global USE_DIFFUSION
diff --git a/docs/MBD.md b/docs/MBD.md
index 4288a89d..b6629184 100644
--- a/docs/MBD.md
+++ b/docs/MBD.md
@@ -113,5 +113,5 @@ Learn more about AudioCraft training pipelines in the [dedicated section](./TRAI
See license information in the [README](../README.md).
-[arxiv]: https://dl.fbaipublicfiles.com/encodec/Diffusion/paper.pdf
+[arxiv]: https://arxiv.org/abs/2308.02560
[mbd_samples]: https://ai.honu.io/papers/mbd/
diff --git a/docs/MUSICGEN.md b/docs/MUSICGEN.md
index 606ce858..9a6b1e74 100644
--- a/docs/MUSICGEN.md
+++ b/docs/MUSICGEN.md
@@ -9,7 +9,7 @@ a small delay between the codebooks, we show we can predict them in parallel, th
steps per second of audio.
Check out our [sample page][musicgen_samples] or test the available demo!
-
+
@@ -38,7 +38,7 @@ We offer a number of way to interact with MusicGen:
1. A demo is also available on the [`facebook/MusicGen` Hugging Face Space](https://huggingface.co/spaces/facebook/MusicGen)
(huge thanks to all the HF team for their support).
2. You can run the extended demo on a Colab:
-[colab notebook](https://colab.research.google.com/drive/1JlTOjB-G0A2Hz3h8PK63vLZk4xdCI5QB?usp=sharing)
+[colab notebook](https://ai.honu.io/red/musicgen-colab)
3. You can use the gradio demo locally by running [`python -m demos.musicgen_app --share`](../demos/musicgen_app.py).
4. You can play with MusicGen by running the jupyter notebook at [`demos/musicgen_demo.ipynb`](../demos/musicgen_demo.ipynb) locally (if you have a GPU).
5. Finally, checkout [@camenduru Colab page](https://github.com/camenduru/MusicGen-colab)
@@ -47,11 +47,18 @@ which is regularly updated with contributions from @camenduru and the community.
## API
-We provide a simple API and 4 pre-trained models. The pre trained models are:
+We provide a simple API and 10 pre-trained models. The pre trained models are:
- `facebook/musicgen-small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-small)
- `facebook/musicgen-medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-medium)
- `facebook/musicgen-melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody)
- `facebook/musicgen-large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-large)
+- `facebook/musicgen-melody-large`: 3.3B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody-large)
+- `facebook/musicgen-stereo-*`: All the previous models fine tuned for stereo generation -
+ [small](https://huggingface.co/facebook/musicgen-stereo-small),
+ [medium](https://huggingface.co/facebook/musicgen-stereo-medium),
+ [large](https://huggingface.co/facebook/musicgen-stereo-large),
+ [melody](https://huggingface.co/facebook/musicgen-stereo-melody),
+ [melody large](https://huggingface.co/facebook/musicgen-stereo-melody-large).
We observe the best trade-off between quality and compute with the `facebook/musicgen-medium` or `facebook/musicgen-melody` model.
In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
@@ -143,6 +150,10 @@ We provide a dummy dataset containing just a few examples for illustrative purpo
Please read first the [TRAINING documentation](./TRAINING.md), in particular the Environment Setup section.
+
+**Warning:** As of version 1.1.0, a few breaking changes were introduced. Check the [CHANGELOG.md](../CHANGELOG.md)
+file for more information. You might need to retrain some of your models.
+
### Example configurations and grids
We provide configurations to reproduce the released models and our research.
@@ -205,6 +216,19 @@ dora run solver=musicgen/debug \
**Warning:** you are responsible for setting the proper value for `transformer_lm.n_q` and `transformer_lm.card` (cardinality of the codebooks). You also have to update the codebook_pattern to match `n_q` as shown in the example for using DAC. .
+### Training stereo models
+
+Use the option `interleave_stereo_codebooks.use` set to `True` to activate stereo training along with `channels=2`. Left and right channels will be
+encoded separately by the compression model, then their codebook will be interleaved, e.g. order of codebook is
+`[1_L, 1_R, 2_L, 2_R, ...]`. You will also need to update the delays for the codebook patterns to match the number of codebooks, and the `n_q` value passed to the transformer LM:
+```
+dora run solver=musicgen/debug \
+ compression_model_checkpoint=//pretrained/facebook/encodec_32khz \
+ channels=2 interleave_stereo_codebooks.use=True \
+ transformer_lm.n_q=8 transformer_lm.card=2048 \
+ codebooks_pattern.delay.delays='[0, 0, 1, 1, 2, 2, 3, 3]'
+```
+
### Fine tuning existing models
You can initialize your model to one of the pretrained models by using the `continue_from` argument, in particular
@@ -228,6 +252,39 @@ dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continu
If you decide to do so, make sure your checkpoint is saved with `torch.save` and contains a dict
`{'best_state': {'model': model_state_dict_here}}`. Directly give the path to `continue_from` without a `//pretrained/` prefix.
+
+#### Fine tuning mono model to stereo
+
+You will not be able to `continue_from` a mono model with stereo training, as the shape of the embeddings and output linears
+would not match. You can use the following snippet to prepare a proper finetuning checkpoint.
+
+```python
+from pathlib import Path
+import torch
+
+# Download the pretrained model, e.g. from
+# https://huggingface.co/facebook/musicgen-melody/blob/main/state_dict.bin
+
+model_name = 'musicgen-melody'
+root = Path.home() / 'checkpoints'
+# You are responsible for downloading the following checkpoint in the proper location
+input_state_dict_path = root / model_name / 'state_dict.bin'
+state = torch.load(input_state_dict_path, 'cpu')
+bs = state['best_state']
+# there is a slight different in format between training checkpoints and exported public checkpoints.
+# If you want to use your own mono models from one of your training checkpont, following the instructions
+# for exporting a model explained later on this page.
+assert 'model' not in bs, 'The following code is for using an exported pretrained model'
+nbs = dict(bs)
+for k in range(8):
+ # We will just copy mono embeddings and linears twice, once for left and right channels.
+ nbs[f'linears.{k}.weight'] = bs[f'linears.{k//2}.weight']
+ nbs[f'emb.{k}.weight'] = bs[f'emb.{k//2}.weight']
+torch.save({'best_state': {'model': nbs}}, root / f'stereo_finetune_{model_name}.th')
+```
+
+Now, you can use `$HOME/checkpoints/stereo_finetune_musicgen-melody.th` as a `continue_from` target (without a `//pretrained` prefix!).
+
### Caching of EnCodec tokens
It is possible to precompute the EnCodec tokens and other metadata.
@@ -283,9 +340,9 @@ Once you have launched some experiments, you can easily get access
to the Solver with the latest trained model using the following snippet.
```python
-from audiocraft.solvers.musicgen import MusicGen
+from audiocraft.solvers.musicgen import MusicGenSolver
-solver = MusicGen.get_eval_solver_from_sig('SIG', device='cpu', batch_size=8)
+solver = MusicGenSolver.get_eval_solver_from_sig('SIG', device='cpu', batch_size=8)
solver.model
solver.dataloaders
```
@@ -344,11 +401,11 @@ activations by sharding the optimizer state.
## Citation
```
-@article{copet2023simple,
+@inproceedings{copet2023simple,
title={Simple and Controllable Music Generation},
author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
+ booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
- journal={arXiv preprint arXiv:2306.05284},
}
```
diff --git a/model_cards/MUSICGEN_MODEL_CARD.md b/model_cards/MUSICGEN_MODEL_CARD.md
index 95431368..68e81d44 100644
--- a/model_cards/MUSICGEN_MODEL_CARD.md
+++ b/model_cards/MUSICGEN_MODEL_CARD.md
@@ -87,4 +87,19 @@ More information can be found in the paper [Simple and Controllable Music Genera
**Use cases:** Users must be aware of the biases, limitations and risks of the model. MusicGen is a model developed for artificial intelligence research on controllable music generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks.
+## Update: stereo models and large melody.
+
+We further release a set of stereophonic capable models. Those were fine tuned for 200k updates starting
+from the mono models. The training data is otherwise identical and capabilities and limitations are shared with the base modes. The stereo models work by getting 2 streams of tokens from the EnCodec model, and interleaving those using
+the delay pattern. We also release a mono large model with melody conditioning capabilities. The list of new models
+is as follow:
+
+- facebook/musicgen-stereo-small
+- facebook/musicgen-stereo-medium
+- facebook/musicgen-stereo-large
+- facebook/musicgen-stereo-melody
+- facebook/musicgen-melody-large
+- facebook/musicgen-stereo-melody-large
+
+
[arxiv]: https://arxiv.org/abs/2306.05284
diff --git a/requirements.txt b/requirements.txt
index e44fe159..a6fa5809 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,7 +9,7 @@ num2words
numpy
sentencepiece
spacy==3.5.2
-torch>=2.0.0
+torch==2.1.0
torchaudio>=2.0.0
huggingface_hub
tqdm
@@ -20,4 +20,4 @@ librosa
gradio
torchmetrics
encodec
-protobuf
\ No newline at end of file
+protobuf
diff --git a/tests/common_utils/wav_utils.py b/tests/common_utils/wav_utils.py
index d3a563ee..cc14a9ca 100644
--- a/tests/common_utils/wav_utils.py
+++ b/tests/common_utils/wav_utils.py
@@ -5,10 +5,10 @@
# LICENSE file in the root directory of this source tree.
from pathlib import Path
-import typing as tp
import torch
-import torchaudio
+
+from audiocraft.data.audio import audio_write
def get_white_noise(chs: int = 1, num_frames: int = 1):
@@ -22,11 +22,8 @@ def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1):
def save_wav(path: str, wav: torch.Tensor, sample_rate: int):
+ assert wav.dim() == 2, wav.shape
fp = Path(path)
- kwargs: tp.Dict[str, tp.Any] = {}
- if fp.suffix == '.wav':
- kwargs['encoding'] = 'PCM_S'
- kwargs['bits_per_sample'] = 16
- elif fp.suffix == '.mp3':
- kwargs['compression'] = 320
- torchaudio.save(str(fp), wav, sample_rate, **kwargs)
+ assert fp.suffix in ['.mp3', '.ogg', '.wav', '.flac'], fp
+ audio_write(fp.parent / fp.stem, wav, sample_rate, fp.suffix[1:],
+ normalize=False, strategy='clip', peak_clip_headroom_db=0)
diff --git a/tests/models/test_musicgen.py b/tests/models/test_musicgen.py
index 65618a9e..2b32ac5d 100644
--- a/tests/models/test_musicgen.py
+++ b/tests/models/test_musicgen.py
@@ -56,3 +56,10 @@ def test_generate_long(self):
wav = mg.generate(
['youpi', 'lapin dort'])
assert list(wav.shape) == [2, 1, 32000 * 4]
+
+ def test_generate_two_step_cfg(self):
+ mg = self.get_musicgen()
+ mg.set_generation_params(duration=2.0, extend_stride=2., two_step_cfg=True)
+ wav = mg.generate(
+ ['youpi', 'lapin dort'])
+ assert list(wav.shape) == [2, 1, 64000]
diff --git a/tests/modules/test_rope.py b/tests/modules/test_rope.py
index 067c6f06..ec8d16c0 100644
--- a/tests/modules/test_rope.py
+++ b/tests/modules/test_rope.py
@@ -11,7 +11,7 @@
def test_rope():
- set_efficient_attention_backend('xformers')
+ set_efficient_attention_backend('torch')
B, T, H, C = 8, 75, 16, 128
rope = RotaryEmbedding(dim=C)
@@ -24,7 +24,7 @@ def test_rope():
def test_rope_io_dtypes():
- set_efficient_attention_backend('xformers')
+ set_efficient_attention_backend('torch')
B, T, H, C = 8, 75, 16, 128
rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32)
@@ -48,7 +48,7 @@ def test_rope_io_dtypes():
def test_transformer_with_rope():
- set_efficient_attention_backend('xformers')
+ set_efficient_attention_backend('torch')
torch.manual_seed(1234)
for pos in ['rope', 'sin_rope']:
tr = StreamingTransformer(
@@ -64,7 +64,7 @@ def test_transformer_with_rope():
@torch.no_grad()
def test_rope_streaming():
- set_efficient_attention_backend('xformers')
+ set_efficient_attention_backend('torch')
torch.manual_seed(1234)
tr = StreamingTransformer(
16, 4, 2, causal=True, dropout=0.,
@@ -92,7 +92,7 @@ def test_rope_streaming():
@torch.no_grad()
def test_rope_streaming_past_context():
- set_efficient_attention_backend('xformers')
+ set_efficient_attention_backend('torch')
torch.manual_seed(1234)
for context in [None, 10]:
@@ -122,7 +122,7 @@ def test_rope_streaming_past_context():
def test_rope_memory_efficient():
- set_efficient_attention_backend('xformers')
+ set_efficient_attention_backend('torch')
torch.manual_seed(1234)
tr = StreamingTransformer(
16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
@@ -143,7 +143,7 @@ def test_rope_memory_efficient():
def test_rope_with_xpos():
- set_efficient_attention_backend('xformers')
+ set_efficient_attention_backend('torch')
B, T, H, C = 8, 75, 16, 128
rope = RotaryEmbedding(dim=C, xpos=True)
@@ -156,7 +156,7 @@ def test_rope_with_xpos():
def test_positional_scale():
- set_efficient_attention_backend('xformers')
+ set_efficient_attention_backend('torch')
B, T, H, C = 8, 75, 16, 128
rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0)
diff --git a/tests/modules/test_transformer.py b/tests/modules/test_transformer.py
index 2bb79bfd..ee74ba06 100644
--- a/tests/modules/test_transformer.py
+++ b/tests/modules/test_transformer.py
@@ -86,7 +86,7 @@ def test_streaming_api():
def test_memory_efficient():
- for backend in ['torch', 'xformers']:
+ for backend in ['torch']:
torch.manual_seed(1234)
set_efficient_attention_backend(backend)
@@ -132,7 +132,7 @@ def test_attention_as_float32():
@torch.no_grad()
def test_streaming_memory_efficient():
- for backend in ['torch', 'xformers']:
+ for backend in ['torch']:
torch.manual_seed(1234)
set_efficient_attention_backend(backend)
tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True)