From 3bdfd2409f7aa1bc5bc190687f90c9e594af5556 Mon Sep 17 00:00:00 2001 From: Aidan Pine Date: Sat, 23 Nov 2024 00:15:29 +0000 Subject: [PATCH] feat: add basic optional support for global style token module --- fs2/attn/attention.py | 2 +- fs2/cli/synthesize.py | 43 ++++++- fs2/config/__init__.py | 4 + fs2/dataset.py | 22 +++- fs2/gst/__init__.py | 0 fs2/gst/attn.py | 194 ++++++++++++++++++++++++++++ fs2/gst/model.py | 280 +++++++++++++++++++++++++++++++++++++++++ fs2/model.py | 15 ++- 8 files changed, 554 insertions(+), 6 deletions(-) create mode 100644 fs2/gst/__init__.py create mode 100644 fs2/gst/attn.py create mode 100644 fs2/gst/model.py diff --git a/fs2/attn/attention.py b/fs2/attn/attention.py index 4532651..a6bf8ab 100644 --- a/fs2/attn/attention.py +++ b/fs2/attn/attention.py @@ -239,7 +239,7 @@ def forward( attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # compute log likelihood from a gaussian attn = -0.0005 * attn.sum(1, keepdim=True) - if attn_prior is not None: + if torch.is_tensor(attn_prior): attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8) attn_logprob = attn.clone() diff --git a/fs2/cli/synthesize.py b/fs2/cli/synthesize.py index bb880db..a139a63 100644 --- a/fs2/cli/synthesize.py +++ b/fs2/cli/synthesize.py @@ -73,6 +73,7 @@ def prepare_data( model: Any, text_representation: DatasetTextRepresentation, duration_control: float, + style_reference: Path | None, ) -> list[dict[str, Any]]: """""" from everyvoice.utils import slugify @@ -154,9 +155,36 @@ def prepare_data( multi=model.config.model.multispeaker, ) + # We only allow a single style reference right now, so it's fine to load it once here. + if style_reference: + from everyvoice.utils.heavy import get_spectral_transform + + spectral_transform = get_spectral_transform( + model.config.preprocessing.audio.spec_type, + model.config.preprocessing.audio.n_fft, + model.config.preprocessing.audio.fft_window_size, + model.config.preprocessing.audio.fft_hop_size, + f_min=model.config.preprocessing.audio.f_min, + f_max=model.config.preprocessing.audio.f_max, + sample_rate=model.config.preprocessing.audio.output_sampling_rate, + n_mels=model.config.preprocessing.audio.n_mels, + ) + import torchaudio + + style_reference_audio, style_reference_sr = torchaudio.load(style_reference) + if style_reference_sr != model.config.preprocessing.audio.input_sampling_rate: + style_reference_audio = torchaudio.functional.resample( + style_reference_audio, + style_reference_sr, + model.config.preprocessing.audio.input_sampling_rate, + ) + style_reference_spec = spectral_transform(style_reference_audio) # Add duration_control for item in data: item["duration_control"] = duration_control + # Add style reference + if style_reference: + item["mel_style_reference"] = style_reference_spec return data @@ -175,6 +203,7 @@ def get_global_step(model_path: Path) -> int: def synthesize_helper( model, texts: list[str], + style_reference: Optional[Path], language: Optional[str], speaker: Optional[str], duration_control: Optional[float], @@ -227,8 +256,8 @@ def synthesize_helper( filelist=filelist, model=model, text_representation=text_representation, + style_reference=style_reference, ) - from pytorch_lightning import Trainer from ..prediction_writing_callback import get_synthesis_output_callbacks @@ -269,6 +298,7 @@ def synthesize_helper( model.lang2id, model.speaker2id, teacher_forcing=teacher_forcing, + style_reference=style_reference is not None, ), return_predictions=True, ), @@ -312,6 +342,16 @@ def synthesize( # noqa: C901 "-D", help="Control the speaking rate of the synthesis. Set a value to multily the durations by, lower numbers produce quicker speaking rates, larger numbers produce slower speaking rates. Default is 1.0", ), + style_reference: Optional[Path] = typer.Option( + None, + "--style-reference", + "-S", + exists=True, + file_okay=True, + dir_okay=False, + help="The path to an audio file containing a style reference. Your text-to-spec must have been trained with the global style token module to use this feature.", + autocompletion=complete_path, + ), speaker: Optional[str] = typer.Option( None, "--speaker", @@ -454,6 +494,7 @@ def synthesize( # noqa: C901 return synthesize_helper( model=model, texts=texts, + style_reference=style_reference, language=language, speaker=speaker, duration_control=duration_control, diff --git a/fs2/config/__init__.py b/fs2/config/__init__.py index 53514b6..9de8569 100644 --- a/fs2/config/__init__.py +++ b/fs2/config/__init__.py @@ -135,6 +135,10 @@ class FastSpeech2ModelConfig(ConfigModel): True, description="Whether to jointly learn alignments using monotonic alignment search module (See Badlani et. al. 2021: https://arxiv.org/abs/2108.10447). If set to False, you will have to provide text/audio alignments separately before training a text-to-spec (feature prediction) model.", ) + use_global_style_token_module: bool = Field( + False, + description="Whether to use the Global Style Token (GST) module from Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis (https://arxiv.org/abs/1803.09017)", + ) max_length: int = Field( 1000, description="The maximum length (i.e. number of symbols) for text inputs." ) diff --git a/fs2/dataset.py b/fs2/dataset.py index 7281ff6..b6f2f78 100644 --- a/fs2/dataset.py +++ b/fs2/dataset.py @@ -35,6 +35,7 @@ def __init__( speaker2id: LookupTable, teacher_forcing=False, inference=False, + style_reference=False, ): self.dataset = dataset self.config = config @@ -43,6 +44,7 @@ def __init__( self.preprocessed_dir = Path(self.config.preprocessing.save_dir) self.sampling_rate = self.config.preprocessing.audio.input_sampling_rate self.teacher_forcing = teacher_forcing + self.style_reference = style_reference self.inference = inference self.lang2id = lang2id self.speaker2id = speaker2id @@ -56,6 +58,7 @@ def __getitem__(self, index): """ Returns dict with keys: { "mel" + "mel_style_reference" "duration" "duration_control" "pfs" @@ -103,6 +106,12 @@ def __getitem__(self, index): ) # [mel_bins, frames] -> [frames, mel_bins] else: mel = None + + if self.style_reference: + mel_style_reference = item["mel_style_reference"].squeeze(0).transpose(0, 1) + else: + mel_style_reference = None + if ( self.teacher_forcing or not self.inference ) and self.config.model.learn_alignment: @@ -169,9 +178,9 @@ def __getitem__(self, index): else: energy = None pitch = None - return { "mel": mel, + "mel_style_reference": mel_style_reference, "duration": duration, "duration_control": duration_control, "pfs": pfs, @@ -201,11 +210,13 @@ def __init__( inference=False, teacher_forcing=False, inference_output_dir=Path("synthesis_output"), + style_reference=False, ): super().__init__(config=config, inference_output_dir=inference_output_dir) self.inference = inference self.prepared = False self.teacher_forcing = teacher_forcing + self.style_reference = style_reference self.collate_fn = partial( self.collate_method, learn_alignment=config.model.learn_alignment ) @@ -271,6 +282,7 @@ def prepare_data(self): self.speaker2id, inference=self.inference, teacher_forcing=self.teacher_forcing, + style_reference=self.style_reference, ) torch.save(self.predict_dataset, self.predict_path) elif not self.prepared: @@ -320,8 +332,14 @@ def __init__( lang2id: LookupTable, speaker2id: LookupTable, teacher_forcing: bool = False, + style_reference=False, ): - super().__init__(config=config, inference=True, teacher_forcing=teacher_forcing) + super().__init__( + config=config, + inference=True, + teacher_forcing=teacher_forcing, + style_reference=style_reference, + ) self.inference = True self.data = data self.collate_fn = partial( diff --git a/fs2/gst/__init__.py b/fs2/gst/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/fs2/gst/attn.py b/fs2/gst/attn.py new file mode 100644 index 0000000..f6109c5 --- /dev/null +++ b/fs2/gst/attn.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Multi-Head Attention layer definition.""" + +import math + +import torch +from torch import nn + + +class LayerNorm(torch.nn.LayerNorm): + """Layer normalization module. + + Args: + nout (int): Output dim size. + dim (int): Dimension to be normalized. + + """ + + def __init__(self, nout, dim=-1): + """Construct an LayerNorm object.""" + super(LayerNorm, self).__init__(nout, eps=1e-12) + self.dim = dim + + def forward(self, x): + """Apply layer normalization. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Normalized tensor. + + """ + if self.dim == -1: + return super(LayerNorm, self).forward(x) + return ( + super(LayerNorm, self) + .forward(x.transpose(self.dim, -1)) + .transpose(self.dim, -1) + ) + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + qk_norm (bool): Normalize q and k before dot product. + use_flash_attn (bool): Use flash_attn implementation. + causal (bool): Apply causal attention. + cross_attn (bool): Cross attention instead of self attention. + + """ + + def __init__( + self, + n_head, + n_feat, + dropout_rate, + qk_norm=False, + use_flash_attn=False, + causal=False, + cross_attn=False, + ): + """Construct an MultiHeadedAttention object.""" + super(MultiHeadedAttention, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = ( + nn.Dropout(p=dropout_rate) if not use_flash_attn else nn.Identity() + ) + self.dropout_rate = dropout_rate + + # LayerNorm for q and k + self.q_norm = LayerNorm(self.d_k) if qk_norm else nn.Identity() + self.k_norm = LayerNorm(self.d_k) if qk_norm else nn.Identity() + + self.use_flash_attn = use_flash_attn + self.causal = causal # only used with flash_attn + self.cross_attn = cross_attn # only used with flash_attn + + def forward_qkv(self, query, key, value, expand_kv=False): + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + expand_kv (bool): Used only for partially autoregressive (PAR) decoding. + + Returns: + torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + + if expand_kv: + k_shape = key.shape + k = ( + self.linear_k(key[:1, :, :]) + .expand(n_batch, k_shape[1], k_shape[2]) + .view(n_batch, -1, self.h, self.d_k) + ) + v_shape = value.shape + v = ( + self.linear_v(value[:1, :, :]) + .expand(n_batch, v_shape[1], v_shape[2]) + .view(n_batch, -1, self.h, self.d_k) + ) + else: + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + q = self.q_norm(q) + k = self.k_norm(k) + + return q, k, v + + def forward_attention(self, value, scores, mask): + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + min_value = torch.finfo(scores.dtype).min + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0 + ) # (batch, head, time1, time2) + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query, key, value, mask, expand_kv=False): + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + expand_kv (bool): Used only for partially autoregressive (PAR) decoding. + When set to `True`, `Linear` layers are computed only for the first batch. + This is useful to reduce the memory usage during decoding when the batch size is + #beam_size x #mask_count, which can be very large. Typically, in single waveform + inference of PAR, `Linear` layers should not be computed for all batches + for source-attention. + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q, k, v = self.forward_qkv(query, key, value, expand_kv) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask) diff --git a/fs2/gst/model.py b/fs2/gst/model.py new file mode 100644 index 0000000..498cf0f --- /dev/null +++ b/fs2/gst/model.py @@ -0,0 +1,280 @@ +# Copyright 2020 Nagoya University (Tomoki Hayashi) +# Sourced from ESPNet2 +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Style encoder of GST-Tacotron.""" + +from typing import Sequence + +import torch + +from .attn import MultiHeadedAttention as BaseMultiHeadedAttention + + +class StyleEncoder(torch.nn.Module): + """Style encoder. + + This module is style encoder introduced in `Style Tokens: Unsupervised Style + Modeling, Control and Transfer in End-to-End Speech Synthesis`. + + .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End + Speech Synthesis`: https://arxiv.org/abs/1803.09017 + + Args: + idim (int, optional): Dimension of the input mel-spectrogram. + gst_tokens (int, optional): The number of GST embeddings. + gst_token_dim (int, optional): Dimension of each GST embedding. + gst_heads (int, optional): The number of heads in GST multihead attention. + conv_layers (int, optional): The number of conv layers in the reference encoder. + conv_chans_list: (Sequence[int], optional): + List of the number of channels of conv layers in the referece encoder. + conv_kernel_size (int, optional): + Kernel size of conv layers in the reference encoder. + conv_stride (int, optional): + Stride size of conv layers in the reference encoder. + gru_layers (int, optional): The number of GRU layers in the reference encoder. + gru_units (int, optional): The number of GRU units in the reference encoder. + + Todo: + * Support manual weight specification in inference. + + """ + + def __init__( + self, + idim: int = 80, + gst_tokens: int = 10, + gst_token_dim: int = 256, + gst_heads: int = 4, + conv_layers: int = 6, + conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), + conv_kernel_size: int = 3, + conv_stride: int = 2, + gru_layers: int = 1, + gru_units: int = 128, + ): + """Initilize global style encoder module.""" + super(StyleEncoder, self).__init__() + self.gst_tokens = gst_tokens + self.gst_heads = gst_heads + self.gst_token_dim = gst_token_dim + self.ref_enc = ReferenceEncoder( + idim=idim, + conv_layers=conv_layers, + conv_chans_list=conv_chans_list, + conv_kernel_size=conv_kernel_size, + conv_stride=conv_stride, + gru_layers=gru_layers, + gru_units=gru_units, + ) + self.stl = StyleTokenLayer( + ref_embed_dim=gru_units, + gst_tokens=gst_tokens, + gst_token_dim=gst_token_dim, + gst_heads=gst_heads, + ) + + def condition_on_gst_tokens(self, batch_size, index=0): + if index >= self.gst_tokens: + raise ValueError( + f"We can only synthesize by conditioning on one of {self.gst_tokens} GST tokens" + ) + GST = torch.tanh(self.stl.gst_embs) + query = torch.zeros(batch_size, 1, self.gst_token_dim // 2).to(GST.device) + keys = GST[index].unsqueeze(0).expand(batch_size, -1, -1) + return self.stl.mha(query, keys, keys, None).squeeze(1) + + def forward(self, speech: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + speech (Tensor): Batch of padded target features (B, Lmax, odim). + + Returns: + Tensor: Style token embeddings (B, token_dim). + + """ + ref_embs = self.ref_enc(speech) + style_embs = self.stl(ref_embs) + + return style_embs + + +class ReferenceEncoder(torch.nn.Module): + """Reference encoder module. + + This module is reference encoder introduced in `Style Tokens: Unsupervised Style + Modeling, Control and Transfer in End-to-End Speech Synthesis`. + + .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End + Speech Synthesis`: https://arxiv.org/abs/1803.09017 + + Args: + idim (int, optional): Dimension of the input mel-spectrogram. + conv_layers (int, optional): The number of conv layers in the reference encoder. + conv_chans_list: (Sequence[int], optional): + List of the number of channels of conv layers in the referece encoder. + conv_kernel_size (int, optional): + Kernel size of conv layers in the reference encoder. + conv_stride (int, optional): + Stride size of conv layers in the reference encoder. + gru_layers (int, optional): The number of GRU layers in the reference encoder. + gru_units (int, optional): The number of GRU units in the reference encoder. + + """ + + def __init__( + self, + idim=80, + conv_layers: int = 6, + conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), + conv_kernel_size: int = 3, + conv_stride: int = 2, + gru_layers: int = 1, + gru_units: int = 128, + ): + """Initilize reference encoder module.""" + super(ReferenceEncoder, self).__init__() + + # check hyperparameters are valid + assert conv_kernel_size % 2 == 1, "kernel size must be odd." + assert ( + len(conv_chans_list) == conv_layers + ), "the number of conv layers and length of channels list must be the same." + + convs = [] + padding = (conv_kernel_size - 1) // 2 + for i in range(conv_layers): + conv_in_chans = 1 if i == 0 else conv_chans_list[i - 1] + conv_out_chans = conv_chans_list[i] + convs += [ + torch.nn.Conv2d( + conv_in_chans, + conv_out_chans, + kernel_size=conv_kernel_size, + stride=conv_stride, + padding=padding, + # Do not use bias due to the following batch norm + bias=False, + ), + torch.nn.BatchNorm2d(conv_out_chans), + torch.nn.ReLU(inplace=True), + ] + self.convs = torch.nn.Sequential(*convs) + + self.conv_layers = conv_layers + self.kernel_size = conv_kernel_size + self.stride = conv_stride + self.padding = padding + + # get the number of GRU input units + gru_in_units = idim + for i in range(conv_layers): + gru_in_units = ( + gru_in_units - conv_kernel_size + 2 * padding + ) // conv_stride + 1 + gru_in_units *= conv_out_chans + self.gru = torch.nn.GRU(gru_in_units, gru_units, gru_layers, batch_first=True) + + def forward(self, speech: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + speech (Tensor): Batch of padded target features (B, Lmax, idim). + + Returns: + Tensor: Reference embedding (B, gru_units) + + """ + batch_size = speech.size(0) + xs = speech.unsqueeze(1) # (B, 1, Lmax, idim) + hs = self.convs(xs).transpose(1, 2) # (B, Lmax', conv_out_chans, idim') + # NOTE(kan-bayashi): We need to care the length? + time_length = hs.size(1) + hs = hs.contiguous().view(batch_size, time_length, -1) # (B, Lmax', gru_units) + self.gru.flatten_parameters() + _, ref_embs = self.gru(hs) # (gru_layers, batch_size, gru_units) + ref_embs = ref_embs[-1] # (batch_size, gru_units) + + return ref_embs + + +class StyleTokenLayer(torch.nn.Module): + """Style token layer module. + + This module is style token layer introduced in `Style Tokens: Unsupervised Style + Modeling, Control and Transfer in End-to-End Speech Synthesis`. + + .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End + Speech Synthesis`: https://arxiv.org/abs/1803.09017 + + Args: + ref_embed_dim (int, optional): Dimension of the input reference embedding. + gst_tokens (int, optional): The number of GST embeddings. + gst_token_dim (int, optional): Dimension of each GST embedding. + gst_heads (int, optional): The number of heads in GST multihead attention. + dropout_rate (float, optional): Dropout rate in multi-head attention. + + """ + + def __init__( + self, + ref_embed_dim: int = 128, + gst_tokens: int = 10, + gst_token_dim: int = 256, + gst_heads: int = 4, + dropout_rate: float = 0.0, + ): + """Initilize style token layer module.""" + super(StyleTokenLayer, self).__init__() + gst_embs = torch.randn(gst_tokens, gst_token_dim // gst_heads) + self.register_parameter("gst_embs", torch.nn.Parameter(gst_embs)) + self.mha = MultiHeadedAttention( + q_dim=ref_embed_dim, + k_dim=gst_token_dim // gst_heads, + v_dim=gst_token_dim // gst_heads, + n_head=gst_heads, + n_feat=gst_token_dim, + dropout_rate=dropout_rate, + ) + + def forward(self, ref_embs: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + ref_embs (Tensor): Reference embeddings (B, ref_embed_dim). + + Returns: + Tensor: Style token embeddings (B, gst_token_dim). + + """ + batch_size = ref_embs.size(0) + # (num_tokens, token_dim) -> (batch_size, num_tokens, token_dim) + gst_embs = torch.tanh(self.gst_embs).unsqueeze(0).expand(batch_size, -1, -1) + # NOTE(kan-bayashi): Shoule we apply Tanh? + ref_embs = ref_embs.unsqueeze(1) # (batch_size, 1 ,ref_embed_dim) + style_embs = self.mha(ref_embs, gst_embs, gst_embs, None) + return style_embs.squeeze(1) + + +class MultiHeadedAttention(BaseMultiHeadedAttention): + """Multi head attention module with different input dimension.""" + + def __init__(self, q_dim, k_dim, v_dim, n_head, n_feat, dropout_rate=0.0): + """Initialize multi head attention module.""" + # NOTE(kan-bayashi): Do not use super().__init__() here since we want to + # overwrite BaseMultiHeadedAttention.__init__() method. + torch.nn.Module.__init__(self) + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = torch.nn.Linear(q_dim, n_feat) + self.linear_k = torch.nn.Linear(k_dim, n_feat) + self.linear_v = torch.nn.Linear(v_dim, n_feat) + self.linear_out = torch.nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.use_flash_attn = False + self.q_norm = torch.nn.Identity() + self.k_norm = torch.nn.Identity() diff --git a/fs2/model.py b/fs2/model.py index 76f6b4c..ec7bd5d 100644 --- a/fs2/model.py +++ b/fs2/model.py @@ -22,6 +22,7 @@ from torchaudio.models import Conformer from .config import FastSpeech2Config +from .gst.model import StyleEncoder from .layers import PositionalEmbedding, PostNet from .loss import FastSpeech2Loss from .noam import NoamLR @@ -86,7 +87,8 @@ def __init__( self.position_embedding = PositionalEmbedding( self.config.model.encoder.input_dim ) - + if self.config.model.use_global_style_token_module: + self.gst = StyleEncoder(idim=self.config.preprocessing.audio.n_mels) self.encoder = Conformer( input_dim=self.config.model.encoder.input_dim, num_heads=self.config.model.encoder.heads, @@ -150,7 +152,6 @@ def forward(self, batch, control=InferenceControl(), inference=False): # and update the control accordingly, but this is hacky and should be fixed if "duration_control" in batch and batch["duration_control"][0]: control.duration = batch["duration_control"][0] - # Determine whether we're teacher forcing or not # To do so, we need to be in inference mode and # the data loader should have loaded some Mel lengths @@ -187,6 +188,16 @@ def forward(self, batch, control=InferenceControl(), inference=False): # Encoder x, _ = self.encoder(inputs + enc_pos_emb, src_lens) # expects B, T, K + + # Add Global Style Token Embedding + if self.config.model.use_global_style_token_module: + if torch.is_tensor(batch["mel_style_reference"]): + # Used in training and also for synthesis with a reference audio + style_embs = self.gst(batch["mel_style_reference"]) + else: + style_embs = self.gst.condition_on_gst_tokens(batch["text"].size(0)) + x = x + style_embs.unsqueeze(1) + # Speaker Embedding if self.config.model.multispeaker and self.speaker_embedding: speaker_emb = self.speaker_embedding(speaker_ids)