Skip to content

Commit

Permalink
Merge pull request #313 from 0xlws/main
Browse files Browse the repository at this point in the history
refactor: rm unused parameter and redundant return statement + docs: typos, formatting
  • Loading branch information
JadeCopet authored Jan 11, 2024
2 parents e1f2b18 + e2d3f45 commit da8a3ac
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 59 deletions.
8 changes: 4 additions & 4 deletions audiocraft/models/encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


class CompressionModel(ABC, nn.Module):
"""Base API for all compression model that aim at being used as audio tokenizers
"""Base API for all compression models that aim at being used as audio tokenizers
with a language model.
"""

Expand Down Expand Up @@ -112,7 +112,7 @@ def get_pretrained(
logger.info("Getting pretrained compression model for debug")
model = builders.get_debug_compression_model()
elif Path(name).exists():
# We assume here if the paths exist that it is in fact an AC checkpoint
# We assume here if the path exists that it is in fact an AC checkpoint
# that was exported using `audiocraft.utils.export` functions.
model = loaders.load_compression_model(name, device=device)
else:
Expand Down Expand Up @@ -228,8 +228,8 @@ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Te
Returns:
codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
scale a float tensor containing the scale for audio renormalizealization.
codes: a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
scale: a float tensor containing the scale for audio renormalization.
"""
assert x.dim() == 3
x, scale = self.preprocess(x)
Expand Down
53 changes: 24 additions & 29 deletions audiocraft/models/multibanddiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ class DiffusionProcess:
noise_schedule (NoiseSchedule): Noise schedule for diffusion process.
"""
def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None:
"""
"""
self.model = model
self.schedule = noise_schedule

Expand All @@ -40,8 +38,8 @@ def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
"""Perform one diffusion process to generate one of the bands.
Args:
condition (tensor): The embeddings form the compression model.
initial_noise (tensor): The initial noise to start the process/
condition (torch.Tensor): The embeddings from the compression model.
initial_noise (torch.Tensor): The initial noise to start the process.
"""
return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list,
condition=condition)
Expand Down Expand Up @@ -80,14 +78,13 @@ def get_mbd_musicgen(device=None):
return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)

@staticmethod
def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True,
def get_mbd_24khz(bw: float = 3.0,
device: tp.Optional[tp.Union[torch.device, str]] = None,
n_q: tp.Optional[int] = None):
"""Get the pretrained Models for MultibandDiffusion.
Args:
bw (float): Bandwidth of the compression model.
pretrained (bool): Whether to use / download if necessary the models.
device (torch.device or str, optional): Device on which the models are loaded.
n_q (int, optional): Number of quantizers to use within the compression model.
"""
Expand All @@ -112,14 +109,12 @@ def get_mbd_24khz(bw: float = 3.0, pretrained: bool = True,
DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)

return MultiBandDiffusion(DPs, codec_model)

@torch.no_grad()
def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
"""Get the conditioning (i.e. latent reprentatios of the compression model) from a waveform.
"""Get the conditioning (i.e. latent representations of the compression model) from a waveform.
Args:
wav (torch.Tensor): The audio that we want to extract the conditioning from
sample_rate (int): sample rate of the audio"""
wav (torch.Tensor): The audio that we want to extract the conditioning from.
sample_rate (int): Sample rate of the audio."""
if sample_rate != self.sample_rate:
wav = julius.resample_frac(wav, sample_rate, self.sample_rate)
codes, scale = self.codec_model.encode(wav)
Expand All @@ -129,20 +124,20 @@ def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:

@torch.no_grad()
def get_emb(self, codes: torch.Tensor):
"""Get latent representation from the discrete codes
Argrs:
codes (torch.Tensor): discrete tokens"""
"""Get latent representation from the discrete codes.
Args:
codes (torch.Tensor): Discrete tokens."""
emb = self.codec_model.decode_latent(codes)
return emb

def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
step_list: tp.Optional[tp.List[int]] = None):
"""Generate Wavform audio from the latent embeddings of the compression model
"""Generate waveform audio from the latent embeddings of the compression model.
Args:
emb (torch.Tensor): Conditioning embeddinds
size (none torch.Size): size of the output
if None this is computed from the typical upsampling of the model
step_list (optional list[int]): list of Markov chain steps, defaults to 50 linearly spaced step.
emb (torch.Tensor): Conditioning embeddings
size (None, torch.Size): Size of the output
if None this is computed from the typical upsampling of the model.
step_list (list[int], optional): list of Markov chain steps, defaults to 50 linearly spaced step.
"""
if size is None:
upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate)
Expand All @@ -154,12 +149,12 @@ def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
return out

