Skip to content

Commit

Permalink
feat: add basic optional support for global style token module
Browse files Browse the repository at this point in the history
  • Loading branch information
roedoejet committed Nov 23, 2024
1 parent 739cedd commit 3bdfd24
Show file tree
Hide file tree
Showing 8 changed files with 554 additions and 6 deletions.
2 changes: 1 addition & 1 deletion fs2/attn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def forward(
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2
# compute log likelihood from a gaussian
attn = -0.0005 * attn.sum(1, keepdim=True)
if attn_prior is not None:
if torch.is_tensor(attn_prior):
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8)

attn_logprob = attn.clone()
Expand Down
43 changes: 42 additions & 1 deletion fs2/cli/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def prepare_data(
model: Any,
text_representation: DatasetTextRepresentation,
duration_control: float,
style_reference: Path | None,
) -> list[dict[str, Any]]:
""""""
from everyvoice.utils import slugify
Expand Down Expand Up @@ -154,9 +155,36 @@ def prepare_data(
multi=model.config.model.multispeaker,
)

# We only allow a single style reference right now, so it's fine to load it once here.
if style_reference:
from everyvoice.utils.heavy import get_spectral_transform

spectral_transform = get_spectral_transform(
model.config.preprocessing.audio.spec_type,
model.config.preprocessing.audio.n_fft,
model.config.preprocessing.audio.fft_window_size,
model.config.preprocessing.audio.fft_hop_size,
f_min=model.config.preprocessing.audio.f_min,
f_max=model.config.preprocessing.audio.f_max,
sample_rate=model.config.preprocessing.audio.output_sampling_rate,
n_mels=model.config.preprocessing.audio.n_mels,
)
import torchaudio

style_reference_audio, style_reference_sr = torchaudio.load(style_reference)
if style_reference_sr != model.config.preprocessing.audio.input_sampling_rate:
style_reference_audio = torchaudio.functional.resample(
style_reference_audio,
style_reference_sr,
model.config.preprocessing.audio.input_sampling_rate,
)
style_reference_spec = spectral_transform(style_reference_audio)
# Add duration_control
for item in data:
item["duration_control"] = duration_control
# Add style reference
if style_reference:
item["mel_style_reference"] = style_reference_spec

return data

