Skip to content

Commit

Permalink
Merge pull request #1 from chenxwh/main
Browse files Browse the repository at this point in the history
Add Replicate demo and API
  • Loading branch information
soujanyaporia authored Nov 16, 2023
2 parents 9d1f4f2 + 1db6e0e commit b8cd624
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 62 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Mustango: Toward Controllable Text-to-Music Generation

[Demo]() [Model](https://huggingface.co/declare-lab/mustango) [Website and Examples](https://amaai-lab.github.io/mustango/) [Paper](https://arxiv.org/abs/2311.08355) [Dataset](https://huggingface.co/datasets/amaai-lab/MusicBench)
[Demo](https://replicate.com/declare-lab/mustango) [Model](https://huggingface.co/declare-lab/mustango) [Website and Examples](https://amaai-lab.github.io/mustango/) [Paper](https://arxiv.org/abs/2311.08355) [Dataset](https://huggingface.co/datasets/amaai-lab/MusicBench)
</div>

Meet Mustango, an exciting addition to the vibrant landscape of Multimodal Large Language Models designed for controlled music generation. Mustango leverages Latent Diffusion Model (LDM), Flan-T5, and musical features to do the magic!
Expand Down
32 changes: 32 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Configuration for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md

build:
gpu: true
python_version: "3.11"
python_packages:
- torch==2.0.1
- torchaudio==2.0.2
- transformers==4.27.0
- accelerate==0.18.0
- datasets==2.1.0
- einops==0.6.1
- huggingface_hub==0.13.3
- importlib_metadata==6.3.0
- librosa==0.9.2
- matplotlib==3.5.2
- omegaconf==2.3.0
- packaging==23.1
- progressbar33==2.4
- protobuf==3.20.*
- resampy==0.4.2
- scikit_image==0.22.0
- scikit_learn==1.2.2
- scipy==1.11.3
- soundfile==0.12.1
- ssr_eval==0.0.6
- torchlibrosa==0.1.0
- tqdm==4.63.1
- sentencepiece==0.1.99
- diffusers==0.15.0
predict: "predict.py:Predictor"
182 changes: 121 additions & 61 deletions mustango.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,134 +11,194 @@
from diffusers import DDPMScheduler
from models import MusicAudioDiffusion


class MusicFeaturePredictor:

def __init__(self, path, device="cuda:0"):
self.beats_tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large")
self.beats_model = DebertaV2ForTokenClassificationRegression.from_pretrained("microsoft/deberta-v3-large")
def __init__(self, path, device="cuda:0", cache_dir=None, local_files_only=False):
self.beats_tokenizer = AutoTokenizer.from_pretrained(
"microsoft/deberta-v3-large",
cache_dir=cache_dir,
local_files_only=local_files_only,
)
self.beats_model = DebertaV2ForTokenClassificationRegression.from_pretrained(
"microsoft/deberta-v3-large",
cache_dir=cache_dir,
local_files_only=local_files_only,
)
self.beats_model.eval()
self.beats_model.to(device)

beats_ckpt = f"{path}/beats/microsoft-deberta-v3-large.pt"
beats_weight = torch.load(beats_ckpt, map_location="cpu")
self.beats_model.load_state_dict(beats_weight)

self.chords_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
self.chords_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")

self.chords_tokenizer = AutoTokenizer.from_pretrained(
"google/flan-t5-large",
cache_dir=cache_dir,
local_files_only=local_files_only,
)
self.chords_model = T5ForConditionalGeneration.from_pretrained(
"google/flan-t5-large",
cache_dir=cache_dir,
local_files_only=local_files_only,
)
self.chords_model.eval()
self.chords_model.to(device)

chords_ckpt = f"{path}/chords/flan-t5-large.bin"
chords_weight = torch.load(chords_ckpt, map_location="cpu")
self.chords_model.load_state_dict(chords_weight)



def generate_beats(self, prompt):
tokenized = self.beats_tokenizer(prompt, max_length=512, padding=True, truncation=True, return_tensors="pt")
tokenized = self.beats_tokenizer(
prompt, max_length=512, padding=True, truncation=True, return_tensors="pt"
)
tokenized = {k: v.to(self.beats_model.device) for k, v in tokenized.items()}

with torch.no_grad():
out = self.beats_model(**tokenized)
out = self.beats_model(**tokenized)

max_beat = (1 + torch.argmax(out["logits"][:, 0, :], -1).detach().cpu().numpy()).tolist()[0]
intervals = out["values"][:, :, 0].detach().cpu().numpy().astype("float32").round(4).tolist()

intervals = np.cumsum(intervals)
max_beat = (
1 + torch.argmax(out["logits"][:, 0, :], -1).detach().cpu().numpy()
).tolist()[0]
intervals = (
out["values"][:, :, 0]
.detach()
.cpu()
.numpy()
.astype("float32")
.round(4)
.tolist()
)

intervals = np.cumsum(intervals)
predicted_beats_times = []
for t in intervals:
if t < 10:
predicted_beats_times.append(round(t, 2))
else:
break
predicted_beats_times = list(np.array(predicted_beats_times)[:50])



if len(predicted_beats_times) == 0:
predicted_beats = [[],[]]
predicted_beats = [[], []]
else:
beat_counts = []
for i in range(len(predicted_beats_times)):
beat_counts.append(float(1.0+np.mod(i, max_beat)))
beat_counts.append(float(1.0 + np.mod(i, max_beat)))
predicted_beats = [[predicted_beats_times, beat_counts]]

return max_beat, predicted_beats_times, predicted_beats



def generate(self, prompt):
max_beat, predicted_beats_times, predicted_beats = self.generate_beats(prompt)

chords_prompt = "Caption: {} \\n Timestamps: {} \\n Max Beat: {}".format(
prompt, " , ".join([str(round(t, 2)) for t in predicted_beats_times]), max_beat
prompt,
" , ".join([str(round(t, 2)) for t in predicted_beats_times]),
max_beat,
)

tokenized = self.chords_tokenizer(chords_prompt, max_length=512, padding=True, truncation=True, return_tensors="pt")
tokenized = self.chords_tokenizer(
chords_prompt,
max_length=512,
padding=True,
truncation=True,
return_tensors="pt",
)
tokenized = {k: v.to(self.chords_model.device) for k, v in tokenized.items()}

generated_chords = self.chords_model.generate(
input_ids=tokenized["input_ids"], attention_mask=tokenized["attention_mask"],
min_length=8, max_length=128, num_beams=5, early_stopping=True, num_return_sequences=1
input_ids=tokenized["input_ids"],
attention_mask=tokenized["attention_mask"],
min_length=8,
max_length=128,
num_beams=5,
early_stopping=True,
num_return_sequences=1,
)

generated_chords = self.chords_tokenizer.decode(
generated_chords[0], skip_special_tokens=True, clean_up_tokenization_spaces=True
generated_chords[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
).split(" n ")



predicted_chords, predicted_chords_times = [], []
for item in generated_chords:
c, ct = item.split(" at ")
predicted_chords.append(c)
predicted_chords_times.append(float(ct))

return predicted_beats, predicted_chords, predicted_chords_times


class Mustango:

def __init__(self, name="declare-lab/mustango", device="cuda:0"):

path = snapshot_download(repo_id=name)

self.music_model = MusicFeaturePredictor(path, device)

def __init__(
self,
name="declare-lab/mustango",
device="cuda:0",
cache_dir=None,
local_files_only=False,
):
path = snapshot_download(repo_id=name, cache_dir=cache_dir)

self.music_model = MusicFeaturePredictor(
path, device, cache_dir=cache_dir, local_files_only=local_files_only
)

vae_config = json.load(open(f"{path}/configs/vae_config.json"))
stft_config = json.load(open(f"{path}/configs/stft_config.json"))
main_config = json.load(open(f"{path}/configs/main_config.json"))

self.vae = AutoencoderKL(**vae_config).to(device)
self.stft = TacotronSTFT(**stft_config).to(device)
self.model = MusicAudioDiffusion(
main_config["text_encoder_name"], main_config["scheduler_name"],
unet_model_config_path=f"{path}/configs/music_diffusion_model_config.json"
main_config["text_encoder_name"],
main_config["scheduler_name"],
unet_model_config_path=f"{path}/configs/music_diffusion_model_config.json",
).to(device)

vae_weights = torch.load(f"{path}/vae/pytorch_model_vae.bin", map_location=device)
stft_weights = torch.load(f"{path}/stft/pytorch_model_stft.bin", map_location=device)
main_weights = torch.load(f"{path}/ldm/pytorch_model_ldm.bin", map_location=device)


vae_weights = torch.load(
f"{path}/vae/pytorch_model_vae.bin", map_location=device
)
stft_weights = torch.load(
f"{path}/stft/pytorch_model_stft.bin", map_location=device
)
main_weights = torch.load(
f"{path}/ldm/pytorch_model_ldm.bin", map_location=device
)

self.vae.load_state_dict(vae_weights)
self.stft.load_state_dict(stft_weights)
self.model.load_state_dict(main_weights)

print ("Successfully loaded checkpoint from:", name)
print("Successfully loaded checkpoint from:", name)

self.vae.eval()
self.stft.eval()
self.model.eval()

self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder="scheduler")



self.scheduler = DDPMScheduler.from_pretrained(
main_config["scheduler_name"], subfolder="scheduler"
)

def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):
""" Genrate music for a single prompt string. """
"""Genrate music for a single prompt string."""

with torch.no_grad():
beats, chords, chords_times = self.music_model.generate(prompt)
latents = self.model.inference(
[prompt], beats, [chords], [chords_times], self.scheduler,
steps, guidance, samples, disable_progress
[prompt],
beats,
[chords],
[chords_times],
self.scheduler,
steps,
guidance,
samples,
disable_progress,
)
mel = self.vae.decode_first_stage(latents)
wave = self.vae.decode_to_waveform(mel)
return wave[0]

return wave[0]
36 changes: 36 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md

import subprocess

subprocess.run("cd diffusers && pip install . && cd ..", shell=True, check=True)

import soundfile as sf
from cog import BasePredictor, Input, Path
from mustango import Mustango


class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
cache_dir = "model_cache"
local_files_only = True # set to True if models are cached in cache_dir
self.model = Mustango(
"declare-lab/mustango", cache_dir=cache_dir, local_files_only=local_files_only
)

def predict(
self,
prompt: str = Input(
description="Input prompt.",
default="This is a new age piece. There is a flute playing the main melody with a lot of staccato notes. The rhythmic background consists of a medium tempo electronic drum beat with percussive elements all over the spectrum. There is a playful atmosphere to the piece. This piece can be used in the soundtrack of a children's TV show or an advertisement jingle.",
),
steps: int = Input(description="inferene steps", default=100),
guidance: float = Input(description="guidance scale", default=3),
) -> Path:
"""Run a single prediction on the model"""

music = self.model.generate(prompt, steps=steps, guidance=guidance)
out = "/tmp/output.wav"
sf.write(out, music, samplerate=16000)
return Path(out)

0 comments on commit b8cd624

Please sign in to comment.