def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1):
"""match the eq to the encodec output by matching the standard deviation of some frequency bands
"""Match the eq to the encodec output by matching the standard deviation of some frequency bands.
Args:
wav (torch.Tensor): audio to equalize
ref (torch.Tensor):refenrence audio from which we match the spectrogram.
n_bands (int): number of bands of the eq
strictness (float): how strict the the matching. 0 is no matching, 1 is exact matching.
wav (torch.Tensor): Audio to equalize.
ref (torch.Tensor): Reference audio from which we match the spectrogram.
n_bands (int): Number of bands of the eq.
strictness (float): How strict the matching. 0 is no matching, 1 is exact matching.
"""
split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device)
bands = split(wav)
Expand All @@ -170,10 +165,10 @@ def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictn
return out

def regenerate(self, wav: torch.Tensor, sample_rate: int):
"""Regenerate a wavform through compression and diffusion regeneration.
"""Regenerate a waveform through compression and diffusion regeneration.
Args:
wav (torch.Tensor): Original 'ground truth' audio
sample_rate (int): sample rate of the input (and output) wav
wav (torch.Tensor): Original 'ground truth' audio.
sample_rate (int): Sample rate of the input (and output) wav.
"""
if sample_rate != self.codec_model.sample_rate:
wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate)
Expand All @@ -187,8 +182,8 @@ def regenerate(self, wav: torch.Tensor, sample_rate: int):
def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
"""Generate Waveform audio with diffusion from the discrete codes.
Args:
tokens (torch.Tensor): discrete codes
n_bands (int): bands for the eq matching.
tokens (torch.Tensor): Discrete codes.
n_bands (int): Bands for the eq matching.
"""
wav_encodec = self.codec_model.decode(tokens)
condition = self.get_emb(tokens)
Expand Down
18 changes: 9 additions & 9 deletions docs/CONDITIONING.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ For now, we support 3 main types of conditioning within AudioCraft:
* Joint embedding conditioning methods for text and audio projected in a shared latent space.

The Language Model relies on 2 core components that handle processing information:
* The `ConditionProvider` class, that maps metadata to processed conditions leveraging
* The `ConditionProvider` class, that maps metadata to processed conditions, leveraging
all the defined conditioners for the given task.
* The `ConditionFuser` class, that takes preprocessed conditions and properly fuse the
conditioning embedding to the language model inputs following a given fusing strategy.
Expand All @@ -29,7 +29,7 @@ conditioning signals and feed them to the language model.

### Conditioners

The `BaseConditioner` torch module is the base implementation for all conditioners in audiocraft.
The `BaseConditioner` torch module is the base implementation for all conditioners in AudioCraft.

Each conditioner is expected to implement 2 methods:
* The `tokenize` method that is used as a preprocessing method that contains all processing
Expand All @@ -45,10 +45,10 @@ The ConditionProvider prepares and provides conditions given a dictionary of con
Conditioners are specified as a dictionary of attributes and the corresponding conditioner
providing the processing logic for the given attribute.

Similarly to the conditioners, the condition provider works in two steps to avoid sychronization points:
Similarly to the conditioners, the condition provider works in two steps to avoid synchronization points:
* A `tokenize` method that takes a list of conditioning attributes for the batch,
and run all tokenize steps for the set of conditioners.
* A `forward` method that takes the output of the tokenize step and run all the forward steps
and runs all tokenize steps for the set of conditioners.
* A `forward` method that takes the output of the tokenize step and runs all the forward steps
for the set of conditioners.

The list of conditioning attributes is passed as a list of `ConditioningAttributes`
Expand Down Expand Up @@ -111,15 +111,15 @@ frozen or fine-tuned at train time to extract the text embeddings.
### Waveform conditioners

