Skip to content

Commit

Permalink
Refactor Embedding (facebookresearch#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Oct 3, 2023
1 parent 522d321 commit e081284
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/fairseq2/models/llama/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
TransformerFrontend,
)
from fairseq2.models.utils.arch_registry import ArchitectureRegistry
from fairseq2.nn.embedding import Embedding
from fairseq2.nn.embedding import StandardEmbedding
from fairseq2.nn.normalization import LayerNorm, RMSNorm
from fairseq2.nn.position_encoder import RotaryEncoder
from fairseq2.nn.transformer import (
Expand Down Expand Up @@ -240,7 +240,7 @@ def build_model(self) -> TransformerDecoderModel:

def build_frontend(self) -> TransformerFrontend:
"""Build a Transformer decoder front-end."""
embed = Embedding(
embed = StandardEmbedding(
num_embeddings=self.config.vocabulary_size,
embedding_dim=self.config.model_dim,
device=self.device,
Expand Down
4 changes: 2 additions & 2 deletions src/fairseq2/models/nllb/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
TransformerModel,
)
from fairseq2.models.utils.arch_registry import ArchitectureRegistry
from fairseq2.nn.embedding import Embedding
from fairseq2.nn.embedding import Embedding, StandardEmbedding
from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
from fairseq2.nn.projection import TiedProjection
from fairseq2.nn.transformer import (
Expand Down Expand Up @@ -182,7 +182,7 @@ def build_model(self) -> TransformerModel:

def build_embedding(self) -> Embedding:
"""Build an embedding table."""
return Embedding(
return StandardEmbedding(
num_embeddings=self.config.vocabulary_size,
embedding_dim=self.config.model_dim,
pad_idx=self.config.pad_idx,
Expand Down
4 changes: 2 additions & 2 deletions src/fairseq2/models/s2t_transformer/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
TransformerModel,
)
from fairseq2.models.utils.arch_registry import ArchitectureRegistry
from fairseq2.nn.embedding import Embedding
from fairseq2.nn.embedding import StandardEmbedding
from fairseq2.nn.position_encoder import PositionEncoder, SinusoidalPositionEncoder
from fairseq2.nn.transformer import (
SDPA,
Expand Down Expand Up @@ -283,7 +283,7 @@ def build_encoder_frontend(self) -> TransformerFrontend:

def build_decoder_frontend(self) -> TransformerFrontend:
"""Build a Transformer decoder front-end."""
embed = Embedding(
embed = StandardEmbedding(
num_embeddings=self.config.target_vocabulary_size,
embedding_dim=self.config.model_dim,
pad_idx=self.config.target_pad_idx,
Expand Down
78 changes: 55 additions & 23 deletions src/fairseq2/nn/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Optional, final

import torch
Expand All @@ -13,17 +14,65 @@
from torch.nn.functional import embedding
from torch.nn.parameter import Parameter

from fairseq2.typing import DataType, Device
from fairseq2.typing import DataType, Device, finaloverride


@final
class Embedding(Module):
class Embedding(Module, ABC):
"""Stores embeddings of a fixed dictionary and size."""

num_embeddings: int
embedding_dim: int
pad_idx: Optional[int]
padding_idx: Optional[int] # Compat

def __init__(
self, num_embeddings: int, embedding_dim: int, pad_idx: Optional[int] = None
) -> None:
"""
:param num_embeddings:
The size of the embedding table.
:param embedding_dim:
The dimensionality of returned embeddings.
:param pad_idx:
If not ``None``, entries at ``pad_idx`` do not contribute to the
gradient; therefore, the embedding at ``pad_idx`` is not updated
during training.
"""
super().__init__()

self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.pad_idx = pad_idx

# Alias field for compatibility with `torch.nn.Embedding`.
self.padding_idx = pad_idx

@abstractmethod
def forward(self, x: Tensor) -> Tensor:
"""
:param x:
The embedding indices. *Shape:* Any.
:returns:
The embeddings corresponding to the specified indices. *Shape:*
:math:`(*,E)`, where :math:`*` is the input shape and :math:`E` is
the dimensionality of the embeddings.
"""

def extra_repr(self) -> str:
""":meta private:"""
s = f"num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}"

if self.pad_idx is not None:
s = f"{s}, pad_idx={self.pad_idx}"

return s


@final
class StandardEmbedding(Embedding):
"""Stores embeddings of a fixed dictionary and size in an in-memory table."""

scaled: bool
weight: Parameter

Expand Down Expand Up @@ -51,16 +100,10 @@ def __init__(
:math:`\\mathcal{N}(0, \\frac{1}{\\text{embedding_dim}})`; otherwise,
from :math:`\\mathcal{N}(0, 1)`.
"""
super().__init__()
super().__init__(num_embeddings, embedding_dim, pad_idx)

self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.pad_idx = pad_idx
self.scaled = scaled

# Alias field for compatibility with `torch.nn.Embedding`.
self.padding_idx = pad_idx

self.weight = Parameter(
torch.empty((num_embeddings, embedding_dim), device=device, dtype=dtype)
)
Expand All @@ -78,24 +121,13 @@ def reset_parameters(self) -> None:
with torch.no_grad():
self.weight[self.pad_idx].fill_(0.0)

@finaloverride
def forward(self, x: Tensor) -> Tensor:
"""
:param x:
The embedding indices. *Shape:* Any.
:returns:
The embeddings corresponding to the specified indices. *Shape:*
:math:`(*,E)`, where :math:`*` is the input shape and :math:`E` is
the dimensionality of the embeddings.
"""
return embedding(x, self.weight, self.pad_idx)

def extra_repr(self) -> str:
""":meta private:"""
s = f"num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}"

if self.pad_idx is not None:
s = f"{s}, pad_idx={self.pad_idx}"
s = super().extra_repr()

if self.scaled:
s = f"{s}, scaled=True"
Expand Down

0 comments on commit e081284

Please sign in to comment.