Skip to content

Commit

Permalink
Merge branch 'main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
felixkreuk authored Nov 23, 2023
2 parents ba67e4d + c65e9be commit 201b542
Show file tree
Hide file tree
Showing 36 changed files with 586 additions and 165 deletions.
2 changes: 2 additions & 0 deletions .github/actions/audiocraft_build/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ runs:
python3 -m venv env
. env/bin/activate
python -m pip install --upgrade pip
pip install torch torchvision torchaudio
pip install xformers
pip install -e '.[dev]'
- name: System Dependencies
shell: bash
Expand Down
25 changes: 25 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,31 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [1.2.0a] - TBD

Adding stereo models.


## [1.1.0] - 2023-11-06

Not using torchaudio anymore when writing audio files, relying instead directly on the commandline ffmpeg. Also not using it anymore for reading audio files, for similar reasons.

Fixed DAC support with non default number of codebooks.

Fixed bug when `two_step_cfg` was overriden when calling `generate()`.

Fixed samples being always prompted with audio, rather than having both prompted and unprompted.

**Backward incompatible change:** A `torch.no_grad` around the computation of the conditioning made its way in the public release.
The released models were trained without this. Those impact linear layers applied to the output of the T5 or melody conditioners.
We removed it, so you might need to retrain models.

**Backward incompatible change:** Fixing wrong sample rate in CLAP (WARNING if you trained model with CLAP before).

**Backward incompatible change:** Renamed VALLEPattern to CoarseFirstPattern, as it was wrongly named. Probably no one
retrained a model with this pattern, so hopefully this won't impact you!


## [1.0.0] - 2023-09-07

Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can ru
```shell
# Best to make sure you have torch installed first, in particular before installing xformers.
# Don't run this if you already have PyTorch installed.
pip install 'torch>=2.0'
python -m pip install 'torch>=2.0'
# Then proceed to one of the following
pip install -U audiocraft # stable release
pip install -U git+https://[email protected]/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
pip install -e . # or if you cloned the repo locally (mandatory if you want to train).
python -m pip install -U audiocraft # stable release
python -m pip install -U git+https://[email protected]/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
python -m pip install -e . # or if you cloned the repo locally (mandatory if you want to train).
```

We also recommend having `ffmpeg` installed, either through your system or Anaconda:
Expand Down Expand Up @@ -72,11 +72,11 @@ Finally, if you use a model that relies on Demucs (e.g. `musicgen-melody`) and w

For the general framework of AudioCraft, please cite the following.
```
@article{copet2023simple,
@inproceedings{copet2023simple,
title={Simple and Controllable Music Generation},
author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
journal={arXiv preprint arXiv:2306.05284},
}
```

Expand Down
2 changes: 1 addition & 1 deletion audiocraft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@
# flake8: noqa
from . import data, modules, models

__version__ = '1.0.0'
__version__ = '1.2.0a1'
45 changes: 29 additions & 16 deletions audiocraft/data/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import soundfile
import torch
from torch.nn import functional as F
import torchaudio as ta

import av
import subprocess as sp

from .audio_utils import f32_pcm, i16_pcm, normalize_audio
from .audio_utils import f32_pcm, normalize_audio


_av_initialized = False
Expand Down Expand Up @@ -136,12 +136,6 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
wav = torch.from_numpy(wav).t().contiguous()
if len(wav.shape) == 1:
wav = torch.unsqueeze(wav, 0)
elif (
fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
and duration <= 0 and seek_time == 0
):
# Torchaudio is faster if we load an entire file at once.
wav, sr = ta.load(fp)
else:
wav, sr = _av_read(filepath, seek_time, duration)
if pad and duration > 0:
Expand All @@ -150,10 +144,22 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
return wav, sr


def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
# ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
assert wav.dim() == 2, wav.shape
command = [
'ffmpeg',
'-loglevel', 'error',
'-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
'-i', '-'] + flags + [str(out_path)]
input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
sp.run(command, input=input_, check=True)