All waveform conditioners are expected to inherit from the `WaveformConditioner` class and
consists of conditioning method that takes a waveform as input. The waveform conditioner
consist of a conditioning method that takes a waveform as input. The waveform conditioner
must implement the logic to extract the embedding from the waveform and define the downsampling
factor from the waveform to the resulting embedding.

The `ChromaStemConditioner` conditioner is a waveform conditioner for the chroma features
conditioning used by MusicGen. It takes a given waveform, extract relevant stems for melody
conditioning used by MusicGen. It takes a given waveform, extracts relevant stems for melody
(namely all non drums and bass stems) using a
[pre-trained Demucs model](https://github.com/facebookresearch/demucs)
and then extract the chromagram bins from the remaining mix of stems.
and then extracts the chromagram bins from the remaining mix of stems.

### Joint embeddings conditioners

Expand All @@ -143,4 +143,4 @@ not to expect all conditioning signals to be provided at once.
Conditioners that require some heavy computation on the waveform can be cached, in particular
the `ChromaStemConditioner` or `CLAPEmbeddingConditioner`. You just need to provide the
`cache_path` parameter to them. We recommend running dummy jobs for filling up the cache quickly.
An example is provied in the [musicgen.musicgen_melody_32khz grid](../audiocraft/grids/musicgen/musicgen_melody_32khz.py).
An example is provided in the [musicgen.musicgen_melody_32khz grid](../audiocraft/grids/musicgen/musicgen_melody_32khz.py).
21 changes: 11 additions & 10 deletions docs/ENCODEC.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# EnCodec: High Fidelity Neural Audio Compression

AudioCraft provides the training code for EnCodec, a state-of-the-art deep learning
based audio codec supporting both mono stereo audio, presented in the
based audio codec supporting both mono and stereo audio, presented in the
[High Fidelity Neural Audio Compression][arxiv] paper.
Check out our [sample page][encodec_samples].

Expand All @@ -26,7 +26,7 @@ task to train an EnCodec model. Specifically, it trains an encoder-decoder with
bottleneck - a SEANet encoder-decoder with Residual Vector Quantization bottleneck for EnCodec -
using a combination of objective and perceptual losses in the forms of discriminators.

The default configuration matches a causal EnCodec training with at a single bandwidth.
The default configuration matches a causal EnCodec training at a single bandwidth.

### Example configuration and grids

Expand All @@ -45,7 +45,7 @@ dora grid compression.encodec_base_24khz
dora grid compression.encodec_musicgen_32khz
```

### Training and valid stages
### Training and validation stages

The model is trained using a combination of objective and perceptual losses.
More specifically, EnCodec is trained with the MS-STFT discriminator along with
Expand All @@ -54,7 +54,7 @@ the different losses, in an intuitive manner.

### Evaluation stage

Evaluations metrics for audio generation:
Evaluation metrics for audio generation:
* SI-SNR: Scale-Invariant Signal-to-Noise Ratio.
* ViSQOL: Virtual Speech Quality Objective Listener.

Expand Down Expand Up @@ -110,8 +110,9 @@ import logging
import os
import sys

# uncomment the following line if you want some detailed logs when loading a Solver.
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
# Uncomment the following line if you want some detailed logs when loading a Solver.
# logging.basicConfig(stream=sys.stderr, level=logging.INFO)

# You must always run the following function from the root directory.
os.chdir(Path(train.__file__).parent.parent)

Expand All @@ -126,10 +127,10 @@ solver.dataloaders
### Importing / Exporting models

At the moment we do not have a definitive workflow for exporting EnCodec models, for
instance to Hugging Face (HF). We are working on supporting automatic convertion between
instance to Hugging Face (HF). We are working on supporting automatic conversion between
AudioCraft and Hugging Face implementations.

We still have some support for fine tuning an EnCodec model coming from HF in AudioCraft,
We still have some support for fine-tuning an EnCodec model coming from HF in AudioCraft,
using for instance `continue_from=//pretrained/facebook/encodec_32k`.

An AudioCraft checkpoint can be exported in a more compact format (excluding the optimizer etc.)
Expand All @@ -148,11 +149,11 @@ from audiocraft.models import CompressionModel
model = CompressionModel.get_pretrained('/checkpoints/my_audio_lm/compression_state_dict.bin')

from audiocraft.solvers import CompressionSolver
# The two are strictly equivalent, but this function supports also loading from non already exported models.
# The two are strictly equivalent, but this function supports also loading from non-already exported models.
model = CompressionSolver.model_from_checkpoint('//pretrained//checkpoints/my_audio_lm/compression_state_dict.bin')
```

We will see then how to use this model as a tokenizer for MusicGen/Audio gen in the
We will see then how to use this model as a tokenizer for MusicGen/AudioGen in the
[MusicGen documentation](./MUSICGEN.md).

### Learn more
Expand Down
14 changes: 7 additions & 7 deletions docs/TRAINING.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ AudioCraft training pipelines are designed to be research and experiment-friendl
## Environment setup

For the base installation, follow the instructions from the [README.md](../README.md).
Below are some additional instructions for setting up environment to train new models.
Below are some additional instructions for setting up the environment to train new models.

### Team and cluster configuration

Expand Down Expand Up @@ -47,7 +47,7 @@ properly set the `dora_dir` entries.

#### Overriding environment configurations

You can set the following environmet variables to bypass the team's environment configuration:
You can set the following environment variables to bypass the team's environment configuration:
* `AUDIOCRAFT_CONFIG`: absolute path to a team config yaml file.
* `AUDIOCRAFT_DORA_DIR`: absolute path to a custom dora directory.
* `AUDIOCRAFT_REFERENCE_DIR`: absolute path to the shared reference directory.
Expand Down Expand Up @@ -199,7 +199,7 @@ Once this configuration is created and used for running experiments, you should

Note that as we are using Dora as our experiment manager, all our experiment tracking is based on
signatures computed from delta between configurations.
**One must therefore ensure backward compatibilty of the configuration at all time.**
**One must therefore ensure backward compatibility of the configuration at all time.**
See [Dora's README](https://github.com/facebookresearch/dora) and the
[section below introduction Dora](#running-experiments-with-dora).

Expand Down Expand Up @@ -255,7 +255,7 @@ of those hyper-parameters. We always refer to an XP with its signature, e.g. 935
after that one can retrieve the hyper-params and re-rerun it in a single command.
* In fact, the hash is defined as a delta between the base config and the one obtained
with the config overrides you passed from the command line. This means you must never change
the `conf/**.yaml` files directly., except for editing things like paths. Changing the default values
the `conf/**.yaml` files directly, except for editing things like paths. Changing the default values
in the config files means the XP signature won't reflect that change, and wrong checkpoints might be reused.
I know, this is annoying, but the reason is that otherwise, any change to the config file would mean
that all XPs ran so far would see their signature change.
Expand All @@ -276,7 +276,7 @@ dora run -d -f 81de367c dataset.batch_size=32 # start from the config of XP 81d
dora info -f SIG -t # will tail the log (if the XP has scheduled).
# if you need to access the logs of the process for rank > 0, in particular because a crash didn't happen in the main
# process, then use `dora info -f SIG` to get the main log name (finished into something like `/5037674_0_0_log.out`)
# and worker K can accessed as `/5037674_0_{K}_log.out`.
# and worker K can be accessed as `/5037674_0_{K}_log.out`.
# This is only for scheduled jobs, for local distributed runs with `-d`, then you should go into the XP folder,
# and look for `worker_{K}.log` logs.
```
Expand All @@ -290,9 +290,9 @@ a previous checkpoint you can use `dora run --clear [RUN ARGS]`.
If you have a Slurm cluster, you can also use the dora grid command, e.g.

```shell
# run a dummy grid located at `audiocraft/grids/my_grid_folder/my_grid_name.py`
# Run a dummy grid located at `audiocraft/grids/my_grid_folder/my_grid_name.py`
dora grid my_grid_folder.my_grid_name
# Run the following will simply display the grid and also initialized the Dora experiments database.
# The following will simply display the grid and also initialize the Dora experiments database.
# You can then simply refer to a config using its signature (e.g. as `dora run -f SIG`).
dora grid my_grid_folder.my_grid_name --dry_run --init
```
Expand Down

0 comments on commit da8a3ac

Please sign in to comment.