Skip to content

Commit

Permalink
Improve padding and attention mask handling (facebookresearch#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Oct 16, 2023
1 parent e8acd38 commit 4604d73
Show file tree
Hide file tree
Showing 63 changed files with 1,334 additions and 1,226 deletions.
11 changes: 0 additions & 11 deletions doc/reference/abc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,3 @@ ABCs and Protocols
:nosignatures:

gang.Gang
nn.PositionEncoder
nn.Projection
nn.transformer.AttentionMaskGenerator
nn.transformer.AttentionWeightHook
nn.transformer.FeedForwardNetwork
nn.transformer.MultiheadAttention
nn.transformer.SDPA
nn.transformer.TransformerDecoder
nn.transformer.TransformerDecoderLayer
nn.transformer.TransformerEncoder
nn.transformer.TransformerEncoderLayer
19 changes: 0 additions & 19 deletions doc/reference/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,6 @@ Classes
:toctree: generated/classes
:nosignatures:

nn.Embedding
nn.IncrementalState
nn.IncrementalStateBag
nn.LearnedPositionEncoder
nn.Linear
nn.ModuleList
nn.RotaryEncoder
nn.SinusoidalPositionEncoder
nn.TiedProjection
nn.transformer.ALiBiAttentionMaskGenerator
nn.transformer.CausalAttentionMaskGenerator
nn.transformer.RelativePositionSDPA
nn.transformer.StandardFeedForwardNetwork
nn.transformer.StandardMultiheadAttention
nn.transformer.StandardTransformerDecoder
nn.transformer.StandardTransformerDecoderLayer
nn.transformer.StandardTransformerEncoder
nn.transformer.StandardTransformerEncoderLayer
nn.transformer.StoreAttentionWeights
optim.lr_scheduler.CosineAnnealingLR
optim.lr_scheduler.LRSchedulerBase
optim.lr_scheduler.MyleLR
Expand Down
1 change: 0 additions & 1 deletion doc/reference/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ Functions
:nosignatures:

nn.utils.mask.to_float_mask
nn.utils.mask.to_padding_mask
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ waveform_to_fbank_converter::operator()(data &&d) const
at::kCPU, at::kFloat, /*non_blocking=*/false, /*copy=*/false, at::MemoryFormat::Contiguous);

if (!are_close(opts_.waveform_scale(), 1.0F))
waveform = at::multiply(waveform, opts_.waveform_scale());
waveform = waveform.multiply(opts_.waveform_scale());

at::Tensor fbank = computer_->compute(waveform, opts_.pin_memory());

Expand All @@ -72,7 +72,7 @@ waveform_to_fbank_converter::operator()(data &&d) const

std::tie(stdev, mean) = at::std_mean(fbank, /*dim=*/0);

fbank = at::divide(at::subtract(fbank, mean), stdev);
fbank = fbank.subtract(mean).divide(stdev);
}

// If no device is specified, we fallback to the device of the waveform
Expand Down
5 changes: 5 additions & 0 deletions fairseq2n/src/fairseq2n/data/collater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ collate_op::pad_tensors(span<at::Tensor> tensors, std::int64_t pad_idx, const co
++i;
}

// We might still need to return as ragged even if all sequences have the
// same length if `seqs` has extra padding due to `pad_to_multiple`.
if (!is_ragged && !tensors.empty() && seq_lens_data[0] != seqs.size(1))
is_ragged = true;

seq_lens = seq_lens.to(seqs.device());

// Pack the sequences and their lengths into a dict.
Expand Down
50 changes: 26 additions & 24 deletions src/fairseq2/generation/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from fairseq2.generation.logits_processor import LogitsProcessor
from fairseq2.models.encoder_decoder import Seq2SeqDecoder
from fairseq2.nn.incremental_state import IncrementalStateBag
from fairseq2.nn.ops import pad_sequence, repeat_interleave
from fairseq2.nn.ops import repeat_interleave
from fairseq2.nn.padding import PaddingMask, pad_seqs
from fairseq2.typing import Device


