-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add basic optional support for global style token module
- Loading branch information
Showing
8 changed files
with
554 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
|
||
# Copyright 2019 Shigeki Karita | ||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | ||
|
||
"""Multi-Head Attention layer definition.""" | ||
|
||
import math | ||
|
||
import torch | ||
from torch import nn | ||
|
||
|
||
class LayerNorm(torch.nn.LayerNorm): | ||
"""Layer normalization module. | ||
Args: | ||
nout (int): Output dim size. | ||
dim (int): Dimension to be normalized. | ||
""" | ||
|
||
def __init__(self, nout, dim=-1): | ||
"""Construct an LayerNorm object.""" | ||
super(LayerNorm, self).__init__(nout, eps=1e-12) | ||
self.dim = dim | ||
|
||
def forward(self, x): | ||
"""Apply layer normalization. | ||
Args: | ||
x (torch.Tensor): Input tensor. | ||
Returns: | ||
torch.Tensor: Normalized tensor. | ||
""" | ||
if self.dim == -1: | ||
return super(LayerNorm, self).forward(x) | ||
return ( | ||
super(LayerNorm, self) | ||
.forward(x.transpose(self.dim, -1)) | ||
.transpose(self.dim, -1) | ||
) | ||
|
||
|
||
class MultiHeadedAttention(nn.Module): | ||
"""Multi-Head Attention layer. | ||
Args: | ||
n_head (int): The number of heads. | ||
n_feat (int): The number of features. | ||
dropout_rate (float): Dropout rate. | ||
qk_norm (bool): Normalize q and k before dot product. | ||
use_flash_attn (bool): Use flash_attn implementation. | ||
causal (bool): Apply causal attention. | ||
cross_attn (bool): Cross attention instead of self attention. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
n_head, | ||
n_feat, | ||
dropout_rate, | ||
qk_norm=False, | ||
use_flash_attn=False, | ||
causal=False, | ||
cross_attn=False, | ||
): | ||
"""Construct an MultiHeadedAttention object.""" | ||
super(MultiHeadedAttention, self).__init__() | ||
assert n_feat % n_head == 0 | ||
# We assume d_v always equals d_k | ||
self.d_k = n_feat // n_head | ||
self.h = n_head | ||
self.linear_q = nn.Linear(n_feat, n_feat) | ||
self.linear_k = nn.Linear(n_feat, n_feat) | ||
self.linear_v = nn.Linear(n_feat, n_feat) | ||
self.linear_out = nn.Linear(n_feat, n_feat) | ||
self.attn = None | ||
self.dropout = ( | ||
nn.Dropout(p=dropout_rate) if not use_flash_attn else nn.Identity() | ||
) | ||
self.dropout_rate = dropout_rate | ||
|
||
# LayerNorm for q and k | ||
self.q_norm = LayerNorm(self.d_k) if qk_norm else nn.Identity() | ||
self.k_norm = LayerNorm(self.d_k) if qk_norm else nn.Identity() | ||
|
||
self.use_flash_attn = use_flash_attn | ||
self.causal = causal # only used with flash_attn | ||
self.cross_attn = cross_attn # only used with flash_attn | ||
|
||
def forward_qkv(self, query, key, value, expand_kv=False): | ||
"""Transform query, key and value. | ||
Args: | ||
query (torch.Tensor): Query tensor (#batch, time1, size). | ||
key (torch.Tensor): Key tensor (#batch, time2, size). | ||
value (torch.Tensor): Value tensor (#batch, time2, size). | ||
expand_kv (bool): Used only for partially autoregressive (PAR) decoding. | ||
Returns: | ||
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). | ||
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). | ||
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). | ||
""" | ||
n_batch = query.size(0) | ||
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) | ||
|
||
if expand_kv: | ||
k_shape = key.shape | ||
k = ( | ||
self.linear_k(key[:1, :, :]) | ||
.expand(n_batch, k_shape[1], k_shape[2]) | ||
.view(n_batch, -1, self.h, self.d_k) | ||
) | ||
v_shape = value.shape | ||
v = ( | ||
self.linear_v(value[:1, :, :]) | ||
.expand(n_batch, v_shape[1], v_shape[2]) | ||
.view(n_batch, -1, self.h, self.d_k) | ||
) | ||
else: | ||
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) | ||
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) | ||
|
||
q = q.transpose(1, 2) # (batch, head, time1, d_k) | ||
k = k.transpose(1, 2) # (batch, head, time2, d_k) | ||
v = v.transpose(1, 2) # (batch, head, time2, d_k) | ||
|
||
q = self.q_norm(q) | ||
k = self.k_norm(k) | ||
|
||
return q, k, v | ||
|
||
def forward_attention(self, value, scores, mask): | ||
"""Compute attention context vector. | ||
Args: | ||
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). | ||
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). | ||
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). | ||
Returns: | ||
torch.Tensor: Transformed value (#batch, time1, d_model) | ||
weighted by the attention score (#batch, time1, time2). | ||
""" | ||
n_batch = value.size(0) | ||
if mask is not None: | ||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) | ||
min_value = torch.finfo(scores.dtype).min | ||
scores = scores.masked_fill(mask, min_value) | ||
self.attn = torch.softmax(scores, dim=-1).masked_fill( | ||
mask, 0.0 | ||
) # (batch, head, time1, time2) | ||
else: | ||
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) | ||
|
||
p_attn = self.dropout(self.attn) | ||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) | ||
x = ( | ||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) | ||
) # (batch, time1, d_model) | ||
|
||
return self.linear_out(x) # (batch, time1, d_model) | ||
|
||
def forward(self, query, key, value, mask, expand_kv=False): | ||
"""Compute scaled dot product attention. | ||
Args: | ||
query (torch.Tensor): Query tensor (#batch, time1, size). | ||
key (torch.Tensor): Key tensor (#batch, time2, size). | ||
value (torch.Tensor): Value tensor (#batch, time2, size). | ||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or | ||
(#batch, time1, time2). | ||
expand_kv (bool): Used only for partially autoregressive (PAR) decoding. | ||
When set to `True`, `Linear` layers are computed only for the first batch. | ||
This is useful to reduce the memory usage during decoding when the batch size is | ||
#beam_size x #mask_count, which can be very large. Typically, in single waveform | ||
inference of PAR, `Linear` layers should not be computed for all batches | ||
for source-attention. | ||
Returns: | ||
torch.Tensor: Output tensor (#batch, time1, d_model). | ||
""" | ||
q, k, v = self.forward_qkv(query, key, value, expand_kv) | ||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) | ||
return self.forward_attention(v, scores, mask) |
Oops, something went wrong.