This is the official repository of the CompoNet baseline defined in COCOLA: Coherence-Oriented Contrastive Learning of Musical Audio Representations. The code for the proper COCOLA model can be found at https://github.com/gladia-research-group/cocola.
conda create --name componet python=3.11
conda activate componet
pip install -r requirements.txt
Model Checkpoint | Train Dataset | Train Config | Description |
---|---|---|---|
musdb-conditional_epoch=423.ckpt | MusDB | exp/train_musdb_conditional.yaml |
CompoNet model trained on MusDB dataset using AudioLDM2-large as base model, finetuing ControlNet adapter. |
moisesdb-conditional_epoch=250.ckpt | MoisesDB | exp/train_moisesdb_conditional.yaml |
CompoNet model trained on MoisesDB dataset using AudioLDM2-large as base model, finetuing ControlNet adapter. |
slakh-conditional_epoch=93.ckpt | Slakh2100 | exp/train_slakh_conditional_attentions.yaml |
CompoNet model trained on Slakh2100 dataset using AudioLDM2-large as base model, finetuing ControlNet adapter and UNet cross-attentions. |
Inference can be performed using inference.ipynb
. The model is first instantiated and the checkpoint loaded. Specify
the model config (the Train Config
in the table above without .yaml
extension) as exp_cfg
and checkpoint path in ckpt_path
.
Then, load your input with:
y, sr = torchaudio.load("in.wav") # load you audio input
And specify the inference prompt
.
Full example with musdb-conditional
:
import hydra
import torch
import torchaudio
exp_cfg = "train_musdb_conditional"
ckpt_path = "../ckpts/musdb-conditional_epoch=423.ckpt"
with hydra.initialize(config_path="..", version_base=None):
cond_cfg = hydra.compose(config_name="config", overrides=[f'exp={exp_cfg}'])
model = hydra.utils.instantiate(cond_cfg["model"])
ckpt = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(ckpt['state_dict'], strict=False)
model = model.cuda()
y, sr = torchaudio.load("in.wav") # load you audio input
prompt = "in: other_1, vocals_1; out: vocals_1"
assert sr == 16000
y = torch.clip(y, -1, 1)
y_melspec = model.stft.cpu().mel_spectrogram(y.cpu())[0]
y_latents = model.vae.encode(y_melspec.cuda().permute(0, 2, 1).unsqueeze(1)).latent_dist.sample()
y_latents = y_latents * model.vae.config.scaling_factor
samples = model.pipeline([prompt], num_inference_steps=150,
guidance_scale=1.0, audio_length_in_s=10.23, controlnet_cond=y_latents.cuda()).audios
torchaudio.save(f"out.wav", torch.tensor(samples[0]).unsqueeze(0), sample_rate=sr)
For musdb-conditional
and slakh-conditional
the prompts do not have a genre attribute. For example:
prompt = "in: other_1, vocals_1; out: vocals_1"
For moisesdb-conditional
you have to specify a lowercase genre (e.g., pop
, rock
) preceding the input and output
tags:
prompt = "genre: pop; in: guitar_1, vocals_1; out: other_keys_1, drums_1"
The available stem tags for musdb-conditional
are
STEMS = ['bass', 'drums', 'vocals', 'other']
The available stem tags for slakh-conditional
are
STEMS = ['bass', 'drums', 'guitar', 'piano']
The available stem and genre tags for moisesdb-conditional
are
STEMS = ['bass', 'bowed_strings', 'drums', 'guitar', 'other', 'other_keys', 'other_plucked', 'percussion', 'piano', 'vocals', 'wind']
GENRES = ['blues', 'bossa_nova', 'country', 'electronic', 'jazz', 'musical_theatre', 'pop', 'rap', 'reggae', 'rock', 'singer_songwriter', 'world_folk']
For training first copy .env.tmp
and remove the .tmp
extension. Then modify the resulting .env
file changing the
following fields with your wandb
data:
WANDB_PROJECT=wandbprojectname
WANDB_ENTITY=wandbuser
WANDB_API_KEY=wandbapikey
Training then can be run by calling train.py
with the desired experiment. Data has to be provided as webdataset
shards: scripts for sharding the datasets will be provided. You also have to specify a TAG
describing the experiment.
TAG=moisesdb-conditional python train.py exp=train_moisesdb_conditional datamodule.train_dataset.path=data/moisesdb/{0..18}.tar datamodule.val_dataset.path=data/moisesdb/19.tar