Expand Down Expand Up @@ -143,7 +144,7 @@ def __init__(
def __call__(
self,
encoder_output: Tensor,
encoder_padding_mask: Optional[Tensor],
encoder_padding_mask: Optional[PaddingMask],
source_seq_len: Optional[int] = None,
) -> "SequenceGeneratorOutput":
opts = self.opts
Expand Down Expand Up @@ -378,11 +379,18 @@ def __call__(
encoder_output = encoder_output[search_indices].flatten(0, 1)

if encoder_padding_mask is not None:
# (N x B, S_enc, M) -> (N, B, S_enc, M)
padding_mask = encoder_padding_mask.unflatten(0, (num_searches, -1))
# (N x B)
seq_lens = encoder_padding_mask.seq_lens

# (N, B, S_enc, M) -> ((N - F) x B, S_enc, M)
encoder_padding_mask = padding_mask[search_indices].flatten(0, 1)
# (N x B) -> (N, B)
seq_lens = seq_lens.unflatten(0, (num_searches, -1))

# (N, B) -> ((N - F) x B)
seq_lens = seq_lens[search_indices].flatten(0, 1)

encoder_padding_mask = PaddingMask(
seq_lens, batch_seq_len=encoder_output.size(1)
)
# fmt: on

num_searches = new_num_searches
Expand Down Expand Up @@ -471,26 +479,20 @@ def _determine_max_seq_len(self, source_seq_len: Optional[int]) -> int:
return max_seq_len

def _fan_out_encoder_output(
self, encoder_output: Tensor, encoder_padding_mask: Optional[Tensor]
) -> Tuple[Tensor, Optional[Tensor]]:
num_searches = encoder_output.size(0) # i.e. batch size

self, encoder_output: Tensor, encoder_padding_mask: Optional[PaddingMask]
) -> Tuple[Tensor, Optional[PaddingMask]]:
# Fan out `encoder_output` to `num_searches` x `beam_size`.
# (N)
fan_out_indices = torch.arange(num_searches, device=encoder_output.device)

# (N) -> (N x B)
fan_out_indices = repeat_interleave(
fan_out_indices, dim=0, repeat=self.beam_size
)

# (N, S_enc, M) -> (N x B, S_enc, M)
encoder_output = encoder_output.index_select(dim=0, index=fan_out_indices)
encoder_output = repeat_interleave(encoder_output, dim=0, repeat=self.beam_size)

# (N, S_enc, M) -> (N x B, S_enc, M)
if encoder_padding_mask is not None:
encoder_padding_mask = encoder_padding_mask.index_select(
dim=0, index=fan_out_indices
seq_lens = encoder_padding_mask.seq_lens

seq_lens = repeat_interleave(seq_lens, dim=0, repeat=self.beam_size)

encoder_padding_mask = PaddingMask(
seq_lens, batch_seq_len=encoder_output.size(1)
)

return encoder_output, encoder_padding_mask
Expand All @@ -500,7 +502,7 @@ def _bootstrap_seqs_and_scores(
seqs: Tensor,
scores: Tensor,
encoder_output: Tensor,
encoder_padding_mask: Optional[Tensor],
encoder_padding_mask: Optional[PaddingMask],
state_bag: IncrementalStateBag,
) -> None:
assert self.prefix_seq_len > 0
Expand Down Expand Up @@ -633,7 +635,7 @@ class SequenceGeneratorOutput:

def collate(
self, *, hypo_idx: int = 0, skip_batch: bool = False
) -> Tuple[Tensor, Optional[Tensor]]:
) -> Tuple[Tensor, Optional[PaddingMask]]:
"""Collate the generated sequences at index ``hypo_idx`` in each search
result into a single tensor.
Expand Down Expand Up @@ -670,7 +672,7 @@ def collate(
# Return a zero-dimensional (not scalar!) tensor.
return torch.empty((0,), device=self.device, dtype=torch.int64), None

return pad_sequence(seqs, self.pad_idx, pad_to_multiple=2)
return pad_seqs(seqs, self.pad_idx)


@dataclass
Expand Down
40 changes: 19 additions & 21 deletions src/fairseq2/generation/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
SequenceGeneratorOutput,
)
from fairseq2.models.encoder_decoder import EncoderDecoderModel
from fairseq2.nn.ops import pad_sequence
from fairseq2.nn.padding import PaddingMask, pad_seqs
from fairseq2.nn.utils.module import infer_device


Expand Down Expand Up @@ -66,11 +66,11 @@ def __init__(

@torch.inference_mode()
def _do_generate(
self, source_seqs: Tensor, source_seq_lens: Optional[Tensor]
self, source_seqs: Tensor, source_padding_mask: Optional[PaddingMask]
) -> "SequenceToTextOutput":
"""A subclass should call this function for the actual text generation."""
encoder_output, encoder_padding_mask = self.model.encode(
source_seqs, source_seq_lens
source_seqs, source_padding_mask
)

gen_output = self.generator(
Expand Down Expand Up @@ -100,10 +100,10 @@ class SequenceToTextOutput:
the batch size, :math:`S_{enc}` is the encoder output sequence length, and
:math:`M` is the dimensionality of the model."""

encoder_padding_mask: Optional[Tensor]
"""The float padding mask of :attr:`encoder_output`. *Shape:*
:math:`(N,S_{enc})`, where :math:`N` is the batch size and :math:`S_{enc}`
is the encoder output sequence length."""
encoder_padding_mask: Optional[PaddingMask]
"""The padding mask of :attr:`encoder_output`. *Shape:* :math:`(N,S_{enc})`,
where :math:`N` is the batch size and :math:`S_{enc}` is the encoder output
sequence length."""


class SequenceToTextGenerator(SequenceToTextGeneratorBase):
Expand All @@ -114,41 +114,39 @@ class SequenceToTextGenerator(SequenceToTextGeneratorBase):
"""

def __call__(
self, source_seqs: Tensor, source_seq_lens: Optional[Tensor]
self, source_seqs: Tensor, source_padding_mask: Optional[PaddingMask]
) -> List[StringLike]:
"""
:param source_seqs:
The source sequences to use for generation. *Shape:* :math:`(N,S,*)`,
where :math:`N` is the batch size, :math:`S` is the sequence length,
and :math:`*` is any number of sequence-specific dimensions
including none.
:param source_seq_lens:
An array where each element represents the length of the sequence at
the same index in ``source_seqs``. *Shape:* :math:`(N)`, where
:math:`N` is the batch size.
:param source_padding_mask:
The padding mask of ``source_seqs``. *Shape:* :math:`(N,S)`, where
:math:`N` is the batch size and :math:`S` is the sequence length.
:returns:
The generated text sentences.
"""
output = self.generate_ex(source_seqs, source_seq_lens)
output = self.generate_ex(source_seqs, source_padding_mask)

return output.sentences

def generate_ex(
self, source_seqs: Tensor, source_seq_lens: Optional[Tensor]
self, source_seqs: Tensor, source_padding_mask: Optional[PaddingMask]
) -> SequenceToTextOutput:
"""
:param source_seqs:
The source sequences to use for generation. *Shape:* :math:`(N,S,*)`,
where :math:`N` is the batch size, :math:`S` is the sequence length,
and :math:`*` is any number of sequence-specific dimensions
including none.
:param source_seq_lens:
An array where each element represents the length of the sequence at
the same index in ``source_seqs``. *Shape:* :math:`(N)`, where
:math:`N` is the batch size.
:param source_padding_mask:
The padding mask of ``source_seqs``. *Shape:* :math:`(N,S)`, where
:math:`N` is the batch size and :math:`S` is the sequence length.
"""
return self._do_generate(source_seqs, source_seq_lens)
return self._do_generate(source_seqs, source_padding_mask)


class TextTranslator(SequenceToTextGeneratorBase):
Expand Down Expand Up @@ -211,8 +209,8 @@ def translate_ex(
:param source_sentences:
The sentences in the source language.
"""
seqs, seq_lens = pad_sequence(
seqs, padding_mask = pad_seqs(
[self.source_encoder(s) for s in source_sentences], pad_idx=self.pad_idx
)

return self._do_generate(seqs, seq_lens)
return self._do_generate(seqs, padding_mask)
22 changes: 13 additions & 9 deletions src/fairseq2/models/conformer/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@

from fairseq2.models.conformer.convolution import ConformerConvolution
from fairseq2.nn.normalization import LayerNorm
from fairseq2.nn.padding import PaddingMask
from fairseq2.nn.transformer import (
AttentionMask,
FeedForwardNetwork,
LayerNormFactory,
MultiheadAttention,
TransformerEncoderLayer,
create_default_layer_norm,
create_standard_layer_norm,
)
from fairseq2.nn.utils.module import check_model_dim
from fairseq2.typing import DataType, Device, finaloverride
Expand Down Expand Up @@ -72,7 +74,7 @@ def __init__(
super().__init__(model_dim)

if layer_norm_factory is None:
layer_norm_factory = create_default_layer_norm
layer_norm_factory = create_standard_layer_norm

self.ffn1_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)

Expand Down Expand Up @@ -120,9 +122,9 @@ def __init__(
def forward(
self,
seqs: Tensor,
padding_mask: Optional[Tensor],
self_attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
padding_mask: Optional[PaddingMask],
self_attn_mask: Optional[AttentionMask] = None,
) -> Tuple[Tensor, Optional[PaddingMask]]:
seqs = self._forward_ffn1(seqs)

seqs = self._forward_self_attn(seqs, padding_mask, self_attn_mask)
Expand Down Expand Up @@ -150,8 +152,8 @@ def _forward_ffn1(self, seqs: Tensor) -> Tensor:
def _forward_self_attn(
self,
seqs: Tensor,
padding_mask: Optional[Tensor],
self_attn_mask: Optional[Tensor],
padding_mask: Optional[PaddingMask],
self_attn_mask: Optional[AttentionMask],
) -> Tensor:
residual = seqs

Expand All @@ -161,17 +163,19 @@ def _forward_self_attn(
seqs,
padding_mask,
keys=seqs,
key_padding_mask=padding_mask,
values=seqs,
attn_mask=self_attn_mask,
key_padding_mask=padding_mask,
)

if self.self_attn_dropout is not None:
seqs = self.self_attn_dropout(seqs)

return seqs + residual

def _forward_conv(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
def _forward_conv(
self, seqs: Tensor, padding_mask: Optional[PaddingMask]
) -> Tensor:
residual = seqs

seqs = self.conv_layer_norm(seqs)
Expand Down
8 changes: 4 additions & 4 deletions src/fairseq2/models/conformer/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.nn.functional import pad

from fairseq2.nn.normalization import LayerNorm, StandardLayerNorm
from fairseq2.nn.utils.mask import apply_padding_mask
from fairseq2.nn.padding import PaddingMask, apply_padding_mask
from fairseq2.typing import DataType, Device


Expand Down Expand Up @@ -112,15 +112,15 @@ def __init__(
model_dim, model_dim, kernel_size=1, bias=False, device=device, dtype=dtype
)

def forward(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
def forward(self, seqs: Tensor, padding_mask: Optional[PaddingMask]) -> Tensor:
"""
:param seqs:
The sequences to process. *Shape:* :math:`(N,S,M)`, where :math:`N`
is the batch size, :math:`S` is the sequence length, and :math:`M`
is the dimensionality of the model.
:param padding_mask:
The float padding mask of ``seqs``. *Shape:* :math:`(N,S)`, where
:math:`N` is the batch size and :math:`S` is the sequence length.
The padding mask of ``seqs``. *Shape:* :math:`(N,S)`, where :math:`N`
is the batch size and :math:`S` is the sequence length.
:returns:
The processed sequences. *Shape:* Same as ``seqs``.
Expand Down
Loading

0 comments on commit 4604d73

Please sign in to comment.