Skip to content

Commit

Permalink
Hold vocab_info in DecoderModel and EncoderDecoderModel (facebookrese…
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Nov 6, 2023
1 parent 6197f6e commit 785c3e4
Show file tree
Hide file tree
Showing 15 changed files with 140 additions and 102 deletions.
14 changes: 0 additions & 14 deletions src/fairseq2/assets/cards/s2t_transformer_covost_st_en_de.yaml

This file was deleted.

3 changes: 3 additions & 0 deletions src/fairseq2/assets/cards/s2t_transformer_mustc_asr_de_s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
name: s2t_transformer_mustc_asr_de_s
model_type: s2t_transformer
model_arch: small
model_config:
target_vocab_info:
size: 5000
task: transcription
tgt_langs: [en]
tokenizer: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_asr_vocab_unigram5000.zip"
Expand Down
3 changes: 3 additions & 0 deletions src/fairseq2/assets/cards/s2t_transformer_mustc_asr_es_s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
name: s2t_transformer_mustc_asr_es_s
model_type: s2t_transformer
model_arch: small
model_config:
target_vocab_info:
size: 5000
task: transcription
tgt_langs: [en]
tokenizer: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_asr_vocab_unigram5000.zip"
Expand Down
3 changes: 3 additions & 0 deletions src/fairseq2/assets/cards/s2t_transformer_mustc_st_de_s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
name: s2t_transformer_mustc_st_de_s
model_type: s2t_transformer
model_arch: small
model_config:
target_vocab_info:
size: 8000
task: translation
tgt_langs: [de]
tokenizer: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_st_vocab_unigram8000.zip"
Expand Down
7 changes: 5 additions & 2 deletions src/fairseq2/models/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from torch import Tensor

from fairseq2.data import VocabularyInfo
from fairseq2.models.sequence import SequenceBatch, SequenceModel, SequenceModelOutput
from fairseq2.nn.incremental_state import IncrementalStateBag
from fairseq2.nn.padding import PaddingMask
Expand Down Expand Up @@ -69,12 +70,14 @@ class DecoderModel(SequenceModel, SequenceDecoder):

model_dim: int

def __init__(self, model_dim: int) -> None:
def __init__(self, model_dim: int, vocab_info: VocabularyInfo) -> None:
"""
:param model_dim:
The dimensionality of the model.
:param vocab_info:
The vocabulary information of sequences produced by the model.
"""
super().__init__()
super().__init__(vocab_info)

self.model_dim = model_dim

Expand Down
7 changes: 5 additions & 2 deletions src/fairseq2/models/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from torch import Tensor

from fairseq2.data import VocabularyInfo
from fairseq2.models.seq2seq import Seq2SeqBatch, Seq2SeqModel
from fairseq2.models.sequence import SequenceModelOutput
from fairseq2.nn.incremental_state import IncrementalStateBag
Expand Down Expand Up @@ -83,12 +84,14 @@ class EncoderDecoderModel(Seq2SeqModel, Seq2SeqDecoder):

model_dim: int

def __init__(self, model_dim: int) -> None:
def __init__(self, model_dim: int, target_vocab_info: VocabularyInfo) -> None:
"""
:param model_dim:
The dimensionality of the model.
:param target_vocab_info:
The vocabulary information of sequences produced by the model.
"""
super().__init__()
super().__init__(target_vocab_info)

self.model_dim = model_dim

Expand Down
39 changes: 27 additions & 12 deletions src/fairseq2/models/llama/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataclasses import dataclass
from typing import Optional

from fairseq2.data import VocabularyInfo
from fairseq2.models.transformer import (
TransformerDecoderModel,
TransformerEmbeddingFrontend,
Expand Down Expand Up @@ -44,8 +45,8 @@ class LLaMAConfig:
max_seq_len: int
"""The expected maximum sequence length."""

vocabulary_size: int
"""The size of the vocabulary."""
vocab_info: VocabularyInfo
"""The vocabulary information."""

num_layers: int
"""The number of Transformer decoder layers."""
Expand Down Expand Up @@ -79,7 +80,9 @@ def _7b() -> LLaMAConfig:
return LLaMAConfig(
model_dim=4096,
max_seq_len=2048,
vocabulary_size=32000,
vocab_info=VocabularyInfo(
size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None
),
num_layers=32,
num_attn_heads=32,
num_key_value_heads=32,
Expand All @@ -94,7 +97,9 @@ def _13b() -> LLaMAConfig:
return LLaMAConfig(
model_dim=5120,
max_seq_len=2048,
vocabulary_size=32000,
vocab_info=VocabularyInfo(
size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None
),
num_layers=40,
num_attn_heads=40,
num_key_value_heads=40,
Expand All @@ -109,7 +114,9 @@ def _33b() -> LLaMAConfig:
return LLaMAConfig(
model_dim=6656,
max_seq_len=2048,
vocabulary_size=32000,
vocab_info=VocabularyInfo(
size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None
),
num_layers=60,
num_attn_heads=52,
num_key_value_heads=52,
Expand All @@ -124,7 +131,9 @@ def _65b() -> LLaMAConfig:
return LLaMAConfig(
model_dim=8192,
max_seq_len=2048,
vocabulary_size=32000,
vocab_info=VocabularyInfo(
size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None
),
num_layers=80,
num_attn_heads=64,
num_key_value_heads=64,
Expand All @@ -139,7 +148,9 @@ def _llama2_7b() -> LLaMAConfig:
return LLaMAConfig(
model_dim=4096,
max_seq_len=4096,
vocabulary_size=32000,
vocab_info=VocabularyInfo(
size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None
),
num_layers=32,
num_attn_heads=32,
num_key_value_heads=32,
Expand All @@ -154,7 +165,9 @@ def _llama2_13b() -> LLaMAConfig:
return LLaMAConfig(
model_dim=5120,
max_seq_len=4096,
vocabulary_size=32000,
vocab_info=VocabularyInfo(
size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None
),
num_layers=40,
num_attn_heads=40,
num_key_value_heads=40,
Expand All @@ -169,7 +182,9 @@ def _llama2_70b() -> LLaMAConfig:
return LLaMAConfig(
model_dim=8192,
max_seq_len=4096,
vocabulary_size=32000,
vocab_info=VocabularyInfo(
size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None
),
num_layers=80,
num_attn_heads=64,
num_key_value_heads=8,
Expand Down Expand Up @@ -222,21 +237,21 @@ def build_model(self) -> TransformerDecoderModel:

final_proj = Linear(
self.config.model_dim,
self.config.vocabulary_size,
self.config.vocab_info.size,
bias=False,
init_fn=init_final_projection,
device=self.device,
dtype=self.dtype,
)

return TransformerDecoderModel(
frontend, decoder, final_proj, target_pad_idx=None
frontend, decoder, final_proj, self.config.vocab_info
)

def build_frontend(self) -> TransformerFrontend:
"""Build a Transformer decoder front-end."""
embed = StandardEmbedding(
num_embeddings=self.config.vocabulary_size,
num_embeddings=self.config.vocab_info.size,
embedding_dim=self.config.model_dim,
device=self.device,
dtype=self.dtype,
Expand Down
15 changes: 9 additions & 6 deletions src/fairseq2/models/mistral/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataclasses import dataclass
from typing import Optional

from fairseq2.data import VocabularyInfo
from fairseq2.models.transformer import (
TransformerDecoderModel,
TransformerEmbeddingFrontend,
Expand Down Expand Up @@ -45,8 +46,8 @@ class MistralConfig:
max_seq_len: int
"""The expected maximum sequence length."""

vocabulary_size: int
"""The size of the vocabulary."""
vocab_info: VocabularyInfo
"""The vocabulary information."""

attn_window_len: int
"""The local attention window length."""
Expand Down Expand Up @@ -79,7 +80,9 @@ def _7b() -> MistralConfig:
return MistralConfig(
model_dim=4096,
max_seq_len=8192,
vocabulary_size=32000,
vocab_info=VocabularyInfo(
size=32000, unk_idx=0, bos_idx=1, eos_idx=2, pad_idx=None
),
attn_window_len=4096,
num_layers=32,
num_attn_heads=32,
Expand Down Expand Up @@ -131,21 +134,21 @@ def build_model(self) -> TransformerDecoderModel:

final_proj = Linear(
self.config.model_dim,
self.config.vocabulary_size,
self.config.vocab_info.size,
bias=False,
init_fn=init_final_projection,
device=self.device,
dtype=self.dtype,
)

return TransformerDecoderModel(
frontend, decoder, final_proj, target_pad_idx=None
frontend, decoder, final_proj, self.config.vocab_info
)

def build_frontend(self) -> TransformerFrontend:
"""Build a Transformer decoder front-end."""
embed = StandardEmbedding(
num_embeddings=self.config.vocabulary_size,
num_embeddings=self.config.vocab_info.size,
embedding_dim=self.config.model_dim,
device=self.device,
dtype=self.dtype,
Expand Down
34 changes: 15 additions & 19 deletions src/fairseq2/models/nllb/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,8 @@ class NllbConfig:
max_seq_len: int
"""The expected maximum sequence length."""

vocabulary_size: int
"""The size of the vocabulary."""

pad_idx: Optional[int]
"""The index of the PAD symbol in the vocabulary."""
vocab_info: VocabularyInfo
"""The vocabulary information."""

num_encoder_layers: int
"""The number of Transformer encoder layers."""
Expand All @@ -70,10 +67,6 @@ class NllbConfig:
dropout_p: float
"""The dropout probability in Transformer layers."""

def update_vocabulary(self, info: VocabularyInfo) -> None:
"""Update vocabulary configuration from ``info``."""
self.vocabulary_size, self.pad_idx = info.size, info.pad_idx


nllb_archs = ArchitectureRegistry[NllbConfig]("nllb")

Expand All @@ -86,8 +79,9 @@ def _dense_1b() -> NllbConfig:
return NllbConfig(
model_dim=1024,
max_seq_len=1024,
vocabulary_size=256206,
pad_idx=0,
vocab_info=VocabularyInfo(
size=256206, unk_idx=1, bos_idx=2, eos_idx=3, pad_idx=0
),
num_encoder_layers=24,
num_decoder_layers=24,
num_encoder_attn_heads=16,
Expand All @@ -102,8 +96,9 @@ def _dense_3b() -> NllbConfig:
return NllbConfig(
model_dim=2048,
max_seq_len=1024,
vocabulary_size=256206,
pad_idx=0,
vocab_info=VocabularyInfo(
size=256206, unk_idx=1, bos_idx=2, eos_idx=3, pad_idx=0
),
num_encoder_layers=24,
num_decoder_layers=24,
num_encoder_attn_heads=16,
Expand All @@ -118,8 +113,9 @@ def _dense_600m() -> NllbConfig:
return NllbConfig(
model_dim=1024,
max_seq_len=1024,
vocabulary_size=256206,
pad_idx=0,
vocab_info=VocabularyInfo(
size=256206, unk_idx=1, bos_idx=2, eos_idx=3, pad_idx=0
),
num_encoder_layers=12,
num_decoder_layers=12,
num_encoder_attn_heads=16,
Expand Down Expand Up @@ -177,15 +173,15 @@ def build_model(self) -> TransformerModel:
frontend,
decoder,
final_proj,
target_pad_idx=self.config.pad_idx,
self.config.vocab_info,
)

def build_embedding(self) -> StandardEmbedding:
"""Build an embedding table."""
return StandardEmbedding(
num_embeddings=self.config.vocabulary_size,
num_embeddings=self.config.vocab_info.size,
embedding_dim=self.config.model_dim,
pad_idx=self.config.pad_idx,
pad_idx=self.config.vocab_info.pad_idx,
init_fn=init_scaled_embedding,
device=self.device,
dtype=self.dtype,
Expand All @@ -196,7 +192,7 @@ def build_frontend(self, embed: Embedding) -> TransformerFrontend:
pos_encoder = SinusoidalPositionEncoder(
self.config.model_dim,
self.config.max_seq_len,
_legacy_pad_idx=self.config.pad_idx,
_legacy_pad_idx=self.config.vocab_info.pad_idx,
device=self.device,
)

Expand Down
Loading

0 comments on commit 785c3e4

Please sign in to comment.