Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add basic optional support for global style token module #100

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fs2/attn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def forward(
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2
# compute log likelihood from a gaussian
attn = -0.0005 * attn.sum(1, keepdim=True)
if attn_prior is not None:
if torch.is_tensor(attn_prior):
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8)

attn_logprob = attn.clone()
Expand Down
43 changes: 42 additions & 1 deletion fs2/cli/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def prepare_data(
model: Any,
text_representation: DatasetTextRepresentation,
duration_control: float,
style_reference: Path | None,
) -> list[dict[str, Any]]:
""""""
from everyvoice.utils import slugify
Expand Down Expand Up @@ -154,9 +155,36 @@ def prepare_data(
multi=model.config.model.multispeaker,
)

# We only allow a single style reference right now, so it's fine to load it once here.
if style_reference:
from everyvoice.utils.heavy import get_spectral_transform

spectral_transform = get_spectral_transform(
model.config.preprocessing.audio.spec_type,
model.config.preprocessing.audio.n_fft,
model.config.preprocessing.audio.fft_window_size,
model.config.preprocessing.audio.fft_hop_size,
f_min=model.config.preprocessing.audio.f_min,
f_max=model.config.preprocessing.audio.f_max,
sample_rate=model.config.preprocessing.audio.output_sampling_rate,
n_mels=model.config.preprocessing.audio.n_mels,
)
import torchaudio

style_reference_audio, style_reference_sr = torchaudio.load(style_reference)
if style_reference_sr != model.config.preprocessing.audio.input_sampling_rate:
style_reference_audio = torchaudio.functional.resample(
style_reference_audio,
style_reference_sr,
model.config.preprocessing.audio.input_sampling_rate,
)
style_reference_spec = spectral_transform(style_reference_audio)
# Add duration_control
for item in data:
item["duration_control"] = duration_control
# Add style reference
if style_reference:
item["mel_style_reference"] = style_reference_spec

return data