def audio_write(stem_name: tp.Union[str, Path],
wav: torch.Tensor, sample_rate: int,
format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
loudness_compressor: bool = False,
log_clipping: bool = True, make_parent_dir: bool = True,
Expand All @@ -164,8 +170,9 @@ def audio_write(stem_name: tp.Union[str, Path],
stem_name (str or Path): Filename without extension which will be added automatically.
wav (torch.Tensor): Audio data to save.
sample_rate (int): Sample rate of audio data.
format (str): Either "wav" or "mp3".
format (str): Either "wav", "mp3", "ogg", or "flac".
mp3_rate (int): kbps when using mp3s.
ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
normalize (bool): if `True` (default), normalizes according to the prescribed
strategy (see after). If `False`, the strategy is only used in case clipping
would happen.
Expand Down Expand Up @@ -193,14 +200,20 @@ def audio_write(stem_name: tp.Union[str, Path],
rms_headroom_db, loudness_headroom_db, loudness_compressor,
log_clipping=log_clipping, sample_rate=sample_rate,
stem_name=str(stem_name))
kwargs: dict = {}
if format == 'mp3':
suffix = '.mp3'
kwargs.update({"compression": mp3_rate})
flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
elif format == 'wav':
wav = i16_pcm(wav)
suffix = '.wav'
kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
elif format == 'ogg':
suffix = '.ogg'
flags = ['-f', 'ogg', '-c:a', 'libvorbis']
if ogg_rate is not None:
flags += ['-b:a', f'{ogg_rate}k']
elif format == 'flac':
suffix = '.flac'
flags = ['-f', 'flac']
else:
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
if not add_suffix:
Expand All @@ -209,7 +222,7 @@ def audio_write(stem_name: tp.Union[str, Path],
if make_parent_dir:
path.parent.mkdir(exist_ok=True, parents=True)
try:
ta.save(path, wav, sample_rate, **kwargs)
_piping_to_ffmpeg(path, wav, sample_rate, flags)
except Exception:
if path.exists():
# we do not want to leave half written files around.
Expand Down
57 changes: 57 additions & 0 deletions audiocraft/grids/musicgen/musicgen_stereo_finetune_32khz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from pathlib import Path
from ._explorers import LMExplorer
from ...environment import AudioCraftEnvironment


@LMExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=32, partition=partitions)
launcher.bind_(solver='musicgen/musicgen_base_32khz')
# replace this by the desired music dataset, which needs to be stereo
launcher.bind_(dset='audio/example')

fsdp = {'autocast': False, 'fsdp.use': True}
medium = {'model/lm/model_scale': 'medium'}
large = {'model/lm/model_scale': 'large'}

cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
wd_low = {'conditioners.description.t5.word_dropout': 0.2}

adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}

stereo = {
'codebooks_pattern.delay.delays': [0, 0, 1, 1, 2, 2, 3, 3],
'transformer_lm.n_q': 8,
'interleave_stereo_codebooks.use': True,
'channels': 2,
}

# You must follow the instructions in docs/MUSICGEN.md about the creation
# of the proper fine tuning checkpoints. We will assume they are stored under
# ~/checkpoints/{mode_name}.

checkpoints = Path.home() / 'checkpoints'

launcher.bind_(fsdp, stereo, {'optim.epochs': 100})

launcher.slurm_(gpus=32).bind_(label='32gpus')
with launcher.job_array():
sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-small.th')})
sub()

launcher.slurm_(gpus=64).bind_(label='64gpus')
with launcher.job_array():
sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-medium.th')})
sub(medium, adam)

launcher.slurm_(gpus=96).bind_(label='96gpus')
with launcher.job_array():
sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-large.th')})
sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
4 changes: 4 additions & 0 deletions audiocraft/models/audiogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
self.name = name
self.compression_model = compression_model
self.lm = lm
# Just to be safe, let's put everything in eval mode.
self.compression_model.eval()
self.lm.eval()

if max_duration is None:
if hasattr(lm, 'cfg'):
max_duration = lm.cfg.dataset.segment_duration # type: ignore
Expand Down
16 changes: 11 additions & 5 deletions audiocraft/models/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
import omegaconf
import torch

from .encodec import CompressionModel, EncodecModel
from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel
from .lm import LMModel
from ..modules.codebooks_patterns import (
CodebooksPatternProvider,
DelayedPatternProvider,
MusicLMPattern,
ParallelPatternProvider,
UnrolledPatternProvider,
VALLEPattern,
CoarseFirstPattern,
)
from ..modules.conditioners import (
BaseConditioner,
Expand Down Expand Up @@ -172,7 +172,7 @@ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> Codeb
'parallel': ParallelPatternProvider,
'delay': DelayedPatternProvider,
'unroll': UnrolledPatternProvider,
'valle': VALLEPattern,
'coarse_first': CoarseFirstPattern,
'musiclm': MusicLMPattern,
}
name = cfg.modeling
Expand All @@ -196,7 +196,6 @@ def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
'dimension': 32,
'ratios': ratios,
}
print(seanet_kwargs)
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
Expand Down Expand Up @@ -248,5 +247,12 @@ def get_debug_lm_model(device='cpu'):
def get_wrapped_compression_model(
compression_model: CompressionModel,
cfg: omegaconf.DictConfig) -> CompressionModel:
# more to come.
if hasattr(cfg, 'interleave_stereo_codebooks'):
if cfg.interleave_stereo_codebooks.use:
kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
kwargs.pop('use')
compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs)
if hasattr(cfg, 'compression_model_n_q'):
if cfg.compression_model_n_q is not None:
compression_model.set_num_codebooks(cfg.compression_model_n_q)
return compression_model
Loading

0 comments on commit 201b542

Please sign in to comment.