forked from zqevans/audio-diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
169 lines (138 loc) · 6.53 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#!/usr/bin/env python3
import argparse
from contextlib import contextmanager
from pathlib import Path
from random import randint
import sys
from glob import glob
# from einops import rearrange
import pytorch_lightning as pl
from pytorch_lightning.utilities.distributed import rank_zero_only
import torch
from torch.utils import data
import torchaudio
from torchaudio import transforms as T
import wandb
from dataset.dataset import SampleDataset
from diffusion.inference import sample
from diffusion.model import LightningDiffusion, AudioPerceiverEncoder, SelfSupervisedLearner, Transpose
from diffusion.pqmf import CachedPQMF as PQMF
from diffusion.utils import MidSideEncoding, PadCrop, RandomGain
# Define utility functions
@contextmanager
def train_mode(model, mode=True):
"""A context manager that places a model into training mode and restores
the previous mode on exit."""
modes = [module.training for module in model.modules()]
try:
yield model.train(mode)
finally:
for i, module in enumerate(model.modules()):
module.training = modes[i]
def eval_mode(model):
"""A context manager that places a model into evaluation mode and restores
the previous mode on exit."""
return train_mode(model, False)
class DemoCallback(pl.Callback):
def __init__(self, global_args):
super().__init__()
self.pqmf = PQMF(2, 70, global_args.pqmf_bands)
self.demo_dir = global_args.demo_dir
self.demo_samples = global_args.sample_size
self.demo_every = global_args.demo_every
self.demo_steps = global_args.demo_steps
self.ms_encoder = MidSideEncoding()
self.pad_crop = PadCrop(global_args.sample_size)
@rank_zero_only
@torch.no_grad()
def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx, unused=0):
last_demo_step = -1
if (trainer.global_step - 1) % self.demo_every != 0 or last_demo_step == trainer.global_step:
return
last_demo_step = trainer.global_step
demo_files = glob(f'{self.demo_dir}/**/*.wav', recursive=True)
audio_batch = torch.zeros(len(demo_files), 2, self.demo_samples)
for i, demo_file in enumerate(demo_files):
audio, sr = torchaudio.load(demo_file)
audio = audio.clamp(-1, 1)
audio = self.pad_crop(audio)
audio = self.ms_encoder(audio)
audio_batch[i] = audio
audio_batch = self.pqmf(audio_batch)
audio_batch = audio_batch.to(module.device)
with eval_mode(module):
fakes = sample(module, audio_batch, self.demo_steps, 1)
# undo the PQMF encoding
fakes = self.pqmf.inverse(fakes.cpu())
try:
log_dict = {}
for i, fake in enumerate(fakes):
filename = f'demo_{trainer.global_step:08}_{i:02}.wav'
fake = self.ms_encoder(fake).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
torchaudio.save(filename, fake, 44100)
log_dict[f'demo_{i}'] = wandb.Audio(filename,
sample_rate=44100,
caption=f'Demo {i}')
trainer.logger.experiment.log(log_dict, step=trainer.global_step)
except Exception as e:
print(f'{type(e).__name__}: {e}', file=sys.stderr)
class ExceptionCallback(pl.Callback):
def on_exception(self, trainer, module, err):
print(f'{type(err).__name__}: {err}', file=sys.stderr)
def main():
p = argparse.ArgumentParser()
p.add_argument('--training-dir', type=Path, required=True,
help='the training data directory')
p.add_argument('--name', type=str, required=True,
help='the name of the run')
p.add_argument('--demo-dir', type=Path, required=True,
help='path to a directory with audio files for demos')
p.add_argument('--num-workers', type=int, default=2,
help='number of CPU workers for the DataLoader')
p.add_argument('--batch-size', type=int, default=8,
help='number of audio samples per batch')
p.add_argument('--num-gpus', type=int, default=1,
help='number of GPUs to use for training')
p.add_argument('--pqmf-bands', type=int, default=8,
help='number of sub-bands for the PQMF filter')
p.add_argument('--sample-rate', type=int, default=48000,
help='The sample rate of the audio')
p.add_argument('--sample-size', type=int, default=16384,
help='Number of samples to train on, must be a multiple of 16384')
p.add_argument('--demo-every', type=int, default=1000,
help='Number of steps between demos')
p.add_argument('--demo-steps', type=int, default=500,
help='Number of denoising steps for the demos')
p.add_argument('--checkpoint-every', type=int, default=20000,
help='Number of steps between checkpoints')
p.add_argument('--data-repeats', type=int, default=1,
help='Number of times to repeat the dataset. Useful to lengthen epochs on small datasets')
p.add_argument('--style-latent-size', type=int, default=512,
help='Size of the style latents')
p.add_argument('--accum-batches', type=int, default=8,
help='Batches for gradient accumulation')
args = p.parse_args()
train_set = SampleDataset([args.training_dir], args)
train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True,
num_workers=args.num_workers, persistent_workers=True, pin_memory=True)
wandb_logger = pl.loggers.WandbLogger(project=args.name)
exc_callback = ExceptionCallback()
encoder = AudioPerceiverEncoder(args)
#TODO: Get pretrained encoder
ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, save_top_k=-1)
demo_callback = DemoCallback(args)
diffusion_model = LightningDiffusion(encoder, args)
wandb_logger.watch(diffusion_model.diffusion)
diffusion_trainer = pl.Trainer(
gpus=args.num_gpus,
strategy='ddp',
precision=16,
accumulate_grad_batches={0:1, 1:args.accum_batches}, #Start without accumulation
callbacks=[ckpt_callback, demo_callback, exc_callback],
logger=wandb_logger,
log_every_n_steps=1,
max_epochs=10000000,
)
diffusion_trainer.fit(diffusion_model, train_dl)
if __name__ == '__main__':
main()