diff --git a/src/fairseq2/assets/cards/s2t_transformer_covost_st_en_de.yaml b/src/fairseq2/assets/cards/s2t_transformer_covost_st_en_de.yaml deleted file mode 100644 index a111ec69d..000000000 --- a/src/fairseq2/assets/cards/s2t_transformer_covost_st_en_de.yaml +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -name: s2t_transformer_covost_st_en_de -model_type: s2t_transformer -model_arch: small -task: translation -tgt_langs: [de] -tokenizer: "https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_vocab_char.zip" -tokenizer_file: "spm_char.model" -checkpoint: "https://dl.fbaipublicfiles.com/fairseq/s2t/covost2_en_de_st_transformer_s.pt" diff --git a/src/fairseq2/assets/cards/s2t_transformer_mustc_asr_de_s.yaml b/src/fairseq2/assets/cards/s2t_transformer_mustc_asr_de_s.yaml index 59e56df05..a98c6d7de 100644 --- a/src/fairseq2/assets/cards/s2t_transformer_mustc_asr_de_s.yaml +++ b/src/fairseq2/assets/cards/s2t_transformer_mustc_asr_de_s.yaml @@ -7,6 +7,9 @@ name: s2t_transformer_mustc_asr_de_s model_type: s2t_transformer model_arch: small +model_config: + target_vocab_info: + size: 5000 task: transcription tgt_langs: [en] tokenizer: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_asr_vocab_unigram5000.zip" diff --git a/src/fairseq2/assets/cards/s2t_transformer_mustc_asr_es_s.yaml b/src/fairseq2/assets/cards/s2t_transformer_mustc_asr_es_s.yaml index 1e9947fe9..87d3b1992 100644 --- a/src/fairseq2/assets/cards/s2t_transformer_mustc_asr_es_s.yaml +++ b/src/fairseq2/assets/cards/s2t_transformer_mustc_asr_es_s.yaml @@ -7,6 +7,9 @@ name: s2t_transformer_mustc_asr_es_s model_type: s2t_transformer model_arch: small +model_config: + target_vocab_info: + size: 5000 task: transcription tgt_langs: [en] tokenizer: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_asr_vocab_unigram5000.zip" diff --git a/src/fairseq2/assets/cards/s2t_transformer_mustc_st_de_s.yaml b/src/fairseq2/assets/cards/s2t_transformer_mustc_st_de_s.yaml index eb4ccd0c7..d6c464d90 100644 --- a/src/fairseq2/assets/cards/s2t_transformer_mustc_st_de_s.yaml +++ b/src/fairseq2/assets/cards/s2t_transformer_mustc_st_de_s.yaml @@ -7,6 +7,9 @@ name: s2t_transformer_mustc_st_de_s model_type: s2t_transformer model_arch: small +model_config: + target_vocab_info: + size: 8000 task: translation tgt_langs: [de] tokenizer: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_st_vocab_unigram8000.zip" diff --git a/src/fairseq2/models/decoder.py b/src/fairseq2/models/decoder.py index d6f829d7b..a57cd8e47 100644 --- a/src/fairseq2/models/decoder.py +++ b/src/fairseq2/models/decoder.py @@ -9,6 +9,7 @@ from torch import Tensor +from fairseq2.data import VocabularyInfo from fairseq2.models.sequence import SequenceBatch, SequenceModel, SequenceModelOutput from fairseq2.nn.incremental_state import IncrementalStateBag from fairseq2.nn.padding import PaddingMask @@ -69,12 +70,14 @@ class DecoderModel(SequenceModel, SequenceDecoder): model_dim: int - def __init__(self, model_dim: int) -> None: + def __init__(self, model_dim: int, vocab_info: VocabularyInfo) -> None: """ :param model_dim: The dimensionality of the model. + :param vocab_info: + The vocabulary information of sequences produced by the model. """ - super().__init__() + super().__init__(vocab_info) self.model_dim = model_dim diff --git a/src/fairseq2/models/encoder_decoder.py b/src/fairseq2/models/encoder_decoder.py index 24f6a732b..e18143206 100644 --- a/src/fairseq2/models/encoder_decoder.py +++ b/src/fairseq2/models/encoder_decoder.py @@ -9,6 +9,7 @@ from torch import Tensor +from fairseq2.data import VocabularyInfo from fairseq2.models.seq2seq import Seq2SeqBatch, Seq2SeqModel from fairseq2.models.sequence import SequenceModelOutput from fairseq2.nn.incremental_state import IncrementalStateBag @@ -83,12 +84,14 @@ class EncoderDecoderModel(Seq2SeqModel, Seq2SeqDecoder): model_dim: int - def __init__(self, model_dim: int) -> None: + def __init__(self, model_dim: int, target_vocab_info: VocabularyInfo) -> None: """ :param model_dim: The dimensionality of the model. + :param target_vocab_info: + The vocabulary information of sequences produced by the model. """ - super().__init__() + super().__init__(target_vocab_info) self.model_dim = model_dim diff --git a/src/fairseq2/models/llama/builder.py b/src/fairseq2/models/llama/builder.py index 8c91836d5..26f16ebeb 100644 --- a/src/fairseq2/models/llama/builder.py +++ b/src/fairseq2/models/llama/builder.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from typing import Optional +from fairseq2.data import VocabularyInfo from fairseq2.models.transformer import ( TransformerDecoderModel, TransformerEmbeddingFrontend, @@ -44,8 +45,8 @@ class LLaMAConfig: max_seq_len: int """The expected maximum sequence length.""" - vocabulary_size: int - """The size of the vocabulary.""" + vocab_info: VocabularyInfo + """The vocabulary information.""" num_layers: int """The number of Transformer decoder layers.""" @@ -79,7 +80,9 @@ def _7b() -> LLaMAConfig: return LLaMAConfig( model_dim=4096, max_seq_len=2048, - vocabulary_size=32000, + vocab_info=VocabularyInfo( + size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None + ), num_layers=32, num_attn_heads=32, num_key_value_heads=32, @@ -94,7 +97,9 @@ def _13b() -> LLaMAConfig: return LLaMAConfig( model_dim=5120, max_seq_len=2048, - vocabulary_size=32000, + vocab_info=VocabularyInfo( + size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None + ), num_layers=40, num_attn_heads=40, num_key_value_heads=40, @@ -109,7 +114,9 @@ def _33b() -> LLaMAConfig: return LLaMAConfig( model_dim=6656, max_seq_len=2048, - vocabulary_size=32000, + vocab_info=VocabularyInfo( + size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None + ), num_layers=60, num_attn_heads=52, num_key_value_heads=52, @@ -124,7 +131,9 @@ def _65b() -> LLaMAConfig: return LLaMAConfig( model_dim=8192, max_seq_len=2048, - vocabulary_size=32000, + vocab_info=VocabularyInfo( + size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None + ), num_layers=80, num_attn_heads=64, num_key_value_heads=64, @@ -139,7 +148,9 @@ def _llama2_7b() -> LLaMAConfig: return LLaMAConfig( model_dim=4096, max_seq_len=4096, - vocabulary_size=32000, + vocab_info=VocabularyInfo( + size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None + ), num_layers=32, num_attn_heads=32, num_key_value_heads=32, @@ -154,7 +165,9 @@ def _llama2_13b() -> LLaMAConfig: return LLaMAConfig( model_dim=5120, max_seq_len=4096, - vocabulary_size=32000, + vocab_info=VocabularyInfo( + size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None + ), num_layers=40, num_attn_heads=40, num_key_value_heads=40, @@ -169,7 +182,9 @@ def _llama2_70b() -> LLaMAConfig: return LLaMAConfig( model_dim=8192, max_seq_len=4096, - vocabulary_size=32000, + vocab_info=VocabularyInfo( + size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None + ), num_layers=80, num_attn_heads=64, num_key_value_heads=8, @@ -222,7 +237,7 @@ def build_model(self) -> TransformerDecoderModel: final_proj = Linear( self.config.model_dim, - self.config.vocabulary_size, + self.config.vocab_info.size, bias=False, init_fn=init_final_projection, device=self.device, @@ -230,13 +245,13 @@ def build_model(self) -> TransformerDecoderModel: ) return TransformerDecoderModel( - frontend, decoder, final_proj, target_pad_idx=None + frontend, decoder, final_proj, self.config.vocab_info ) def build_frontend(self) -> TransformerFrontend: """Build a Transformer decoder front-end.""" embed = StandardEmbedding( - num_embeddings=self.config.vocabulary_size, + num_embeddings=self.config.vocab_info.size, embedding_dim=self.config.model_dim, device=self.device, dtype=self.dtype, diff --git a/src/fairseq2/models/mistral/builder.py b/src/fairseq2/models/mistral/builder.py index 7163e7bca..8eb44e654 100644 --- a/src/fairseq2/models/mistral/builder.py +++ b/src/fairseq2/models/mistral/builder.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from typing import Optional +from fairseq2.data import VocabularyInfo from fairseq2.models.transformer import ( TransformerDecoderModel, TransformerEmbeddingFrontend, @@ -45,8 +46,8 @@ class MistralConfig: max_seq_len: int """The expected maximum sequence length.""" - vocabulary_size: int - """The size of the vocabulary.""" + vocab_info: VocabularyInfo + """The vocabulary information.""" attn_window_len: int """The local attention window length.""" @@ -79,7 +80,9 @@ def _7b() -> MistralConfig: return MistralConfig( model_dim=4096, max_seq_len=8192, - vocabulary_size=32000, + vocab_info=VocabularyInfo( + size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None + ), attn_window_len=4096, num_layers=32, num_attn_heads=32, @@ -131,7 +134,7 @@ def build_model(self) -> TransformerDecoderModel: final_proj = Linear( self.config.model_dim, - self.config.vocabulary_size, + self.config.vocab_info.size, bias=False, init_fn=init_final_projection, device=self.device, @@ -139,13 +142,13 @@ def build_model(self) -> TransformerDecoderModel: ) return TransformerDecoderModel( - frontend, decoder, final_proj, target_pad_idx=None + frontend, decoder, final_proj, self.config.vocab_info ) def build_frontend(self) -> TransformerFrontend: """Build a Transformer decoder front-end.""" embed = StandardEmbedding( - num_embeddings=self.config.vocabulary_size, + num_embeddings=self.config.vocab_info.size, embedding_dim=self.config.model_dim, device=self.device, dtype=self.dtype, diff --git a/src/fairseq2/models/nllb/builder.py b/src/fairseq2/models/nllb/builder.py index 59eeba358..c85e5eabd 100644 --- a/src/fairseq2/models/nllb/builder.py +++ b/src/fairseq2/models/nllb/builder.py @@ -46,11 +46,8 @@ class NllbConfig: max_seq_len: int """The expected maximum sequence length.""" - vocabulary_size: int - """The size of the vocabulary.""" - - pad_idx: Optional[int] - """The index of the PAD symbol in the vocabulary.""" + vocab_info: VocabularyInfo + """The vocabulary information.""" num_encoder_layers: int """The number of Transformer encoder layers.""" @@ -70,10 +67,6 @@ class NllbConfig: dropout_p: float """The dropout probability in Transformer layers.""" - def update_vocabulary(self, info: VocabularyInfo) -> None: - """Update vocabulary configuration from ``info``.""" - self.vocabulary_size, self.pad_idx = info.size, info.pad_idx - nllb_archs = ArchitectureRegistry[NllbConfig]("nllb") @@ -86,8 +79,9 @@ def _dense_1b() -> NllbConfig: return NllbConfig( model_dim=1024, max_seq_len=1024, - vocabulary_size=256206, - pad_idx=0, + vocab_info=VocabularyInfo( + size=256206, unk_idx=1, bos_idx=2, eos_idx=3, pad_idx=0 + ), num_encoder_layers=24, num_decoder_layers=24, num_encoder_attn_heads=16, @@ -102,8 +96,9 @@ def _dense_3b() -> NllbConfig: return NllbConfig( model_dim=2048, max_seq_len=1024, - vocabulary_size=256206, - pad_idx=0, + vocab_info=VocabularyInfo( + size=256206, unk_idx=1, bos_idx=2, eos_idx=3, pad_idx=0 + ), num_encoder_layers=24, num_decoder_layers=24, num_encoder_attn_heads=16, @@ -118,8 +113,9 @@ def _dense_600m() -> NllbConfig: return NllbConfig( model_dim=1024, max_seq_len=1024, - vocabulary_size=256206, - pad_idx=0, + vocab_info=VocabularyInfo( + size=256206, unk_idx=1, bos_idx=2, eos_idx=3, pad_idx=0 + ), num_encoder_layers=12, num_decoder_layers=12, num_encoder_attn_heads=16, @@ -177,15 +173,15 @@ def build_model(self) -> TransformerModel: frontend, decoder, final_proj, - target_pad_idx=self.config.pad_idx, + self.config.vocab_info, ) def build_embedding(self) -> StandardEmbedding: """Build an embedding table.""" return StandardEmbedding( - num_embeddings=self.config.vocabulary_size, + num_embeddings=self.config.vocab_info.size, embedding_dim=self.config.model_dim, - pad_idx=self.config.pad_idx, + pad_idx=self.config.vocab_info.pad_idx, init_fn=init_scaled_embedding, device=self.device, dtype=self.dtype, @@ -196,7 +192,7 @@ def build_frontend(self, embed: Embedding) -> TransformerFrontend: pos_encoder = SinusoidalPositionEncoder( self.config.model_dim, self.config.max_seq_len, - _legacy_pad_idx=self.config.pad_idx, + _legacy_pad_idx=self.config.vocab_info.pad_idx, device=self.device, ) diff --git a/src/fairseq2/models/s2t_transformer/builder.py b/src/fairseq2/models/s2t_transformer/builder.py index da246f0f6..16369ed0d 100644 --- a/src/fairseq2/models/s2t_transformer/builder.py +++ b/src/fairseq2/models/s2t_transformer/builder.py @@ -58,11 +58,8 @@ class S2TTransformerConfig: num_fbank_channels: int """The number of source log-mel filterbank channels.""" - target_vocabulary_size: int - """The size of the target vocabulary.""" - - target_pad_idx: Optional[int] - """The index of the PAD symbol in the target vocabulary.""" + target_vocab_info: VocabularyInfo + """The target vocabulary information.""" use_relative_pos: bool """If ``True``, uses relative positional encodings for source sequences.""" @@ -91,10 +88,6 @@ class S2TTransformerConfig: depthwise_conv_kernel_size: int """The kernel size of depthwise convolutions in Conformer blocks.""" - def update_target_vocabulary(self, info: VocabularyInfo) -> None: - """Update target vocabulary configuration from ``info``.""" - self.target_vocabulary_size, self.target_pad_idx = info.size, info.pad_idx - s2t_transformer_archs = ArchitectureRegistry[S2TTransformerConfig]("s2t_transformer") @@ -108,8 +101,9 @@ def _tiny() -> S2TTransformerConfig: model_dim=256, max_seq_len=1024, num_fbank_channels=80, - target_vocabulary_size=10000, - target_pad_idx=1, + target_vocab_info=VocabularyInfo( + size=10000, unk_idx=0, bos_idx=0, eos_idx=0, pad_idx=1 + ), use_relative_pos=False, use_conformer=False, num_encoder_layers=6, @@ -128,8 +122,9 @@ def _small() -> S2TTransformerConfig: model_dim=256, max_seq_len=1024, num_fbank_channels=80, - target_vocabulary_size=10000, - target_pad_idx=1, + target_vocab_info=VocabularyInfo( + size=10000, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1 + ), use_relative_pos=False, use_conformer=False, num_encoder_layers=12, @@ -148,8 +143,9 @@ def _medium() -> S2TTransformerConfig: model_dim=512, max_seq_len=1024, num_fbank_channels=80, - target_vocabulary_size=10000, - target_pad_idx=1, + target_vocab_info=VocabularyInfo( + size=10000, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1 + ), use_relative_pos=False, use_conformer=False, num_encoder_layers=12, @@ -168,8 +164,9 @@ def _large() -> S2TTransformerConfig: model_dim=1024, max_seq_len=1024, num_fbank_channels=80, - target_vocabulary_size=10000, - target_pad_idx=1, + target_vocab_info=VocabularyInfo( + size=10000, unk_idx=0, bos_idx=0, eos_idx=0, pad_idx=1 + ), use_relative_pos=False, use_conformer=False, num_encoder_layers=12, @@ -188,8 +185,9 @@ def _conformer_medium() -> S2TTransformerConfig: model_dim=256, max_seq_len=6000, num_fbank_channels=80, - target_vocabulary_size=181, - target_pad_idx=1, + target_vocab_info=VocabularyInfo( + size=181, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1 + ), use_relative_pos=False, use_conformer=True, num_encoder_layers=12, @@ -246,7 +244,7 @@ def build_model(self) -> TransformerModel: final_proj = Linear( self.config.model_dim, - self.config.target_vocabulary_size, + self.config.target_vocab_info.size, bias=False, init_fn=init_final_projection, device=self.device, @@ -259,7 +257,7 @@ def build_model(self) -> TransformerModel: decoder_frontend, decoder, final_proj, - target_pad_idx=self.config.target_pad_idx, + self.config.target_vocab_info, ) def build_encoder_frontend(self) -> TransformerFrontend: @@ -288,9 +286,9 @@ def build_encoder_frontend(self) -> TransformerFrontend: def build_decoder_frontend(self) -> TransformerFrontend: """Build a Transformer decoder front-end.""" embed = StandardEmbedding( - num_embeddings=self.config.target_vocabulary_size, + num_embeddings=self.config.target_vocab_info.size, embedding_dim=self.config.model_dim, - pad_idx=self.config.target_pad_idx, + pad_idx=self.config.target_vocab_info.pad_idx, init_fn=init_scaled_embedding, device=self.device, dtype=self.dtype, @@ -320,7 +318,7 @@ def build_target_position_encoder(self) -> PositionEncoder: return SinusoidalPositionEncoder( self.config.model_dim, self.config.max_seq_len, - _legacy_pad_idx=self.config.target_pad_idx, + _legacy_pad_idx=self.config.target_vocab_info.pad_idx, device=self.device, ) diff --git a/src/fairseq2/models/seq2seq.py b/src/fairseq2/models/seq2seq.py index 901062313..58aeaead5 100644 --- a/src/fairseq2/models/seq2seq.py +++ b/src/fairseq2/models/seq2seq.py @@ -14,6 +14,7 @@ from torch import Tensor from torch.nn import Module +from fairseq2.data import VocabularyInfo from fairseq2.models.sequence import SequenceModelOutput from fairseq2.nn.padding import PaddingMask @@ -21,6 +22,17 @@ class Seq2SeqModel(Module, ABC): """Represents a sequence-to-sequence model.""" + target_vocab_info: VocabularyInfo + + def __init__(self, target_vocab_info: VocabularyInfo) -> None: + """ + :param target_vocab_info: + The vocabulary information of sequences produced by the model. + """ + super().__init__() + + self.target_vocab_info = target_vocab_info + @abstractmethod def forward(self, batch: Seq2SeqBatch) -> SequenceModelOutput: """ diff --git a/src/fairseq2/models/sequence.py b/src/fairseq2/models/sequence.py index b1bea736c..2e9e33270 100644 --- a/src/fairseq2/models/sequence.py +++ b/src/fairseq2/models/sequence.py @@ -15,6 +15,7 @@ from torch.nn import Module from torch.nn.functional import log_softmax +from fairseq2.data import VocabularyInfo from fairseq2.nn.functional import nll_loss from fairseq2.nn.padding import PaddingMask @@ -22,6 +23,17 @@ class SequenceModel(Module, ABC): """Represents a sequence model.""" + vocab_info: VocabularyInfo + + def __init__(self, vocab_info: VocabularyInfo) -> None: + """ + :param vocab_info: + The vocabulary information of sequences produced by the model. + """ + super().__init__() + + self.vocab_info = vocab_info + @abstractmethod def forward(self, batch: SequenceBatch) -> SequenceModelOutput: """ @@ -66,10 +78,10 @@ class SequenceModelOutput: logits: Tensor """The logits for next-step prediction. *Shape:* :math:`(N,S,T)`, where :math:`N` is the batch size, :math:`S` is the sequence length, and :math:`T` - is the size of the target vocabulary.""" + is the size of the vocabulary.""" - pad_idx: Optional[int] = None - """The index of the PAD symbol in the target vocabulary.""" + vocab_info: VocabularyInfo + """The vocabulary information.""" def compute_loss( self, @@ -100,4 +112,6 @@ def compute_loss( # For numerical stability run in single precision. lprobs = log_softmax(logits, dim=-1, dtype=torch.float32) - return nll_loss(lprobs, targets, self.pad_idx, label_smoothing=label_smoothing) + return nll_loss( + lprobs, targets, self.vocab_info.pad_idx, label_smoothing=label_smoothing + ) diff --git a/src/fairseq2/models/transformer/decoder_model.py b/src/fairseq2/models/transformer/decoder_model.py index 18d80a747..9a34f5f34 100644 --- a/src/fairseq2/models/transformer/decoder_model.py +++ b/src/fairseq2/models/transformer/decoder_model.py @@ -8,6 +8,7 @@ from torch import Tensor +from fairseq2.data import VocabularyInfo from fairseq2.models.decoder import DecoderModel from fairseq2.models.sequence import SequenceModelOutput from fairseq2.models.transformer.frontend import TransformerFrontend @@ -25,14 +26,13 @@ class TransformerDecoderModel(DecoderModel): decoder_frontend: TransformerFrontend decoder: TransformerDecoder final_proj: Projection - target_pad_idx: Optional[int] def __init__( self, decoder_frontend: TransformerFrontend, decoder: TransformerDecoder, final_proj: Projection, - target_pad_idx: Optional[int], + vocab_info: VocabularyInfo, ) -> None: """ :param decoder_frontend: @@ -40,21 +40,19 @@ def __init__( :param decoder: The decoder. :param final_proj: - The projection to apply to decoder outputs to produce logits. - :param target_pad_idx: - The index of the PAD symbol in the target vocabulary. + The projection to apply to decoder outputs. + :param vocab_info: + The vocabulary information of sequences produced by the model. """ model_dim = decoder.model_dim - super().__init__(model_dim) + super().__init__(model_dim, vocab_info) self.decoder_frontend = decoder_frontend self.decoder = decoder self.final_proj = final_proj - self.target_pad_idx = target_pad_idx - @finaloverride def decode( self, @@ -79,4 +77,4 @@ def project( ) -> SequenceModelOutput: logits = self.final_proj(decoder_output) - return SequenceModelOutput(logits, self.target_pad_idx) + return SequenceModelOutput(logits, self.vocab_info) diff --git a/src/fairseq2/models/transformer/model.py b/src/fairseq2/models/transformer/model.py index aa1b65065..2df08de1b 100644 --- a/src/fairseq2/models/transformer/model.py +++ b/src/fairseq2/models/transformer/model.py @@ -9,6 +9,7 @@ import torch.nn as nn from torch import Tensor +from fairseq2.data import VocabularyInfo from fairseq2.models.encoder_decoder import EncoderDecoderModel from fairseq2.models.sequence import SequenceModelOutput from fairseq2.models.transformer.frontend import TransformerFrontend @@ -29,7 +30,6 @@ class TransformerModel(EncoderDecoderModel): decoder_frontend: TransformerFrontend decoder: TransformerDecoder final_proj: Projection - target_pad_idx: Optional[int] def __init__( self, @@ -38,7 +38,7 @@ def __init__( decoder_frontend: TransformerFrontend, decoder: TransformerDecoder, final_proj: Projection, - target_pad_idx: Optional[int], + target_vocab_info: VocabularyInfo, ) -> None: """ :param encoder_frontend: @@ -50,13 +50,13 @@ def __init__( :param decoder: The decoder. :param final_proj: - The projection to apply to decoder outputs to produce logits. - :param target_pad_idx: - The index of the PAD symbol in the target vocabulary. + The projection to apply to decoder outputs. + :param target_vocab_info: + The vocabulary information of sequences produced by the model. """ model_dim = encoder.model_dim - super().__init__(model_dim) + super().__init__(model_dim, target_vocab_info) self.encoder_frontend = encoder_frontend self.encoder = encoder @@ -66,8 +66,6 @@ def __init__( self.final_proj = final_proj - self.target_pad_idx = target_pad_idx - @finaloverride def encode( self, seqs: Tensor, padding_mask: Optional[PaddingMask] @@ -104,7 +102,7 @@ def project( ) -> SequenceModelOutput: logits = self.final_proj(decoder_output) - return SequenceModelOutput(logits, self.target_pad_idx) + return SequenceModelOutput(logits, self.target_vocab_info) def init_final_projection(proj: Linear) -> None: diff --git a/tests/integration/models/test_llama_lora.py b/tests/integration/models/test_llama_lora.py index 6b3693a82..1620cdd19 100644 --- a/tests/integration/models/test_llama_lora.py +++ b/tests/integration/models/test_llama_lora.py @@ -6,6 +6,7 @@ import torch +from fairseq2.data import VocabularyInfo from fairseq2.models.llama import LLaMAConfig, create_llama_model, get_llama_lora_config from fairseq2.nn.lora import ( freeze_non_lora, @@ -21,7 +22,9 @@ def test_lora_wrappers_llama_works() -> None: llama_config = LLaMAConfig( model_dim=1024, max_seq_len=2048, - vocabulary_size=32001, + vocab_info=VocabularyInfo( + size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None + ), num_layers=16, num_attn_heads=8, num_key_value_heads=8,