Skip to content

Commit

Permalink
Signal domain, with pitch gating
Browse files Browse the repository at this point in the history
  • Loading branch information
jmvalin committed Oct 7, 2023
1 parent 95c1416 commit 69c9b34
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
23 changes: 15 additions & 8 deletions dnn/torch/fargan/adv_train_fargan.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def fmap_loss(scores_real, scores_gen):

spect_loss = MultiResolutionSTFTLoss(device).to(device)

for param in model.parameters():
param.requires_grad = False

if __name__ == '__main__':
model.to(device)
disc.to(device)
Expand All @@ -153,24 +156,28 @@ def fmap_loss(scores_real, scores_gen):
print(f"training epoch {epoch}...")
with tqdm.tqdm(dataloader, unit='batch') as tepoch:
for i, (features, periods, target, lpc) in enumerate(tepoch):
if epoch == 1 and i == 100:
for param in model.parameters():
param.requires_grad = True

optimizer.zero_grad()
features = features.to(device)
lpc = lpc.to(device)
lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
lpc = fargan.interp_lpc(lpc, 4)
#lpc = lpc.to(device)
#lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
#lpc = fargan.interp_lpc(lpc, 4)
periods = periods.to(device)
if True:
target = target[:, :sequence_length*160]
lpc = lpc[:,:sequence_length*4,:]
#lpc = lpc[:,:sequence_length*4,:]
features = features[:,:sequence_length+4,:]
periods = periods[:,:sequence_length+4]
else:
target=target[::2, :]
lpc=lpc[::2,:]
#lpc=lpc[::2,:]
features=features[::2,:]
periods=periods[::2,:]
target = target.to(device)
target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
#target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)

#nb_pre = random.randrange(1, 6)
nb_pre = 2
Expand Down Expand Up @@ -210,15 +217,15 @@ def fmap_loss(scores_real, scores_gen):

cont_loss = fargan.sig_loss(target[:, nb_pre*160:nb_pre*160+80], output[:, nb_pre*160:nb_pre*160+80])
specc_loss = spect_loss(output, target.detach())
reg_loss = args.reg_weight * (.00*cont_loss + specc_loss)
reg_loss = (.00*cont_loss + specc_loss)

loss_gen = 0
for scale in scores_gen:
loss_gen += ((1 - scale[-1]) ** 2).mean() / len(scores_gen)

feat_loss = args.fmap_weight * fmap_loss(scores_real, scores_gen)

gen_loss = reg_loss + feat_loss + loss_gen
gen_loss = args.reg_weight * reg_loss + feat_loss + loss_gen

model.zero_grad()

Expand Down
4 changes: 3 additions & 1 deletion dnn/torch/fargan/test_fargan.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def inverse_perceptual_weighting40 (pw_signal, filters):
buffer[:] = out_sig_frame[-16:]
return signal

from scipy.signal import lfilter

if __name__ == '__main__':
model.to(device)
Expand All @@ -121,7 +122,8 @@ def inverse_perceptual_weighting40 (pw_signal, filters):
sig, _ = model(features, periods, nb_frames - 4)
#weighting_vector = np.array([gamma**i for i in range(16,0,-1)])
sig = sig.detach().numpy().flatten()
sig = inverse_perceptual_weighting40(sig, lpc[0,:,:])
sig = lfilter(np.array([1.]), np.array([1., -.85]), sig)
#sig = inverse_perceptual_weighting40(sig, lpc[0,:,:])

pcm = np.round(32768*np.clip(sig, a_max=.99, a_min=-.99)).astype('int16')
pcm.tofile(signal_file)
12 changes: 6 additions & 6 deletions dnn/torch/fargan/train_fargan.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,23 +116,23 @@
features = features.to(device)
#lpc = torch.tensor(fargan.interp_lpc(lpc.numpy(), 4))
#print("interp size", lpc.shape)
lpc = lpc.to(device)
lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
lpc = fargan.interp_lpc(lpc, 4)
#lpc = lpc.to(device)
#lpc = lpc*(args.gamma**torch.arange(1,17, device=device))
#lpc = fargan.interp_lpc(lpc, 4)
periods = periods.to(device)
if (np.random.rand() > 0.1):
target = target[:, :sequence_length*160]
lpc = lpc[:,:sequence_length*4,:]
#lpc = lpc[:,:sequence_length*4,:]
features = features[:,:sequence_length+4,:]
periods = periods[:,:sequence_length+4]
else:
target=target[::2, :]
lpc=lpc[::2,:]
#lpc=lpc[::2,:]
features=features[::2,:]
periods=periods[::2,:]
target = target.to(device)
#print(target.shape, lpc.shape)
target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)
#target = fargan.analysis_filter(target, lpc[:,:,:], nb_subframes=1, gamma=args.gamma)

#nb_pre = random.randrange(1, 6)
nb_pre = 2
Expand Down

0 comments on commit 69c9b34

Please sign in to comment.