Expand All @@ -175,6 +203,7 @@ def get_global_step(model_path: Path) -> int:
def synthesize_helper(
model,
texts: list[str],
style_reference: Optional[Path],
language: Optional[str],
speaker: Optional[str],
duration_control: Optional[float],
Expand Down Expand Up @@ -227,8 +256,8 @@ def synthesize_helper(
filelist=filelist,
model=model,
text_representation=text_representation,
style_reference=style_reference,
)

from pytorch_lightning import Trainer

from ..prediction_writing_callback import get_synthesis_output_callbacks
Expand Down Expand Up @@ -269,6 +298,7 @@ def synthesize_helper(
model.lang2id,
model.speaker2id,
teacher_forcing=teacher_forcing,
style_reference=style_reference is not None,
),
return_predictions=True,
),
Expand Down Expand Up @@ -312,6 +342,16 @@ def synthesize( # noqa: C901
"-D",
help="Control the speaking rate of the synthesis. Set a value to multily the durations by, lower numbers produce quicker speaking rates, larger numbers produce slower speaking rates. Default is 1.0",
),
style_reference: Optional[Path] = typer.Option(
None,
"--style-reference",
"-S",
exists=True,
file_okay=True,
dir_okay=False,
help="The path to an audio file containing a style reference. Your text-to-spec must have been trained with the global style token module to use this feature.",
autocompletion=complete_path,
),
speaker: Optional[str] = typer.Option(
None,
"--speaker",
Expand Down Expand Up @@ -454,6 +494,7 @@ def synthesize( # noqa: C901
return synthesize_helper(
model=model,
texts=texts,
style_reference=style_reference,
language=language,
speaker=speaker,
duration_control=duration_control,
Expand Down
4 changes: 4 additions & 0 deletions fs2/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ class FastSpeech2ModelConfig(ConfigModel):
True,
description="Whether to jointly learn alignments using monotonic alignment search module (See Badlani et. al. 2021: https://arxiv.org/abs/2108.10447). If set to False, you will have to provide text/audio alignments separately before training a text-to-spec (feature prediction) model.",
)
use_global_style_token_module: bool = Field(
False,
description="Whether to use the Global Style Token (GST) module from Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis (https://arxiv.org/abs/1803.09017)",
)
max_length: int = Field(
1000, description="The maximum length (i.e. number of symbols) for text inputs."
)
Expand Down
22 changes: 20 additions & 2 deletions fs2/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
speaker2id: LookupTable,
teacher_forcing=False,
inference=False,
style_reference=False,
):
self.dataset = dataset
self.config = config
Expand All @@ -43,6 +44,7 @@ def __init__(
self.preprocessed_dir = Path(self.config.preprocessing.save_dir)
self.sampling_rate = self.config.preprocessing.audio.input_sampling_rate
self.teacher_forcing = teacher_forcing
self.style_reference = style_reference
self.inference = inference
self.lang2id = lang2id
self.speaker2id = speaker2id
Expand All @@ -56,6 +58,7 @@ def __getitem__(self, index):
"""
Returns dict with keys: {
"mel"
"mel_style_reference"
"duration"
"duration_control"
"pfs"
Expand Down Expand Up @@ -103,6 +106,12 @@ def __getitem__(self, index):
) # [mel_bins, frames] -> [frames, mel_bins]
else:
mel = None

if self.style_reference:
mel_style_reference = item["mel_style_reference"].squeeze(0).transpose(0, 1)
else:
mel_style_reference = None

if (
self.teacher_forcing or not self.inference
) and self.config.model.learn_alignment:
Expand Down Expand Up @@ -169,9 +178,9 @@ def __getitem__(self, index):
else:
energy = None
pitch = None

return {
"mel": mel,
"mel_style_reference": mel_style_reference,
"duration": duration,
"duration_control": duration_control,
"pfs": pfs,
Expand Down Expand Up @@ -201,11 +210,13 @@ def __init__(
inference=False,
teacher_forcing=False,
inference_output_dir=Path("synthesis_output"),
style_reference=False,
):
super().__init__(config=config, inference_output_dir=inference_output_dir)
self.inference = inference
self.prepared = False
self.teacher_forcing = teacher_forcing
self.style_reference = style_reference
self.collate_fn = partial(
self.collate_method, learn_alignment=config.model.learn_alignment
)
Expand Down Expand Up @@ -271,6 +282,7 @@ def prepare_data(self):
self.speaker2id,
inference=self.inference,
teacher_forcing=self.teacher_forcing,
style_reference=self.style_reference,
)
torch.save(self.predict_dataset, self.predict_path)
elif not self.prepared:
Expand Down Expand Up @@ -320,8 +332,14 @@ def __init__(
lang2id: LookupTable,
speaker2id: LookupTable,
teacher_forcing: bool = False,
style_reference=False,
):
super().__init__(config=config, inference=True, teacher_forcing=teacher_forcing)
super().__init__(
config=config,
inference=True,
teacher_forcing=teacher_forcing,
style_reference=style_reference,
)
self.inference = True
self.data = data
self.collate_fn = partial(
Expand Down
Empty file added fs2/gst/__init__.py
Empty file.
194 changes: 194 additions & 0 deletions fs2/gst/attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)

"""Multi-Head Attention layer definition."""

import math

import torch
from torch import nn


class LayerNorm(torch.nn.LayerNorm):
"""Layer normalization module.
Args:
nout (int): Output dim size.
dim (int): Dimension to be normalized.
"""

def __init__(self, nout, dim=-1):
"""Construct an LayerNorm object."""
super(LayerNorm, self).__init__(nout, eps=1e-12)
self.dim = dim

def forward(self, x):
"""Apply layer normalization.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Normalized tensor.
"""
if self.dim == -1:
return super(LayerNorm, self).forward(x)
return (
super(LayerNorm, self)
.forward(x.transpose(self.dim, -1))
.transpose(self.dim, -1)
)


class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
qk_norm (bool): Normalize q and k before dot product.
use_flash_attn (bool): Use flash_attn implementation.
causal (bool): Apply causal attention.
cross_attn (bool): Cross attention instead of self attention.
"""

def __init__(
self,
n_head,
n_feat,
dropout_rate,
qk_norm=False,
use_flash_attn=False,
causal=False,
cross_attn=False,
):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttention, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = (
nn.Dropout(p=dropout_rate) if not use_flash_attn else nn.Identity()
)
self.dropout_rate = dropout_rate

# LayerNorm for q and k
self.q_norm = LayerNorm(self.d_k) if qk_norm else nn.Identity()
self.k_norm = LayerNorm(self.d_k) if qk_norm else nn.Identity()

self.use_flash_attn = use_flash_attn
self.causal = causal # only used with flash_attn
self.cross_attn = cross_attn # only used with flash_attn

def forward_qkv(self, query, key, value, expand_kv=False):
"""Transform query, key and value.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
expand_kv (bool): Used only for partially autoregressive (PAR) decoding.
Returns:
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)

if expand_kv:
k_shape = key.shape
k = (
self.linear_k(key[:1, :, :])
.expand(n_batch, k_shape[1], k_shape[2])
.view(n_batch, -1, self.h, self.d_k)
)
v_shape = value.shape
v = (
self.linear_v(value[:1, :, :])
.expand(n_batch, v_shape[1], v_shape[2])
.view(n_batch, -1, self.h, self.d_k)
)
else:
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)

q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)

q = self.q_norm(q)
k = self.k_norm(k)

return q, k, v

def forward_attention(self, value, scores, mask):
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
min_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)

p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)

return self.linear_out(x) # (batch, time1, d_model)

def forward(self, query, key, value, mask, expand_kv=False):
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
expand_kv (bool): Used only for partially autoregressive (PAR) decoding.
When set to `True`, `Linear` layers are computed only for the first batch.
This is useful to reduce the memory usage during decoding when the batch size is
#beam_size x #mask_count, which can be very large. Typically, in single waveform
inference of PAR, `Linear` layers should not be computed for all batches
for source-attention.
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
q, k, v = self.forward_qkv(query, key, value, expand_kv)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
Loading

0 comments on commit 3bdfd24

Please sign in to comment.