Expand All @@ -175,6 +203,7 @@ def get_global_step(model_path: Path) -> int:
def synthesize_helper(
model,
texts: list[str],
style_reference: Optional[Path],
language: Optional[str],
speaker: Optional[str],
duration_control: Optional[float],
Expand Down Expand Up @@ -227,8 +256,8 @@ def synthesize_helper(
filelist=filelist,
model=model,
text_representation=text_representation,
style_reference=style_reference,
)

from pytorch_lightning import Trainer

from ..prediction_writing_callback import get_synthesis_output_callbacks
Expand Down Expand Up @@ -269,6 +298,7 @@ def synthesize_helper(
model.lang2id,
model.speaker2id,
teacher_forcing=teacher_forcing,
style_reference=style_reference is not None,
),
return_predictions=True,
),
Expand Down Expand Up @@ -312,6 +342,16 @@ def synthesize( # noqa: C901
"-D",
help="Control the speaking rate of the synthesis. Set a value to multily the durations by, lower numbers produce quicker speaking rates, larger numbers produce slower speaking rates. Default is 1.0",
),
style_reference: Optional[Path] = typer.Option(
None,
"--style-reference",
"-S",
exists=True,
file_okay=True,
dir_okay=False,
help="The path to an audio file containing a style reference. Your text-to-spec must have been trained with the global style token module to use this feature.",
shell_complete=complete_path,
),
speaker: Optional[str] = typer.Option(
None,
"--speaker",
Expand Down Expand Up @@ -454,6 +494,7 @@ def synthesize( # noqa: C901
return synthesize_helper(
model=model,
texts=texts,
style_reference=style_reference,
language=language,
speaker=speaker,
duration_control=duration_control,
Expand Down
4 changes: 4 additions & 0 deletions fs2/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ class FastSpeech2ModelConfig(ConfigModel):
True,
description="Whether to jointly learn alignments using monotonic alignment search module (See Badlani et. al. 2021: https://arxiv.org/abs/2108.10447). If set to False, you will have to provide text/audio alignments separately before training a text-to-spec (feature prediction) model.",
)
use_global_style_token_module: bool = Field(
False,
description="Whether to use the Global Style Token (GST) module from Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis (https://arxiv.org/abs/1803.09017)",
)
max_length: int = Field(
1000, description="The maximum length (i.e. number of symbols) for text inputs."
)
Expand Down
22 changes: 20 additions & 2 deletions fs2/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
speaker2id: LookupTable,
teacher_forcing=False,
inference=False,
style_reference=False,
):
self.dataset = dataset
self.config = config
Expand All @@ -43,6 +44,7 @@ def __init__(
self.preprocessed_dir = Path(self.config.preprocessing.save_dir)
self.sampling_rate = self.config.preprocessing.audio.input_sampling_rate
self.teacher_forcing = teacher_forcing
self.style_reference = style_reference
self.inference = inference
self.lang2id = lang2id
self.speaker2id = speaker2id
Expand All @@ -56,6 +58,7 @@ def __getitem__(self, index):
"""
Returns dict with keys: {
"mel"
"mel_style_reference"
"duration"
"duration_control"
"pfs"
Expand Down Expand Up @@ -103,6 +106,12 @@ def __getitem__(self, index):
) # [mel_bins, frames] -> [frames, mel_bins]
else:
mel = None

if self.style_reference:
mel_style_reference = item["mel_style_reference"].squeeze(0).transpose(0, 1)
else:
mel_style_reference = None

if (
self.teacher_forcing or not self.inference
) and self.config.model.learn_alignment:
Expand Down Expand Up @@ -169,9 +178,9 @@ def __getitem__(self, index):
else:
energy = None
pitch = None

return {
"mel": mel,
"mel_style_reference": mel_style_reference,
"duration": duration,
"duration_control": duration_control,
"pfs": pfs,
Expand Down Expand Up @@ -201,11 +210,13 @@ def __init__(
inference=False,
teacher_forcing=False,
inference_output_dir=Path("synthesis_output"),
style_reference=False,
):
super().__init__(config=config, inference_output_dir=inference_output_dir)
self.inference = inference
self.prepared = False
self.teacher_forcing = teacher_forcing
self.style_reference = style_reference
self.collate_fn = partial(
self.collate_method, learn_alignment=config.model.learn_alignment
)
Expand Down Expand Up @@ -271,6 +282,7 @@ def prepare_data(self):
self.speaker2id,
inference=self.inference,
teacher_forcing=self.teacher_forcing,
style_reference=self.style_reference,
)
torch.save(self.predict_dataset, self.predict_path)
elif not self.prepared:
Expand Down Expand Up @@ -320,8 +332,14 @@ def __init__(
lang2id: LookupTable,
speaker2id: LookupTable,
teacher_forcing: bool = False,
style_reference=False,
):
super().__init__(config=config, inference=True, teacher_forcing=teacher_forcing)
super().__init__(
config=config,
inference=True,
teacher_forcing=teacher_forcing,
style_reference=style_reference,
)
self.inference = True
self.data = data
self.collate_fn = partial(
Expand Down
Empty file added fs2/gst/__init__.py
Empty file.
194 changes: 194 additions & 0 deletions fs2/gst/attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
Comment on lines +4 to +5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the LICENSE files both in this repo and in EV to list this new exception.


"""Multi-Head Attention layer definition."""

import math

import torch
from torch import nn


class LayerNorm(torch.nn.LayerNorm):
"""Layer normalization module.

Args:
nout (int): Output dim size.
dim (int): Dimension to be normalized.

"""

def __init__(self, nout, dim=-1):
"""Construct an LayerNorm object."""
super(LayerNorm, self).__init__(nout, eps=1e-12)
self.dim = dim

def forward(self, x):
"""Apply layer normalization.

Args:
x (torch.Tensor): Input tensor.

Returns:
torch.Tensor: Normalized tensor.

"""
if self.dim == -1:
return super(LayerNorm, self).forward(x)
return (
super(LayerNorm, self)
.forward(x.transpose(self.dim, -1))
.transpose(self.dim, -1)
)


class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.

Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
qk_norm (bool): Normalize q and k before dot product.
use_flash_attn (bool): Use flash_attn implementation.
causal (bool): Apply causal attention.
cross_attn (bool): Cross attention instead of self attention.

"""

def __init__(
self,
n_head,
n_feat,
dropout_rate,
qk_norm=False,
use_flash_attn=False,
causal=False,
cross_attn=False,
):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttention, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = (
nn.Dropout(p=dropout_rate) if not use_flash_attn else nn.Identity()
)
self.dropout_rate = dropout_rate

# LayerNorm for q and k
self.q_norm = LayerNorm(self.d_k) if qk_norm else nn.Identity()
self.k_norm = LayerNorm(self.d_k) if qk_norm else nn.Identity()

self.use_flash_attn = use_flash_attn
self.causal = causal # only used with flash_attn
self.cross_attn = cross_attn # only used with flash_attn

def forward_qkv(self, query, key, value, expand_kv=False):
"""Transform query, key and value.

Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
expand_kv (bool): Used only for partially autoregressive (PAR) decoding.

Returns:
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).

"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)

if expand_kv:
k_shape = key.shape
k = (
self.linear_k(key[:1, :, :])
.expand(n_batch, k_shape[1], k_shape[2])
.view(n_batch, -1, self.h, self.d_k)
)
v_shape = value.shape
v = (
self.linear_v(value[:1, :, :])
.expand(n_batch, v_shape[1], v_shape[2])
.view(n_batch, -1, self.h, self.d_k)
)
else:
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)

q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)

q = self.q_norm(q)
k = self.k_norm(k)

return q, k, v

def forward_attention(self, value, scores, mask):
"""Compute attention context vector.

Args:
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).

Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).

"""
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
min_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)

p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)

return self.linear_out(x) # (batch, time1, d_model)

def forward(self, query, key, value, mask, expand_kv=False):
"""Compute scaled dot product attention.

Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
expand_kv (bool): Used only for partially autoregressive (PAR) decoding.
When set to `True`, `Linear` layers are computed only for the first batch.
This is useful to reduce the memory usage during decoding when the batch size is
#beam_size x #mask_count, which can be very large. Typically, in single waveform
inference of PAR, `Linear` layers should not be computed for all batches
for source-attention.

Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).

"""
q, k, v = self.forward_qkv(query, key, value, expand_kv)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
Loading