From 7fca3cab5d4120a6a111d4c126cf33148968c6ba Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:27:03 +0100 Subject: [PATCH 01/37] add init function to the builders --- src/fairseq2/models/jepa/factory.py | 121 +++++++++++++++++-- src/fairseq2/models/jepa/model.py | 17 ++- src/fairseq2/models/vit/feature_extractor.py | 29 +++++ src/fairseq2/nn/normalization.py | 18 ++- 4 files changed, 161 insertions(+), 24 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 0038299a2..329f2055a 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -6,10 +6,14 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass, field +from functools import partial +import math from typing import Final, cast -from torch.nn import GELU +import torch +from torch.nn import GELU, Module from fairseq2.config_registry import ConfigRegistry from fairseq2.models.jepa.model import JepaModel @@ -23,6 +27,7 @@ from fairseq2.nn import ( InterpolatedPositionEncoder, LayerNorm, + Linear, Sinusoidal2dPositionEncoder, Sinusoidal3dPositionEncoder, StandardLayerNorm, @@ -39,6 +44,7 @@ TransformerNormOrder, create_default_sdpa, ) +from fairseq2.nn.transformer.residual import DropPathResidualConnect from fairseq2.typing import DataType, Device JEPA_FAMILY: Final = "jepa" @@ -96,10 +102,16 @@ class JepaEncoderConfig: The ratio of the dimensionality of the inner projection layers in feed-forward networks to :attr:`model_dim`. """ + + init_std: float = 0.02 + """std to initialize the weights and bias for linear and LayerNorm layers""" dropout_p: float = 0.0 """The dropout probability on outputs of Transformer layers.""" + droppath_p: float = 0.0 + """The probability of output sequence drop.""" + uniform_power: bool = False """ If ``True``, each patch dimension will have equal representation in the @@ -181,6 +193,8 @@ def build_frontend(self) -> TransformerFrontend: def build_feature_extractor(self) -> PatchFeatureExtractor: config = self._config + + conv_init_fn = partial(init_with_explicit_bounds, std=config.init_std) num_patch_dims = len(config.patch_dims) @@ -191,6 +205,7 @@ def build_feature_extractor(self) -> PatchFeatureExtractor: config.num_input_channels, config.model_dim, patch_3d_dims, + init_fn=conv_init_fn, device=self._device, dtype=self._dtype, ) @@ -201,6 +216,7 @@ def build_feature_extractor(self) -> PatchFeatureExtractor: config.num_input_channels, config.model_dim, patch_2d_dims, + init_fn=conv_init_fn, device=self._device, dtype=self._dtype, ) @@ -255,7 +271,7 @@ def build_encoder(self) -> TransformerEncoder: num_layers = config.num_encoder_layers - layers = [self.build_encoder_layer() for _ in range(num_layers)] + layers = [self.build_encoder_layer(i) for i in range(num_layers)] return StandardTransformerEncoder( layers, @@ -265,12 +281,14 @@ def build_encoder(self) -> TransformerEncoder: dtype=self._dtype, ) - def build_encoder_layer(self) -> TransformerEncoderLayer: + def build_encoder_layer(self, layer_id: int) -> TransformerEncoderLayer: config = self._config - self_attn = self.build_attention() + self_attn = self.build_attention(layer_id) - ffn = self.build_ffn() + ffn = self.build_ffn(layer_id) + + drop_path = DropPathResidualConnect(drop_p=config.droppath_p) return StandardTransformerEncoderLayer( self_attn, @@ -278,47 +296,87 @@ def build_encoder_layer(self) -> TransformerEncoderLayer: dropout_p=config.dropout_p, norm_order=TransformerNormOrder.PRE, layer_norm_factory=self.build_layer_norm, + self_attn_residual=drop_path, + ffn_residual=drop_path, device=self._device, dtype=self._dtype, ) - def build_attention(self) -> MultiheadAttention: + def build_attention(self, layer_id: int) -> MultiheadAttention: config = self._config sdpa = create_default_sdpa(attn_dropout_p=config.attn_dropout_p) + proj = self.build_projection(layer_id) + return StandardMultiheadAttention( config.model_dim, config.num_encoder_attn_heads, sdpa=sdpa, bias=config.qkv_bias, + output_proj=proj, output_proj_bias=True, device=self._device, dtype=self._dtype, ) - def build_ffn(self) -> FeedForwardNetwork: + def build_projection(self, layer_id: int) -> Linear: + config = self._config + + proj_init_fn: Callable[[Linear], None] = partial( + init_with_explicit_bounds, std=config.init_std + ) + + proj = Linear( + config.model_dim, + config.model_dim, + bias=True, + init_fn=proj_init_fn, + device=self._device, + dtype=self._dtype, + ) + + # rescale the linear layer + proj.weight.data.div_(math.sqrt(2.0 * layer_id)) + + return proj + + def build_ffn(self, layer_id: int) -> FeedForwardNetwork: config = self._config - return StandardFeedForwardNetwork( + proj_init_fn = partial(init_with_explicit_bounds, std=config.init_std) + + ffn = StandardFeedForwardNetwork( config.model_dim, int(config.model_dim * config.ffn_inner_dim_ratio), bias=True, inner_activation=GELU(), + proj_init_fn=proj_init_fn, norm_order=TransformerNormOrder.PRE, device=self._device, dtype=self._dtype, ) - @staticmethod + # rescale the last layer + proj = ffn.output_proj + assert isinstance(proj, Linear), f"Invalid projection type: {type(proj)}" + proj.weight.data.div_(math.sqrt(2.0 * layer_id)) + + return ffn + def build_layer_norm( + self, model_dim: int, *, device: Device | None = None, dtype: DataType | None = None, ) -> LayerNorm: + config = self._config + + layer_norm_init_fn = partial(init_with_explicit_bounds, std=config.init_std) + return StandardLayerNorm( - model_dim, bias=True, eps=1e-6, device=device, dtype=dtype + model_dim, bias=True, eps=1e-6, init_fn=layer_norm_init_fn, device=device, dtype=dtype ) @@ -329,3 +387,46 @@ def create_jepa_model( dtype: DataType | None = None, ) -> JepaModel: return JepaBuilder(config, device=device, dtype=dtype).build_model() + + +def _norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + +def normalize_truncate( + tensor: torch.Tensor, + *, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> None: + + lower = _norm_cdf((a - mean) / std) + upper = _norm_cdf((b - mean) / std) + + tensor.uniform_(2 * lower - 1, 2 * upper - 1) + tensor.erfinv_() + + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + tensor.clamp_(min=a, max=b) + + +def init_with_explicit_bounds( + m: Module, + *, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +): + if not hasattr(m, "weight") or not hasattr(m, "bias"): + raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") + + with torch.no_grad(): + normalize_truncate(m.weight, mean=mean, std=std, a=a, b=b) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) diff --git a/src/fairseq2/models/jepa/model.py b/src/fairseq2/models/jepa/model.py index 8312901e0..281e8187b 100644 --- a/src/fairseq2/models/jepa/model.py +++ b/src/fairseq2/models/jepa/model.py @@ -6,7 +6,6 @@ from __future__ import annotations -from dataclasses import dataclass from typing import final from torch.nn import Module @@ -43,11 +42,11 @@ def __init__( self.encoder_frontend = encoder_frontend self.encoder = encoder - def forward(self, batch: SequenceBatch) -> JepaOutput: - raise NotImplementedError() - - -@final -@dataclass -class JepaOutput: - pass + def forward(self, batch: SequenceBatch) -> SequenceBatch: + seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.padding_mask) + out_seqs, out_mask = self.encoder(seqs, padding_mask) # type: ignore[no-any-return] + + return SequenceBatch( + seqs=out_seqs, + padding_mask=out_mask, + ) diff --git a/src/fairseq2/models/vit/feature_extractor.py b/src/fairseq2/models/vit/feature_extractor.py index dce324d40..b3ec35033 100644 --- a/src/fairseq2/models/vit/feature_extractor.py +++ b/src/fairseq2/models/vit/feature_extractor.py @@ -7,6 +7,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable from typing import final from torch import Tensor @@ -54,6 +55,7 @@ class Conv2dPatchFeatureExtractor(PatchFeatureExtractor): """Extracts patch features from 2-dimensional inputs using convolution.""" conv: Conv2d + init_fn: Callable[[Conv2d], None] | None def __init__( self, @@ -61,6 +63,7 @@ def __init__( feature_dim: int, patch_dims: tuple[int, int], *, + init_fn: Callable[[Conv2d], None] | None = None, device: Device | None = None, dtype: DataType | None = None, ) -> None: @@ -81,6 +84,18 @@ def __init__( dtype=dtype, ) + self.init_fn = init_fn + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset the parameters and buffers of the module.""" + if self.init_fn is not None: + self.init_fn(self.conv) + else: + self.conv.reset_parameters() + + @override def forward(self, x: Tensor) -> Tensor: # (N, C, H_inp, W_inp) -> (N, H_out, W_out, E) @@ -92,6 +107,7 @@ class Conv3dPatchFeatureExtractor(PatchFeatureExtractor): """Extracts patch features from 3-dimensional inputs using convolution.""" conv: Conv3d + init_fn: Callable[[Conv3d], None] | None def __init__( self, @@ -99,6 +115,7 @@ def __init__( feature_dim: int, patch_dims: tuple[int, int, int], *, + init_fn: Callable[[Conv2d], None] | None = None, device: Device | None = None, dtype: DataType | None = None, ) -> None: @@ -118,6 +135,18 @@ def __init__( device=device, dtype=dtype, ) + + self.init_fn = init_fn + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset the parameters and buffers of the module.""" + if self.init_fn is not None: + self.init_fn(self.conv) + else: + self.conv.reset_parameters() + @override def forward(self, x: Tensor) -> Tensor: diff --git a/src/fairseq2/nn/normalization.py b/src/fairseq2/nn/normalization.py index 44fe37287..64656507e 100644 --- a/src/fairseq2/nn/normalization.py +++ b/src/fairseq2/nn/normalization.py @@ -7,7 +7,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Callable, Sequence from typing import Any, Literal, final import torch @@ -39,6 +39,8 @@ class LayerNorm(Module, ABC): elementwise_affine: bool weight: Parameter | None bias: Parameter | None + init_fn: Callable[[LayerNorm], None] | None + def __init__( self, @@ -47,6 +49,7 @@ def __init__( *, eps: float = 1e-5, elementwise_affine: bool = True, + init_fn: Callable[[LayerNorm], None] | None = None, device: Device | None = None, dtype: DataType | None = None, ) -> None: @@ -87,16 +90,21 @@ def __init__( ) else: self.register_parameter("bias", None) + + self.init_fn = init_fn self.reset_parameters() def reset_parameters(self) -> None: """Reset the parameters and buffers of the module.""" - if self.weight is not None: - nn.init.ones_(self.weight) + if self.init_fn is not None: + self.init_fn(self) + else: + if self.weight is not None: + nn.init.ones_(self.weight) - if self.bias is not None: - nn.init.zeros_(self.bias) + if self.bias is not None: + nn.init.zeros_(self.bias) @abstractmethod def forward(self, x: Tensor) -> Tensor: From 7623e1ba606a994dd68c2e2b36a79b17dfa6ce98 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:31:35 +0100 Subject: [PATCH 02/37] add builder skeleton for the AttentivePooler --- src/fairseq2/nn/transformer/encoder_layer.py | 241 ++++++++++++++++ .../recipes/jepa/attentive_factory.py | 271 ++++++++++++++++++ 2 files changed, 512 insertions(+) create mode 100644 src/fairseq2/recipes/jepa/attentive_factory.py diff --git a/src/fairseq2/nn/transformer/encoder_layer.py b/src/fairseq2/nn/transformer/encoder_layer.py index 580af8748..96b48fc66 100644 --- a/src/fairseq2/nn/transformer/encoder_layer.py +++ b/src/fairseq2/nn/transformer/encoder_layer.py @@ -246,3 +246,244 @@ def extra_repr(self) -> str: s = super().extra_repr() return f"{s}, norm_order={self.norm_order.name}" + + +class CrossAttentionTransformerEncoderLayer(TransformerEncoderLayer): + """Represents a Transformer encoder layer as described in + :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`. + """ + + self_attn: MultiheadAttention + self_attn_norm: LayerNorm | None + self_attn_dropout: Dropout | None + self_attn_residual: ResidualConnect + self_attn_layer_norm: LayerNorm + ffn: FeedForwardNetwork + ffn_dropout: Dropout | None + ffn_residual: ResidualConnect + ffn_layer_norm: LayerNorm + norm_order: TransformerNormOrder + + def __init__( + self, + self_attn: MultiheadAttention, + ffn: FeedForwardNetwork, + *, + dropout_p: float = 0.0, + norm_order: TransformerNormOrder = TransformerNormOrder.POST, + layer_norm_factory: LayerNormFactory | None = None, + self_attn_residual: ResidualConnect | None = None, + ffn_residual: ResidualConnect | None = None, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + """ + :param self_attn: + The self attention layer. + :param ffn: + The feed-forward network. + :param dropout_p: + The dropout probability on outputs of the self attention layer and + the feed-forward network. + :param norm_order: + The Layer Normalization order. + :param layer_norm_factory: + The factory to construct the Layer Normalization modules. + :param self_attn_residual: + The residual connection between the input and output of the self + attention layer. + :param ffn_residual: + The residual connection between the input and output of the + feed-forward network. + """ + model_dim = self_attn.model_dim + + super().__init__(model_dim) + + if layer_norm_factory is None: + layer_norm_factory = make_standard_layer_norm + + self_attn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) + + if norm_order != TransformerNormOrder.POST: + self.self_attn_layer_norm = self_attn_layer_norm + + self.self_attn = self_attn + + if norm_order == TransformerNormOrder.PRE_WITH_NORMFORMER: + self.self_attn_norm = layer_norm_factory( + model_dim, device=device, dtype=dtype + ) + else: + self.register_module("self_attn_norm", None) + + if dropout_p > 0.0: + self.self_attn_dropout = Dropout(dropout_p) + else: + self.register_module("self_attn_dropout", None) + + if self_attn_residual is None: + self_attn_residual = StandardResidualConnect() + + self.self_attn_residual = self_attn_residual + + if norm_order == TransformerNormOrder.POST: + self.self_attn_layer_norm = self_attn_layer_norm + + ffn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) + + if norm_order != TransformerNormOrder.POST: + self.ffn_layer_norm = ffn_layer_norm + + self.ffn = ffn + + if dropout_p > 0.0: + self.ffn_dropout = Dropout(dropout_p) + else: + self.register_module("ffn_dropout", None) + + if ffn_residual is None: + ffn_residual = StandardResidualConnect() + + self.ffn_residual = ffn_residual + + if norm_order == TransformerNormOrder.POST: + self.ffn_layer_norm = ffn_layer_norm + + self.norm_order = norm_order + + @abstractmethod + def forward( + self, + seqs: Tensor, + padding_mask: PaddingMask | None, + self_attn_mask: AttentionMask | None = None, + ) -> tuple[Tensor, PaddingMask | None]: + pass + + def _forward_self_attn( + self, + seqs: Tensor, + padding_mask: PaddingMask | None, + self_attn_mask: AttentionMask | None, + ) -> Tensor: + residual = seqs + + if self.norm_order != TransformerNormOrder.POST: + seqs = self.self_attn_layer_norm(seqs) + + seqs = self.self_attn( + seqs, + padding_mask, + keys=seqs, + key_padding_mask=padding_mask, + values=seqs, + attn_mask=self_attn_mask, + ) + + if self.self_attn_norm is not None: + seqs = self.self_attn_norm(seqs) + + if self.self_attn_dropout is not None: + seqs = self.self_attn_dropout(seqs) + + seqs = self.self_attn_residual(seqs, residual) + + if self.norm_order == TransformerNormOrder.POST: + seqs = self.self_attn_layer_norm(seqs) + + return seqs + + def _forward_ffn(self, seqs: Tensor) -> Tensor: + residual = seqs + + if self.norm_order != TransformerNormOrder.POST: + seqs = self.ffn_layer_norm(seqs) + + seqs = self.ffn(seqs) + + if self.ffn_dropout is not None: + seqs = self.ffn_dropout(seqs) + + seqs = self.ffn_residual(seqs, residual) + + if self.norm_order == TransformerNormOrder.POST: + seqs = self.ffn_layer_norm(seqs) + + return seqs + + def extra_repr(self) -> str: + """:meta private:""" + s = super().extra_repr() + + return f"{s}, norm_order={self.norm_order.name}" + + +class CrossAttentionEncoderLayer(StandardTransformerEncoderLayer): + """Represents a Transformer encoder layer as described in + :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`. + """ + + @override + def forward( + self, + seqs: Tensor, + padding_mask: PaddingMask | None, + self_attn_mask: AttentionMask | None = None, + ) -> tuple[Tensor, PaddingMask | None]: + seqs = self._forward_self_attn(seqs, padding_mask, self_attn_mask) + + seqs = self._forward_ffn(seqs) + + return seqs, padding_mask + + def _forward_self_attn( + self, + seqs: Tensor, + padding_mask: PaddingMask | None, + self_attn_mask: AttentionMask | None, + ) -> Tensor: + residual = seqs + + if self.norm_order != TransformerNormOrder.POST: + seqs = self.self_attn_layer_norm(seqs) + + seqs = self.self_attn( + seqs, + padding_mask, + keys=seqs, + key_padding_mask=padding_mask, + values=seqs, + attn_mask=self_attn_mask, + ) + + if self.self_attn_norm is not None: + seqs = self.self_attn_norm(seqs) + + if self.self_attn_dropout is not None: + seqs = self.self_attn_dropout(seqs) + + seqs = self.self_attn_residual(seqs, residual) + + if self.norm_order == TransformerNormOrder.POST: + seqs = self.self_attn_layer_norm(seqs) + + return seqs + + def _forward_ffn(self, seqs: Tensor) -> Tensor: + residual = seqs + + if self.norm_order != TransformerNormOrder.POST: + seqs = self.ffn_layer_norm(seqs) + + seqs = self.ffn(seqs) + + if self.ffn_dropout is not None: + seqs = self.ffn_dropout(seqs) + + seqs = self.ffn_residual(seqs, residual) + + if self.norm_order == TransformerNormOrder.POST: + seqs = self.ffn_layer_norm(seqs) + + return seqs \ No newline at end of file diff --git a/src/fairseq2/recipes/jepa/attentive_factory.py b/src/fairseq2/recipes/jepa/attentive_factory.py new file mode 100644 index 000000000..db6949ca5 --- /dev/null +++ b/src/fairseq2/recipes/jepa/attentive_factory.py @@ -0,0 +1,271 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Sequence, final + +from torch import Tensor +from torch.nn import Module, ModuleList + +from fairseq2.models.jepa.factory import JepaEncoderConfig +from fairseq2.models.jepa.model import JepaModel +from fairseq2.models.sequence import SequenceBatch +from fairseq2.models.transformer.factory import TransformerConfig +from fairseq2.nn.projection import Linear +from fairseq2.nn.normalization import LayerNorm +from fairseq2.nn.transformer.encoder import TransformerEncoder +from fairseq2.nn.transformer.encoder_layer import StandardTransformerEncoderLayer +from fairseq2.nn.transformer.layer_norm import ( + LayerNormFactory, + make_standard_layer_norm, +) +from fairseq2.typing import Device, DataType + + +@dataclass(kw_only=True) +class AttentivePoolerConfig: + """Holds the configuration for attentive classifier + + The default value is from the AttentiveClassifier + (https://github.com/facebookresearch/jepa/blob/main/src/models/attentive_pooler.py) + """ + + model_dim: int = 768 + """Embedding dimension""" + + depth: int = 1 + """The depth of attention layers. The first one is a thin cross-attention module and all + others are standard multi-head attention""" + + num_queries: int = 1 + """Number of queries for the cross attention layer""" + + num_encoder_attn_heads: int = 12 + """The number of attention heads in encoder layers.""" + + attn_dropout_p: float = 0.0 + """The dropout probability on attention weights.""" + + ffn_inner_dim_ratio: float = 4.0 + """ + The ratio of the dimensionality of the inner projection layers in + feed-forward networks to :attr:`embed_dim`. + """ + + dropout_p: float = 0.0 + """The dropout probability on outputs of Transformer layers.""" + + init_std: float = 0.02 + """std to initialize the weights and bias for linear and LayerNorm layers""" + + +@final +class AttentivePoolerBuilder: + """Build an attentive pooler used for attentive probing evaluation""" + + _config: AttentivePoolerConfig + _device: Device | None + _dtype: DataType | None + + def __init__( + self, + config: AttentivePoolerConfig, + *, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + self._config = config + + self._device, self._dtype = device, dtype + + def build_model(self) -> AttentivePooler: + config = self._config + + cross_attn = self.build_cross_attention() + + if config.depth > 1: + attn_layers = [self.build_encoder_layer(i) for i in range(config.depth - 1)] + else: + attn_layers = None + + return AttentivePooler( + cross_attn, + attn_layers, + num_queries=config.num_queries, + layer_norm_factory=self.build_layer_norm, + device=self._device, + dtype=self._dtype, + ) + + +@final +class AttentivePooler(TransformerEncoder): + def __init__( + self, + config: AttentivePoolerConfig, + *, + depth: int = 1, + dropout_p: float = 0.0, + layer_norm_factory: LayerNormFactory | None = None, + layernorm_init_fn: Callable[[LayerNorm], None] | None = None, + proj_init_fn: Callable[[Linear], None] | None = None, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + """ + :param layers: + The encoder layers. + :param dropout_p: + The dropout probability on encoder outputs. + :param layer_norm_factory: + The factory to construct the Layer Normalization module. + """ + model_dim = layer_list[0].model_dim + + super().__init__(model_dim) + + if layer_norm_factory is None: + layer_norm_factory = make_standard_layer_norm + + layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) + + cross_attn_layer = StandardTransformerEncoderLayer( + model_dim, bias=True, eps=1e-6, init_fn=layernorm_init_fn, device=device, dtype=dtype + ) + + + if dropout_p > 0.0: + self.dropout = Dropout(dropout_p) + else: + self.register_module("dropout", None) + + self.norm_order = norm_order + + @override + def forward( + self, seqs: Tensor, padding_mask: PaddingMask | None + ) -> tuple[Tensor, PaddingMask | None]: + if self._layer_output_hooks and self.layer_drop_p > 0.0 and self.training: + raise InvalidOperationError( + "The layer output hooks cannot be run when LayerDrop is enabled." + ) + + num_layers = len(self.layers) + + if self.self_attn_mask_factory is None: + self_attn_mask = None + else: + self_attn_mask = self.self_attn_mask_factory( + seqs, keys=seqs, training=self.training + ) + + for layer_idx, (layer, drop) in enumerate(self._drop_iter()): + layer_output, layer_padding_mask = layer(seqs, padding_mask, self_attn_mask) + + if drop: + seqs = _record_drop_for_backward(seqs, layer_output) + + continue + + seqs, padding_mask = layer_output, layer_padding_mask + + for hook in self._layer_output_hooks.values(): + if not hook(layer_idx, seqs, padding_mask, num_layers): + break + + if self.layer_norm is not None: + seqs = self.layer_norm(seqs) + + if self.dropout is not None: + seqs = self.dropout(seqs) + + return seqs, padding_mask + + def _drop_iter(self) -> Iterator[tuple[Module, bool]]: + if self.training and self.layer_drop_p > 0.0: + prob_dist = torch.rand( + len(self.layers), generator=self.generator, device=CPU + ) + else: + prob_dist = None + + for idx, m in enumerate(self.layers): + drop = prob_dist is not None and float(prob_dist[idx]) <= self.layer_drop_p + + yield m, drop + + def extra_repr(self) -> str: + """:meta private:""" + s = super().extra_repr() + + if self.self_attn_mask_factory is not None: + self_attn_mask_factory = getattr( + self.self_attn_mask_factory, "__name__", self.self_attn_mask_factory + ) + + s = f"{s}, self_attn_mask_factory={self_attn_mask_factory}" + + if self.layer_drop_p > 0.0: + s = f"{s}, layer_drop_p={self.layer_drop_p:G}" + + return f"{s}, norm_order={self.norm_order.name}" + + +@final +class JepaForClassification(Module): + """ + Represents a pretrained Jepa model, with an attentive probing layer for + classfication tasks. See + * :cite:t:`https://doi.org/10.48550/arXiv.2301.08243` + * :cite:t:`https://doi.org/10.48550/arXiv.2404.08471` + """ + jepa: JepaModel + attentive_pooler: AttentivePooler + head: Linear + + def __init__( + self, + jepa: JepaModel, + attentive_pooler: TransformerEncoder, + head: Linear, + ) -> None: + super().__init__() + + self.model_dim = jepa.model_dim + + self.jepa = jepa + self.attentive_pooler = attentive_pooler + self.head = head + + # TODO: Move to builder + # normalize_truncate(self.query_tokens, std=init_std) + + def forward(self, batch: SequenceBatch) -> Tensor: + seqs = self.jepa(batch) + seqs = self.attentive_pooler(seqs) + output = self.head(self) + return output + + +@dataclass(kw_only=True) +class JepaProbeConfig: + """ + Holds the configuration of a probing model + + TODO: Move to fairseq2.models.jepa + """ + + encoder_config: JepaEncoderConfig = field( + default_factory=lambda: JepaEncoderConfig() + ) + """The configuration of the Vision Transformer encoder.""" + + attentive_config: TransformerConfig = field( + default_factory=lambda: TransformerConfig() + ) \ No newline at end of file From 7b959fe21e9fe83dbb68e79c7f46fd7704392141 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:58:52 +0100 Subject: [PATCH 03/37] refactor init_module function --- src/fairseq2/models/jepa/factory.py | 57 +++-------------------------- src/fairseq2/nn/utils/module.py | 43 ++++++++++++++++++++++ 2 files changed, 49 insertions(+), 51 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 329f2055a..91059ef50 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -12,8 +12,7 @@ import math from typing import Final, cast -import torch -from torch.nn import GELU, Module +from torch.nn import GELU from fairseq2.config_registry import ConfigRegistry from fairseq2.models.jepa.model import JepaModel @@ -45,6 +44,7 @@ create_default_sdpa, ) from fairseq2.nn.transformer.residual import DropPathResidualConnect +from fairseq2.nn.utils.module import init_truncated_uniforma_weights_and_bias as init_module from fairseq2.typing import DataType, Device JEPA_FAMILY: Final = "jepa" @@ -194,7 +194,7 @@ def build_frontend(self) -> TransformerFrontend: def build_feature_extractor(self) -> PatchFeatureExtractor: config = self._config - conv_init_fn = partial(init_with_explicit_bounds, std=config.init_std) + conv_init_fn = partial(init_module, std=config.init_std) num_patch_dims = len(config.patch_dims) @@ -323,9 +323,7 @@ def build_attention(self, layer_id: int) -> MultiheadAttention: def build_projection(self, layer_id: int) -> Linear: config = self._config - proj_init_fn: Callable[[Linear], None] = partial( - init_with_explicit_bounds, std=config.init_std - ) + proj_init_fn: Callable[[Linear], None] = partial(init_module, std=config.init_std) proj = Linear( config.model_dim, @@ -344,7 +342,7 @@ def build_projection(self, layer_id: int) -> Linear: def build_ffn(self, layer_id: int) -> FeedForwardNetwork: config = self._config - proj_init_fn = partial(init_with_explicit_bounds, std=config.init_std) + proj_init_fn = partial(init_module, std=config.init_std) ffn = StandardFeedForwardNetwork( config.model_dim, @@ -373,7 +371,7 @@ def build_layer_norm( ) -> LayerNorm: config = self._config - layer_norm_init_fn = partial(init_with_explicit_bounds, std=config.init_std) + layer_norm_init_fn = partial(init_module, std=config.init_std) return StandardLayerNorm( model_dim, bias=True, eps=1e-6, init_fn=layer_norm_init_fn, device=device, dtype=dtype @@ -387,46 +385,3 @@ def create_jepa_model( dtype: DataType | None = None, ) -> JepaModel: return JepaBuilder(config, device=device, dtype=dtype).build_model() - - -def _norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - -def normalize_truncate( - tensor: torch.Tensor, - *, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -) -> None: - - lower = _norm_cdf((a - mean) / std) - upper = _norm_cdf((b - mean) / std) - - tensor.uniform_(2 * lower - 1, 2 * upper - 1) - tensor.erfinv_() - - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - tensor.clamp_(min=a, max=b) - - -def init_with_explicit_bounds( - m: Module, - *, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -): - if not hasattr(m, "weight") or not hasattr(m, "bias"): - raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") - - with torch.no_grad(): - normalize_truncate(m.weight, mean=mean, std=std, a=a, b=b) - if m.bias is not None: - torch.nn.init.zeros_(m.bias) diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index e76200599..fd464eefa 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -6,6 +6,7 @@ from __future__ import annotations +import math import re from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass @@ -570,3 +571,45 @@ def get_module_size(module: Module) -> ModuleSizeInfo: info.total_size_bytes += size_bytes return info + + +def normalize_truncate( + tensor: Tensor, + *, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> None: + + def _norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + lower = _norm_cdf((a - mean) / std) + upper = _norm_cdf((b - mean) / std) + + tensor.uniform_(2 * lower - 1, 2 * upper - 1) + tensor.erfinv_() + + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + tensor.clamp_(min=a, max=b) + + +def init_truncated_uniforma_weights_and_bias( + m: Module, + *, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +): + if not hasattr(m, "weight") or not hasattr(m, "bias"): + raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") + + with torch.no_grad(): + normalize_truncate(m.weight, mean=mean, std=std, a=a, b=b) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) From cef6687203c24ea045c38e28903857b401505e44 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:58:52 +0100 Subject: [PATCH 04/37] refactor init_module function --- src/fairseq2/models/jepa/factory.py | 57 +++-------------------------- src/fairseq2/nn/utils/module.py | 43 ++++++++++++++++++++++ 2 files changed, 49 insertions(+), 51 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 329f2055a..91059ef50 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -12,8 +12,7 @@ import math from typing import Final, cast -import torch -from torch.nn import GELU, Module +from torch.nn import GELU from fairseq2.config_registry import ConfigRegistry from fairseq2.models.jepa.model import JepaModel @@ -45,6 +44,7 @@ create_default_sdpa, ) from fairseq2.nn.transformer.residual import DropPathResidualConnect +from fairseq2.nn.utils.module import init_truncated_uniforma_weights_and_bias as init_module from fairseq2.typing import DataType, Device JEPA_FAMILY: Final = "jepa" @@ -194,7 +194,7 @@ def build_frontend(self) -> TransformerFrontend: def build_feature_extractor(self) -> PatchFeatureExtractor: config = self._config - conv_init_fn = partial(init_with_explicit_bounds, std=config.init_std) + conv_init_fn = partial(init_module, std=config.init_std) num_patch_dims = len(config.patch_dims) @@ -323,9 +323,7 @@ def build_attention(self, layer_id: int) -> MultiheadAttention: def build_projection(self, layer_id: int) -> Linear: config = self._config - proj_init_fn: Callable[[Linear], None] = partial( - init_with_explicit_bounds, std=config.init_std - ) + proj_init_fn: Callable[[Linear], None] = partial(init_module, std=config.init_std) proj = Linear( config.model_dim, @@ -344,7 +342,7 @@ def build_projection(self, layer_id: int) -> Linear: def build_ffn(self, layer_id: int) -> FeedForwardNetwork: config = self._config - proj_init_fn = partial(init_with_explicit_bounds, std=config.init_std) + proj_init_fn = partial(init_module, std=config.init_std) ffn = StandardFeedForwardNetwork( config.model_dim, @@ -373,7 +371,7 @@ def build_layer_norm( ) -> LayerNorm: config = self._config - layer_norm_init_fn = partial(init_with_explicit_bounds, std=config.init_std) + layer_norm_init_fn = partial(init_module, std=config.init_std) return StandardLayerNorm( model_dim, bias=True, eps=1e-6, init_fn=layer_norm_init_fn, device=device, dtype=dtype @@ -387,46 +385,3 @@ def create_jepa_model( dtype: DataType | None = None, ) -> JepaModel: return JepaBuilder(config, device=device, dtype=dtype).build_model() - - -def _norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - -def normalize_truncate( - tensor: torch.Tensor, - *, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -) -> None: - - lower = _norm_cdf((a - mean) / std) - upper = _norm_cdf((b - mean) / std) - - tensor.uniform_(2 * lower - 1, 2 * upper - 1) - tensor.erfinv_() - - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - tensor.clamp_(min=a, max=b) - - -def init_with_explicit_bounds( - m: Module, - *, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -): - if not hasattr(m, "weight") or not hasattr(m, "bias"): - raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") - - with torch.no_grad(): - normalize_truncate(m.weight, mean=mean, std=std, a=a, b=b) - if m.bias is not None: - torch.nn.init.zeros_(m.bias) diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index e76200599..fd464eefa 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -6,6 +6,7 @@ from __future__ import annotations +import math import re from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass @@ -570,3 +571,45 @@ def get_module_size(module: Module) -> ModuleSizeInfo: info.total_size_bytes += size_bytes return info + + +def normalize_truncate( + tensor: Tensor, + *, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> None: + + def _norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + lower = _norm_cdf((a - mean) / std) + upper = _norm_cdf((b - mean) / std) + + tensor.uniform_(2 * lower - 1, 2 * upper - 1) + tensor.erfinv_() + + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + tensor.clamp_(min=a, max=b) + + +def init_truncated_uniforma_weights_and_bias( + m: Module, + *, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +): + if not hasattr(m, "weight") or not hasattr(m, "bias"): + raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") + + with torch.no_grad(): + normalize_truncate(m.weight, mean=mean, std=std, a=a, b=b) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) From c7a902127c8c67971bef79fd9f5668787658e3e9 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:52:35 +0100 Subject: [PATCH 05/37] update cross attention layer --- src/fairseq2/nn/transformer/encoder_layer.py | 57 +++++++++++--------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/src/fairseq2/nn/transformer/encoder_layer.py b/src/fairseq2/nn/transformer/encoder_layer.py index 96b48fc66..227a1a02f 100644 --- a/src/fairseq2/nn/transformer/encoder_layer.py +++ b/src/fairseq2/nn/transformer/encoder_layer.py @@ -9,8 +9,8 @@ from abc import ABC, abstractmethod from typing import final -from torch import Tensor -from torch.nn import Dropout, Module +from torch import Tensor, zeros +from torch.nn import Dropout, Module, Parameter from typing_extensions import override from fairseq2.nn.normalization import LayerNorm @@ -253,11 +253,11 @@ class CrossAttentionTransformerEncoderLayer(TransformerEncoderLayer): :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`. """ - self_attn: MultiheadAttention - self_attn_norm: LayerNorm | None - self_attn_dropout: Dropout | None - self_attn_residual: ResidualConnect - self_attn_layer_norm: LayerNorm + cross_attn: MultiheadAttention + cross_attn_norm: LayerNorm | None + cross_attn_dropout: Dropout | None + cross_attn_residual: ResidualConnect + cross_attn_layer_norm: LayerNorm ffn: FeedForwardNetwork ffn_dropout: Dropout | None ffn_residual: ResidualConnect @@ -266,69 +266,73 @@ class CrossAttentionTransformerEncoderLayer(TransformerEncoderLayer): def __init__( self, - self_attn: MultiheadAttention, + cross_attn: MultiheadAttention, ffn: FeedForwardNetwork, *, + num_queries: int = 1, dropout_p: float = 0.0, norm_order: TransformerNormOrder = TransformerNormOrder.POST, layer_norm_factory: LayerNormFactory | None = None, - self_attn_residual: ResidualConnect | None = None, + cross_attn_residual: ResidualConnect | None = None, ffn_residual: ResidualConnect | None = None, + device: Device | None = None, dtype: DataType | None = None, ) -> None: """ - :param self_attn: - The self attention layer. + :param cross_attn: + The cross attention layer. + :param num_queries: + The number of queries used to cross-attend on the output of the self :param ffn: The feed-forward network. :param dropout_p: - The dropout probability on outputs of the self attention layer and + The dropout probability on outputs of the cross attention layer and the feed-forward network. :param norm_order: The Layer Normalization order. :param layer_norm_factory: The factory to construct the Layer Normalization modules. - :param self_attn_residual: + :param cross_attn_residual: The residual connection between the input and output of the self attention layer. :param ffn_residual: The residual connection between the input and output of the feed-forward network. """ - model_dim = self_attn.model_dim + model_dim = cross_attn.model_dim super().__init__(model_dim) if layer_norm_factory is None: layer_norm_factory = make_standard_layer_norm - self_attn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) + cross_attn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) if norm_order != TransformerNormOrder.POST: - self.self_attn_layer_norm = self_attn_layer_norm + self.cross_attn_layer_norm = cross_attn_layer_norm - self.self_attn = self_attn + self.cross_attn = cross_attn if norm_order == TransformerNormOrder.PRE_WITH_NORMFORMER: - self.self_attn_norm = layer_norm_factory( + self.cross_attn_norm = layer_norm_factory( model_dim, device=device, dtype=dtype ) else: - self.register_module("self_attn_norm", None) + self.register_module("cross_attn_norm", None) if dropout_p > 0.0: - self.self_attn_dropout = Dropout(dropout_p) + self.cross_attn_dropout = Dropout(dropout_p) else: self.register_module("self_attn_dropout", None) - if self_attn_residual is None: - self_attn_residual = StandardResidualConnect() + if cross_attn_residual is None: + cross_attn_residual = StandardResidualConnect() - self.self_attn_residual = self_attn_residual + self.cross_attn_residual = cross_attn_residual if norm_order == TransformerNormOrder.POST: - self.self_attn_layer_norm = self_attn_layer_norm + self.cross_attn_layer_norm = cross_attn_residual ffn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) @@ -352,6 +356,11 @@ def __init__( self.norm_order = norm_order + self.query_tokens = Parameter(zeros(1, num_queries, model_dim)) + + self.init_weight() + + @abstractmethod def forward( self, From ab93be9e01a2c309521d529fe254de8779a37e42 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:55:01 +0100 Subject: [PATCH 06/37] update cross attn layer --- src/fairseq2/nn/transformer/encoder_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fairseq2/nn/transformer/encoder_layer.py b/src/fairseq2/nn/transformer/encoder_layer.py index 227a1a02f..4bd38558c 100644 --- a/src/fairseq2/nn/transformer/encoder_layer.py +++ b/src/fairseq2/nn/transformer/encoder_layer.py @@ -495,4 +495,4 @@ def _forward_ffn(self, seqs: Tensor) -> Tensor: if self.norm_order == TransformerNormOrder.POST: seqs = self.ffn_layer_norm(seqs) - return seqs \ No newline at end of file + return seqs From f745beabfe350770855dcfab514f0ed18c6eff22 Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Thu, 19 Dec 2024 15:22:13 +0000 Subject: [PATCH 07/37] Cosmetic updates --- src/fairseq2/models/jepa/factory.py | 102 +++++++++++-------- src/fairseq2/models/jepa/model.py | 10 +- src/fairseq2/models/vit/feature_extractor.py | 6 +- src/fairseq2/nn/normalization.py | 3 +- src/fairseq2/nn/utils/module.py | 15 ++- 5 files changed, 74 insertions(+), 62 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 91059ef50..48cffbd22 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -6,12 +6,12 @@ from __future__ import annotations -from collections.abc import Callable +import math from dataclasses import dataclass, field from functools import partial -import math from typing import Final, cast +import torch from torch.nn import GELU from fairseq2.config_registry import ConfigRegistry @@ -44,7 +44,9 @@ create_default_sdpa, ) from fairseq2.nn.transformer.residual import DropPathResidualConnect -from fairseq2.nn.utils.module import init_truncated_uniforma_weights_and_bias as init_module +from fairseq2.nn.utils.module import ( + init_truncated_uniforma_weights_and_bias as init_module, +) from fairseq2.typing import DataType, Device JEPA_FAMILY: Final = "jepa" @@ -102,15 +104,21 @@ class JepaEncoderConfig: The ratio of the dimensionality of the inner projection layers in feed-forward networks to :attr:`model_dim`. """ - + init_std: float = 0.02 - """std to initialize the weights and bias for linear and LayerNorm layers""" + """ + The standard deviation to initialize weights and biases of projection and + normalization layers. + """ dropout_p: float = 0.0 """The dropout probability on outputs of Transformer layers.""" droppath_p: float = 0.0 - """The probability of output sequence drop.""" + """ + The probability of dropping sequences from outputs of multi-head attention + and feed-forward network layers before adding residuals. + """ uniform_power: bool = False """ @@ -193,8 +201,10 @@ def build_frontend(self) -> TransformerFrontend: def build_feature_extractor(self) -> PatchFeatureExtractor: config = self._config - - conv_init_fn = partial(init_module, std=config.init_std) + + init_std = config.init_std + + init_conv = partial(init_module, std=init_std) num_patch_dims = len(config.patch_dims) @@ -205,7 +215,7 @@ def build_feature_extractor(self) -> PatchFeatureExtractor: config.num_input_channels, config.model_dim, patch_3d_dims, - init_fn=conv_init_fn, + init_fn=init_conv, device=self._device, dtype=self._dtype, ) @@ -216,7 +226,7 @@ def build_feature_extractor(self) -> PatchFeatureExtractor: config.num_input_channels, config.model_dim, patch_2d_dims, - init_fn=conv_init_fn, + init_fn=init_conv, device=self._device, dtype=self._dtype, ) @@ -281,13 +291,13 @@ def build_encoder(self) -> TransformerEncoder: dtype=self._dtype, ) - def build_encoder_layer(self, layer_id: int) -> TransformerEncoderLayer: + def build_encoder_layer(self, layer_idx: int) -> TransformerEncoderLayer: config = self._config - self_attn = self.build_attention(layer_id) + self_attn = self.build_attention(layer_idx) + + ffn = self.build_ffn(layer_idx) - ffn = self.build_ffn(layer_id) - drop_path = DropPathResidualConnect(drop_p=config.droppath_p) return StandardTransformerEncoderLayer( @@ -302,66 +312,67 @@ def build_encoder_layer(self, layer_id: int) -> TransformerEncoderLayer: dtype=self._dtype, ) - def build_attention(self, layer_id: int) -> MultiheadAttention: + def build_attention(self, layer_idx: int) -> MultiheadAttention: config = self._config sdpa = create_default_sdpa(attn_dropout_p=config.attn_dropout_p) - proj = self.build_projection(layer_id) + output_proj = self.build_mha_output_projection(layer_idx) return StandardMultiheadAttention( config.model_dim, config.num_encoder_attn_heads, sdpa=sdpa, bias=config.qkv_bias, - output_proj=proj, - output_proj_bias=True, + output_proj=output_proj, device=self._device, dtype=self._dtype, ) - def build_projection(self, layer_id: int) -> Linear: + def build_mha_output_projection(self, layer_idx: int) -> Linear: config = self._config - proj_init_fn: Callable[[Linear], None] = partial(init_module, std=config.init_std) + init_std = config.init_std - proj = Linear( + def init_projection(proj: Linear) -> None: + init_module(proj, std=init_std) + + with torch.no_grad(): + proj.weight.div_(math.sqrt(2.0 * layer_idx)) + + return Linear( config.model_dim, config.model_dim, bias=True, - init_fn=proj_init_fn, + init_fn=init_projection, device=self._device, dtype=self._dtype, ) - # rescale the linear layer - proj.weight.data.div_(math.sqrt(2.0 * layer_id)) + def build_ffn(self, layer_idx: int) -> FeedForwardNetwork: + config = self._config - return proj + init_std = config.init_std - def build_ffn(self, layer_id: int) -> FeedForwardNetwork: - config = self._config + def init_projection(proj: Linear) -> None: + init_module(proj, std=init_std) + + with torch.no_grad(): + proj.weight.div_(math.sqrt(2.0 * layer_idx)) - proj_init_fn = partial(init_module, std=config.init_std) + inner_dim = int(config.model_dim * config.ffn_inner_dim_ratio) - ffn = StandardFeedForwardNetwork( + return StandardFeedForwardNetwork( config.model_dim, - int(config.model_dim * config.ffn_inner_dim_ratio), + inner_dim, bias=True, inner_activation=GELU(), - proj_init_fn=proj_init_fn, + proj_init_fn=init_projection, norm_order=TransformerNormOrder.PRE, device=self._device, dtype=self._dtype, ) - # rescale the last layer - proj = ffn.output_proj - assert isinstance(proj, Linear), f"Invalid projection type: {type(proj)}" - proj.weight.data.div_(math.sqrt(2.0 * layer_id)) - - return ffn - def build_layer_norm( self, model_dim: int, @@ -370,11 +381,18 @@ def build_layer_norm( dtype: DataType | None = None, ) -> LayerNorm: config = self._config - - layer_norm_init_fn = partial(init_module, std=config.init_std) - + + init_std = config.init_std + + init_layer_norm = partial(init_module, std=init_std) + return StandardLayerNorm( - model_dim, bias=True, eps=1e-6, init_fn=layer_norm_init_fn, device=device, dtype=dtype + model_dim, + bias=True, + eps=1e-6, + init_fn=init_layer_norm, + device=device, + dtype=dtype, ) diff --git a/src/fairseq2/models/jepa/model.py b/src/fairseq2/models/jepa/model.py index 281e8187b..b1413328a 100644 --- a/src/fairseq2/models/jepa/model.py +++ b/src/fairseq2/models/jepa/model.py @@ -44,9 +44,7 @@ def __init__( def forward(self, batch: SequenceBatch) -> SequenceBatch: seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.padding_mask) - out_seqs, out_mask = self.encoder(seqs, padding_mask) # type: ignore[no-any-return] - - return SequenceBatch( - seqs=out_seqs, - padding_mask=out_mask, - ) + + seqs, padding_mask = self.encoder(seqs, padding_mask) + + return SequenceBatch(seqs, padding_mask) diff --git a/src/fairseq2/models/vit/feature_extractor.py b/src/fairseq2/models/vit/feature_extractor.py index b3ec35033..50747eaaf 100644 --- a/src/fairseq2/models/vit/feature_extractor.py +++ b/src/fairseq2/models/vit/feature_extractor.py @@ -95,7 +95,6 @@ def reset_parameters(self) -> None: else: self.conv.reset_parameters() - @override def forward(self, x: Tensor) -> Tensor: # (N, C, H_inp, W_inp) -> (N, H_out, W_out, E) @@ -115,7 +114,7 @@ def __init__( feature_dim: int, patch_dims: tuple[int, int, int], *, - init_fn: Callable[[Conv2d], None] | None = None, + init_fn: Callable[[Conv3d], None] | None = None, device: Device | None = None, dtype: DataType | None = None, ) -> None: @@ -135,7 +134,7 @@ def __init__( device=device, dtype=dtype, ) - + self.init_fn = init_fn self.reset_parameters() @@ -147,7 +146,6 @@ def reset_parameters(self) -> None: else: self.conv.reset_parameters() - @override def forward(self, x: Tensor) -> Tensor: # (N, C, D_inp, H_inp, W_inp) -> (N, D_out, H_out, W_out, E) diff --git a/src/fairseq2/nn/normalization.py b/src/fairseq2/nn/normalization.py index 64656507e..c371bc156 100644 --- a/src/fairseq2/nn/normalization.py +++ b/src/fairseq2/nn/normalization.py @@ -41,7 +41,6 @@ class LayerNorm(Module, ABC): bias: Parameter | None init_fn: Callable[[LayerNorm], None] | None - def __init__( self, normalized_shape: int | Sequence[int] | Size, @@ -90,7 +89,7 @@ def __init__( ) else: self.register_parameter("bias", None) - + self.init_fn = init_fn self.reset_parameters() diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index fd464eefa..ddf4c7bf4 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -581,22 +581,21 @@ def normalize_truncate( a: float = -2.0, b: float = 2.0, ) -> None: - - def _norm_cdf(x): + def _norm_cdf(x: float) -> float: # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 lower = _norm_cdf((a - mean) / std) upper = _norm_cdf((b - mean) / std) - + tensor.uniform_(2 * lower - 1, 2 * upper - 1) tensor.erfinv_() - + tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) - + tensor.clamp_(min=a, max=b) - + def init_truncated_uniforma_weights_and_bias( m: Module, @@ -605,10 +604,10 @@ def init_truncated_uniforma_weights_and_bias( std: float = 1.0, a: float = -2.0, b: float = 2.0, -): +) -> None: if not hasattr(m, "weight") or not hasattr(m, "bias"): raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") - + with torch.no_grad(): normalize_truncate(m.weight, mean=mean, std=std, a=a, b=b) if m.bias is not None: From 60ec4c13af057b2b6edd6d70f06724220220d4fe Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:24:37 +0100 Subject: [PATCH 08/37] add forward() function --- src/fairseq2/nn/transformer/encoder_layer.py | 133 +++++-------------- 1 file changed, 32 insertions(+), 101 deletions(-) diff --git a/src/fairseq2/nn/transformer/encoder_layer.py b/src/fairseq2/nn/transformer/encoder_layer.py index 4bd38558c..487c160da 100644 --- a/src/fairseq2/nn/transformer/encoder_layer.py +++ b/src/fairseq2/nn/transformer/encoder_layer.py @@ -7,6 +7,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Callable from typing import final from torch import Tensor, zeros @@ -247,7 +248,7 @@ def extra_repr(self) -> str: return f"{s}, norm_order={self.norm_order.name}" - +@final class CrossAttentionTransformerEncoderLayer(TransformerEncoderLayer): """Represents a Transformer encoder layer as described in :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`. @@ -269,13 +270,11 @@ def __init__( cross_attn: MultiheadAttention, ffn: FeedForwardNetwork, *, - num_queries: int = 1, dropout_p: float = 0.0, norm_order: TransformerNormOrder = TransformerNormOrder.POST, layer_norm_factory: LayerNormFactory | None = None, cross_attn_residual: ResidualConnect | None = None, ffn_residual: ResidualConnect | None = None, - device: Device | None = None, dtype: DataType | None = None, ) -> None: @@ -307,10 +306,10 @@ def __init__( if layer_norm_factory is None: layer_norm_factory = make_standard_layer_norm - cross_attn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) + attn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) if norm_order != TransformerNormOrder.POST: - self.cross_attn_layer_norm = cross_attn_layer_norm + self.attn_layer_norm = attn_layer_norm self.cross_attn = cross_attn @@ -324,7 +323,7 @@ def __init__( if dropout_p > 0.0: self.cross_attn_dropout = Dropout(dropout_p) else: - self.register_module("self_attn_dropout", None) + self.register_module("cross_attn_dropout", None) if cross_attn_residual is None: cross_attn_residual = StandardResidualConnect() @@ -356,50 +355,52 @@ def __init__( self.norm_order = norm_order - self.query_tokens = Parameter(zeros(1, num_queries, model_dim)) - - self.init_weight() - - - @abstractmethod + @override def forward( self, - seqs: Tensor, + queries: Tensor, + keys: Tensor, padding_mask: PaddingMask | None, - self_attn_mask: AttentionMask | None = None, + cross_attn_mask: AttentionMask | None = None, ) -> tuple[Tensor, PaddingMask | None]: - pass + + seqs = self._forward_cross_attn(queries, keys, padding_mask, cross_attn_mask) - def _forward_self_attn( + seqs = self._forward_ffn(seqs) + + return seqs, padding_mask + + def _forward_cross_attn( self, - seqs: Tensor, + queries: Tensor, + keys: Tensor, padding_mask: PaddingMask | None, - self_attn_mask: AttentionMask | None, + cross_attn_mask: AttentionMask | None, ) -> Tensor: - residual = seqs + residual = queries if self.norm_order != TransformerNormOrder.POST: - seqs = self.self_attn_layer_norm(seqs) + keys = self.attn_layer_norm(keys) - seqs = self.self_attn( - seqs, + seqs = self.cross_attn( + queries, padding_mask, - keys=seqs, + keys=keys, key_padding_mask=padding_mask, - values=seqs, - attn_mask=self_attn_mask, + values=keys, + attn_mask=cross_attn_mask, ) - if self.self_attn_norm is not None: - seqs = self.self_attn_norm(seqs) + if self.cross_attn_norm is not None: + seqs = self.cross_attn_norm(seqs) - if self.self_attn_dropout is not None: - seqs = self.self_attn_dropout(seqs) + if self.cross_attn_dropout is not None: + seqs = self.cross_attn_dropout(seqs) - seqs = self.self_attn_residual(seqs, residual) + seqs = self.cross_attn_residual(seqs, residual) if self.norm_order == TransformerNormOrder.POST: - seqs = self.self_attn_layer_norm(seqs) + seqs = self.cross_attn_layer_norm(seqs) return seqs @@ -426,73 +427,3 @@ def extra_repr(self) -> str: s = super().extra_repr() return f"{s}, norm_order={self.norm_order.name}" - - -class CrossAttentionEncoderLayer(StandardTransformerEncoderLayer): - """Represents a Transformer encoder layer as described in - :cite:t:`https://doi.org/10.48550/arxiv.1706.03762`. - """ - - @override - def forward( - self, - seqs: Tensor, - padding_mask: PaddingMask | None, - self_attn_mask: AttentionMask | None = None, - ) -> tuple[Tensor, PaddingMask | None]: - seqs = self._forward_self_attn(seqs, padding_mask, self_attn_mask) - - seqs = self._forward_ffn(seqs) - - return seqs, padding_mask - - def _forward_self_attn( - self, - seqs: Tensor, - padding_mask: PaddingMask | None, - self_attn_mask: AttentionMask | None, - ) -> Tensor: - residual = seqs - - if self.norm_order != TransformerNormOrder.POST: - seqs = self.self_attn_layer_norm(seqs) - - seqs = self.self_attn( - seqs, - padding_mask, - keys=seqs, - key_padding_mask=padding_mask, - values=seqs, - attn_mask=self_attn_mask, - ) - - if self.self_attn_norm is not None: - seqs = self.self_attn_norm(seqs) - - if self.self_attn_dropout is not None: - seqs = self.self_attn_dropout(seqs) - - seqs = self.self_attn_residual(seqs, residual) - - if self.norm_order == TransformerNormOrder.POST: - seqs = self.self_attn_layer_norm(seqs) - - return seqs - - def _forward_ffn(self, seqs: Tensor) -> Tensor: - residual = seqs - - if self.norm_order != TransformerNormOrder.POST: - seqs = self.ffn_layer_norm(seqs) - - seqs = self.ffn(seqs) - - if self.ffn_dropout is not None: - seqs = self.ffn_dropout(seqs) - - seqs = self.ffn_residual(seqs, residual) - - if self.norm_order == TransformerNormOrder.POST: - seqs = self.ffn_layer_norm(seqs) - - return seqs From 854b68e1304f5a607842c8b9741df13ab3fe919a Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:41:23 +0100 Subject: [PATCH 09/37] Can's comments --- src/fairseq2/models/jepa/factory.py | 32 ++++++++++++++++++++++++++++- src/fairseq2/nn/utils/module.py | 4 +++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 48cffbd22..7b51ab8af 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -11,8 +11,14 @@ from functools import partial from typing import Final, cast +<<<<<<< HEAD import torch from torch.nn import GELU +======= +from torch import Tensor +import torch +from torch.nn import GELU, Module +>>>>>>> a09ad4fe (Can's comments) from fairseq2.config_registry import ConfigRegistry from fairseq2.models.jepa.model import JepaModel @@ -338,7 +344,7 @@ def init_projection(proj: Linear) -> None: init_module(proj, std=init_std) with torch.no_grad(): - proj.weight.div_(math.sqrt(2.0 * layer_idx)) + proj.weight.div_(math.sqrt(2.0 * (layer_idx + 1))) return Linear( config.model_dim, @@ -373,6 +379,13 @@ def init_projection(proj: Linear) -> None: dtype=self._dtype, ) + # rescale the last layer + proj = ffn.output_proj + assert isinstance(proj, Linear), f"Invalid projection type: {type(proj)}" + proj.weight.data.div_(math.sqrt(2.0 * (layer_id + 1))) + + return ffn + def build_layer_norm( self, model_dim: int, @@ -403,3 +416,20 @@ def create_jepa_model( dtype: DataType | None = None, ) -> JepaModel: return JepaBuilder(config, device=device, dtype=dtype).build_model() + + +def init_truncated_uniforma_weights_and_bias( + m: Module, + *, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +): + if not hasattr(m, "weight") or not hasattr(m, "bias"): + raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") + + with torch.no_grad(): + torch.nn.init.trunc_normal_(m.weight, mean=mean, std=std, a=a, b=b) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index ddf4c7bf4..baa240d85 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -6,7 +6,6 @@ from __future__ import annotations -import math import re from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass @@ -571,6 +570,7 @@ def get_module_size(module: Module) -> ModuleSizeInfo: info.total_size_bytes += size_bytes return info +<<<<<<< HEAD def normalize_truncate( @@ -612,3 +612,5 @@ def init_truncated_uniforma_weights_and_bias( normalize_truncate(m.weight, mean=mean, std=std, a=a, b=b) if m.bias is not None: torch.nn.init.zeros_(m.bias) +======= +>>>>>>> a09ad4fe (Can's comments) From 328f8ca1256d1b39a92abd8fd7bf20cd3c4e7847 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:43:17 +0100 Subject: [PATCH 10/37] fix git rebase --- src/fairseq2/models/jepa/factory.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 7b51ab8af..5d6dbc79b 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -11,14 +11,8 @@ from functools import partial from typing import Final, cast -<<<<<<< HEAD -import torch -from torch.nn import GELU -======= -from torch import Tensor import torch from torch.nn import GELU, Module ->>>>>>> a09ad4fe (Can's comments) from fairseq2.config_registry import ConfigRegistry from fairseq2.models.jepa.model import JepaModel From f57108d96243605a9c0779b3e76b4c0d8ce29153 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:45:43 +0100 Subject: [PATCH 11/37] fix git rebase --- src/fairseq2/nn/utils/module.py | 29 +---------------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index baa240d85..981ab0ac5 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -570,31 +570,6 @@ def get_module_size(module: Module) -> ModuleSizeInfo: info.total_size_bytes += size_bytes return info -<<<<<<< HEAD - - -def normalize_truncate( - tensor: Tensor, - *, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -) -> None: - def _norm_cdf(x: float) -> float: - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - lower = _norm_cdf((a - mean) / std) - upper = _norm_cdf((b - mean) / std) - - tensor.uniform_(2 * lower - 1, 2 * upper - 1) - tensor.erfinv_() - - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - tensor.clamp_(min=a, max=b) def init_truncated_uniforma_weights_and_bias( @@ -609,8 +584,6 @@ def init_truncated_uniforma_weights_and_bias( raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") with torch.no_grad(): - normalize_truncate(m.weight, mean=mean, std=std, a=a, b=b) + torch.nn.init.trunc_normal_(m.weight, mean=mean, std=std, a=a, b=b) if m.bias is not None: torch.nn.init.zeros_(m.bias) -======= ->>>>>>> a09ad4fe (Can's comments) From 2d27bab861b1d66bb85386ba3e179741ce72f741 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:54:01 +0100 Subject: [PATCH 12/37] lint --- src/fairseq2/models/jepa/factory.py | 17 +++++++---------- src/fairseq2/nn/utils/module.py | 17 ----------------- 2 files changed, 7 insertions(+), 27 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 5d6dbc79b..88ae20283 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -44,9 +44,6 @@ create_default_sdpa, ) from fairseq2.nn.transformer.residual import DropPathResidualConnect -from fairseq2.nn.utils.module import ( - init_truncated_uniforma_weights_and_bias as init_module, -) from fairseq2.typing import DataType, Device JEPA_FAMILY: Final = "jepa" @@ -204,7 +201,7 @@ def build_feature_extractor(self) -> PatchFeatureExtractor: init_std = config.init_std - init_conv = partial(init_module, std=init_std) + init_conv = partial(init_truncated_uniforma_weights_and_bias, std=init_std) num_patch_dims = len(config.patch_dims) @@ -335,7 +332,7 @@ def build_mha_output_projection(self, layer_idx: int) -> Linear: init_std = config.init_std def init_projection(proj: Linear) -> None: - init_module(proj, std=init_std) + init_truncated_uniforma_weights_and_bias(proj, std=init_std) with torch.no_grad(): proj.weight.div_(math.sqrt(2.0 * (layer_idx + 1))) @@ -355,14 +352,14 @@ def build_ffn(self, layer_idx: int) -> FeedForwardNetwork: init_std = config.init_std def init_projection(proj: Linear) -> None: - init_module(proj, std=init_std) + init_truncated_uniforma_weights_and_bias(proj, std=init_std) with torch.no_grad(): proj.weight.div_(math.sqrt(2.0 * layer_idx)) inner_dim = int(config.model_dim * config.ffn_inner_dim_ratio) - return StandardFeedForwardNetwork( + ffn = StandardFeedForwardNetwork( config.model_dim, inner_dim, bias=True, @@ -376,7 +373,7 @@ def init_projection(proj: Linear) -> None: # rescale the last layer proj = ffn.output_proj assert isinstance(proj, Linear), f"Invalid projection type: {type(proj)}" - proj.weight.data.div_(math.sqrt(2.0 * (layer_id + 1))) + proj.weight.data.div_(math.sqrt(2.0 * (layer_idx + 1))) return ffn @@ -391,7 +388,7 @@ def build_layer_norm( init_std = config.init_std - init_layer_norm = partial(init_module, std=init_std) + init_layer_norm = partial(init_truncated_uniforma_weights_and_bias, std=init_std) return StandardLayerNorm( model_dim, @@ -419,7 +416,7 @@ def init_truncated_uniforma_weights_and_bias( std: float = 1.0, a: float = -2.0, b: float = 2.0, -): +) -> None: if not hasattr(m, "weight") or not hasattr(m, "bias"): raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index 981ab0ac5..e76200599 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -570,20 +570,3 @@ def get_module_size(module: Module) -> ModuleSizeInfo: info.total_size_bytes += size_bytes return info - - -def init_truncated_uniforma_weights_and_bias( - m: Module, - *, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -) -> None: - if not hasattr(m, "weight") or not hasattr(m, "bias"): - raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") - - with torch.no_grad(): - torch.nn.init.trunc_normal_(m.weight, mean=mean, std=std, a=a, b=b) - if m.bias is not None: - torch.nn.init.zeros_(m.bias) From 8d7dfafedad4fbfbe229e09ce5567727731cf3bb Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:55:47 +0100 Subject: [PATCH 13/37] lint --- src/fairseq2/models/jepa/factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 88ae20283..006715928 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -355,7 +355,7 @@ def init_projection(proj: Linear) -> None: init_truncated_uniforma_weights_and_bias(proj, std=init_std) with torch.no_grad(): - proj.weight.div_(math.sqrt(2.0 * layer_idx)) + proj.weight.div_(math.sqrt(2.0 * (layer_idx + 1))) inner_dim = int(config.model_dim * config.ffn_inner_dim_ratio) From e86afaa3fa4737165cbcad43afd0126ff4464cff Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 18:02:33 +0100 Subject: [PATCH 14/37] flake8 --- src/fairseq2/models/jepa/factory.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 006715928..a06bc8c74 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -374,7 +374,7 @@ def init_projection(proj: Linear) -> None: proj = ffn.output_proj assert isinstance(proj, Linear), f"Invalid projection type: {type(proj)}" proj.weight.data.div_(math.sqrt(2.0 * (layer_idx + 1))) - + return ffn def build_layer_norm( @@ -407,7 +407,7 @@ def create_jepa_model( dtype: DataType | None = None, ) -> JepaModel: return JepaBuilder(config, device=device, dtype=dtype).build_model() - + def init_truncated_uniforma_weights_and_bias( m: Module, @@ -419,7 +419,7 @@ def init_truncated_uniforma_weights_and_bias( ) -> None: if not hasattr(m, "weight") or not hasattr(m, "bias"): raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") - + with torch.no_grad(): torch.nn.init.trunc_normal_(m.weight, mean=mean, std=std, a=a, b=b) if m.bias is not None: From 2a76edddf6ffdb25f4c3b4eacee92cc90e1c4073 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 18:08:35 +0100 Subject: [PATCH 15/37] remove commits remnant --- src/fairseq2/models/jepa/factory.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 31cce8553..0c26bfa73 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -360,7 +360,7 @@ def init_projection(proj: Linear) -> None: inner_dim = int(config.model_dim * config.ffn_inner_dim_ratio) - ffn = StandardFeedForwardNetwork( + return StandardFeedForwardNetwork( config.model_dim, inner_dim, bias=True, @@ -371,13 +371,6 @@ def init_projection(proj: Linear) -> None: dtype=self._dtype, ) - # rescale the last layer - proj = ffn.output_proj - assert isinstance(proj, Linear), f"Invalid projection type: {type(proj)}" - proj.weight.data.div_(math.sqrt(2.0 * (layer_idx + 1))) - - return ffn - def build_layer_norm( self, model_dim: int, From 250f1eed4e5ce8d6a2efef12e0a46d073624f36b Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 19 Dec 2024 18:09:40 +0100 Subject: [PATCH 16/37] black --- src/fairseq2/models/jepa/factory.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 0c26bfa73..63c28e404 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -44,7 +44,9 @@ create_default_sdpa, ) from fairseq2.nn.transformer.residual import DropPathResidualConnect -from fairseq2.nn.utils.module import init_truncated_uniforma_weights_and_bias as init_module +from fairseq2.nn.utils.module import ( + init_truncated_uniforma_weights_and_bias as init_module, +) from fairseq2.typing import DataType, Device JEPA_FAMILY: Final = "jepa" @@ -382,7 +384,9 @@ def build_layer_norm( init_std = config.init_std - init_layer_norm = partial(init_truncated_uniforma_weights_and_bias, std=init_std) + init_layer_norm = partial( + init_truncated_uniforma_weights_and_bias, std=init_std + ) return StandardLayerNorm( model_dim, @@ -401,7 +405,7 @@ def create_jepa_model( dtype: DataType | None = None, ) -> JepaModel: return JepaBuilder(config, device=device, dtype=dtype).build_model() - + def init_truncated_uniforma_weights_and_bias( m: Module, @@ -413,7 +417,7 @@ def init_truncated_uniforma_weights_and_bias( ) -> None: if not hasattr(m, "weight") or not hasattr(m, "bias"): raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") - + with torch.no_grad(): torch.nn.init.trunc_normal_(m.weight, mean=mean, std=std, a=a, b=b) if m.bias is not None: From f4aaf33c5ed7e41130560a1848de7779ff8d34ca Mon Sep 17 00:00:00 2001 From: tuantran user Date: Thu, 19 Dec 2024 17:25:39 +0000 Subject: [PATCH 17/37] black --- src/fairseq2/models/jepa/factory.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index a06bc8c74..63fe44e71 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -388,7 +388,9 @@ def build_layer_norm( init_std = config.init_std - init_layer_norm = partial(init_truncated_uniforma_weights_and_bias, std=init_std) + init_layer_norm = partial( + init_truncated_uniforma_weights_and_bias, std=init_std + ) return StandardLayerNorm( model_dim, From b8cfd501b2e8fbddcb0dee4569e4ce721468665b Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Fri, 20 Dec 2024 14:00:43 +0100 Subject: [PATCH 18/37] revert remnant codes --- src/fairseq2/nn/utils/module.py | 42 --------------------------------- 1 file changed, 42 deletions(-) diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index fd464eefa..89f3a6908 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -571,45 +571,3 @@ def get_module_size(module: Module) -> ModuleSizeInfo: info.total_size_bytes += size_bytes return info - - -def normalize_truncate( - tensor: Tensor, - *, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -) -> None: - - def _norm_cdf(x): - # Computes standard normal cumulative distribution function - return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 - - lower = _norm_cdf((a - mean) / std) - upper = _norm_cdf((b - mean) / std) - - tensor.uniform_(2 * lower - 1, 2 * upper - 1) - tensor.erfinv_() - - tensor.mul_(std * math.sqrt(2.0)) - tensor.add_(mean) - - tensor.clamp_(min=a, max=b) - - -def init_truncated_uniforma_weights_and_bias( - m: Module, - *, - mean: float = 0.0, - std: float = 1.0, - a: float = -2.0, - b: float = 2.0, -): - if not hasattr(m, "weight") or not hasattr(m, "bias"): - raise ValueError(f"Cannot initialize weights and bias of a {type(m)}") - - with torch.no_grad(): - normalize_truncate(m.weight, mean=mean, std=std, a=a, b=b) - if m.bias is not None: - torch.nn.init.zeros_(m.bias) From a6987abe6bc94017f16929fde5450aa7ea4290c0 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Fri, 20 Dec 2024 14:01:05 +0100 Subject: [PATCH 19/37] revert remnant codes --- src/fairseq2/nn/utils/module.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index 89f3a6908..e76200599 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -6,7 +6,6 @@ from __future__ import annotations -import math import re from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass From 837cc6c76def638981bb794edfaa09ce3c3089bc Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Fri, 20 Dec 2024 14:01:46 +0100 Subject: [PATCH 20/37] revert remnant codes --- src/fairseq2/nn/transformer/encoder_layer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/fairseq2/nn/transformer/encoder_layer.py b/src/fairseq2/nn/transformer/encoder_layer.py index 44cc1d98f..580af8748 100644 --- a/src/fairseq2/nn/transformer/encoder_layer.py +++ b/src/fairseq2/nn/transformer/encoder_layer.py @@ -7,11 +7,10 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Callable from typing import final -from torch import Tensor, zeros -from torch.nn import Dropout, Module, Parameter +from torch import Tensor +from torch.nn import Dropout, Module from typing_extensions import override from fairseq2.nn.normalization import LayerNorm From 25eb3f0c82e352b71da9dcb72270df78e3f42416 Mon Sep 17 00:00:00 2001 From: Tuan Tran Date: Fri, 20 Dec 2024 13:46:49 +0000 Subject: [PATCH 21/37] lint --- src/fairseq2/recipes/jepa/factory.py | 16 +++++++-------- src/fairseq2/recipes/jepa/models.py | 30 +++++++++++++--------------- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/src/fairseq2/recipes/jepa/factory.py b/src/fairseq2/recipes/jepa/factory.py index f6bacc26b..7a7281899 100644 --- a/src/fairseq2/recipes/jepa/factory.py +++ b/src/fairseq2/recipes/jepa/factory.py @@ -4,9 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math from dataclasses import dataclass, field from functools import partial -import math from typing import final import torch @@ -20,16 +20,13 @@ init_truncated_uniforma_weights_and_bias, ) from fairseq2.nn.normalization import LayerNorm, StandardLayerNorm -from fairseq2.nn.projection import IdentityProjection, Linear +from fairseq2.nn.projection import IdentityProjection, Linear, Projection from fairseq2.nn.transformer import ( FeedForwardNetwork, - TransformerDecoder, MultiheadAttention, StandardMultiheadAttention, create_default_sdpa, ) -from fairseq2.nn.transformer.decoder import StandardTransformerDecoder -from fairseq2.nn.transformer.decoder_layer import StandardTransformerDecoderLayer from fairseq2.nn.transformer.encoder import ( StandardTransformerEncoder, TransformerEncoder, @@ -45,7 +42,7 @@ CrossAttentionDecoder, JepaForClassification, ) -from fairseq2.typing import Device, DataType +from fairseq2.typing import DataType, Device @dataclass(kw_only=True) @@ -174,7 +171,7 @@ class AttentivePoolerBuilder: def __init__( self, - config: JepaForClassificationConfig, + config: AttentivePoolerConfig, *, device: Device | None = None, dtype: DataType | None = None, @@ -188,7 +185,8 @@ def build_pooler(self) -> AttentivePooler: def init_pool(pool: Tensor) -> None: std = config.init_std - init_truncated_uniforma_weights_and_bias(pool, std=std) + with torch.no_grad(): + torch.nn.init.trunc_normal_(pool, std=std) decoder = self.build_decoder() @@ -275,7 +273,7 @@ def build_attention( dtype=self._dtype, ) - def build_mha_output_projection(self, layer_idx: int) -> Linear: + def build_mha_output_projection(self, layer_idx: int) -> Projection: config = self._config init_std = config.init_std diff --git a/src/fairseq2/recipes/jepa/models.py b/src/fairseq2/recipes/jepa/models.py index bf27590c3..e590cf8aa 100644 --- a/src/fairseq2/recipes/jepa/models.py +++ b/src/fairseq2/recipes/jepa/models.py @@ -8,37 +8,35 @@ from collections.abc import Callable from typing import final -from typing_extensions import override -import torch +import torch from torch import Tensor -from torch.nn import Module, Parameter, Dropout +from torch.nn import Dropout, Module, Parameter +from typing_extensions import override from fairseq2.models.jepa.model import JepaModel from fairseq2.models.model import Model from fairseq2.models.sequence import SequenceBatch - from fairseq2.nn.incremental_state import IncrementalStateBag from fairseq2.nn.normalization import LayerNorm from fairseq2.nn.padding import PaddingMask from fairseq2.nn.projection import Projection from fairseq2.nn.transformer import ( - TransformerEncoder, FeedForwardNetwork, - MultiheadAttention, LayerNormFactory, - TransformerNormOrder, + MultiheadAttention, ResidualConnect, + TransformerEncoder, + TransformerNormOrder, make_standard_layer_norm, ) - from fairseq2.nn.transformer.residual import StandardResidualConnect -from fairseq2.typing import Device, DataType +from fairseq2.typing import DataType, Device class CrossAttentionDecoder(Module): """Represents a simple transformer decoder with only cross attention and layernorm""" - + model_dim: int cross_attn: MultiheadAttention cross_attn_dropout: Dropout | None @@ -51,7 +49,7 @@ class CrossAttentionDecoder(Module): def __init__( self, - cross_attn: MultiheadAttention | None, + cross_attn: MultiheadAttention, ffn: FeedForwardNetwork, *, dropout_p: float = 0.0, @@ -147,9 +145,6 @@ def forward( "`encoder_output` must not be `None` for encoder-decoder attention." ) - assert self.cross_attn_residual is not None - assert self.cross_attn_layer_norm is not None - seqs = self._forward_cross_attn( seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag ) @@ -168,6 +163,9 @@ def _forward_cross_attn( residual = seqs + assert self.cross_attn_residual is not None + assert self.cross_attn_layer_norm is not None + # Note that the cross-attention norm is applief on encoder output and not seqs if self.norm_order != TransformerNormOrder.POST: encoder_output = self.cross_attn_layer_norm(encoder_output) @@ -258,7 +256,7 @@ def __init__( self.num_pools = num_pools self.pool_layer = Parameter( - torch.empty(1, num_pools, self.model_dim), device=device, dtype=dtype + torch.empty(1, num_pools, self.model_dim, device=device, dtype=dtype) ) if init_fn: @@ -312,7 +310,7 @@ def forward(self, batch: SequenceBatch) -> Tensor: encoder_output: SequenceBatch = self.encoder(batch) seqs, _ = self.pooler(encoder_output.seqs, encoder_output.padding_mask) seqs = seqs.squeeze(1) - output = self.head(seqs) + output: Tensor = self.head(seqs) return output def extra_repr(self) -> str: From be3b6e2be70a0c2e45e82981a5153b7c73d87639 Mon Sep 17 00:00:00 2001 From: Tuan Tran Date: Fri, 20 Dec 2024 13:47:54 +0000 Subject: [PATCH 22/37] lint --- src/fairseq2/recipes/jepa/factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fairseq2/recipes/jepa/factory.py b/src/fairseq2/recipes/jepa/factory.py index 7a7281899..cb32ac936 100644 --- a/src/fairseq2/recipes/jepa/factory.py +++ b/src/fairseq2/recipes/jepa/factory.py @@ -260,7 +260,7 @@ def build_attention( sdpa = create_default_sdpa(attn_dropout_p=config.attn_dropout_p) if is_cross_attn: - output_proj = IdentityProjection(config.model_dim, config.model_dim) + output_proj: Projection = IdentityProjection(config.model_dim, config.model_dim) output_proj = self.build_mha_output_projection(layer_idx) return StandardMultiheadAttention( From 2f685e886277684cdc860fcf570cfe73828b9afc Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Sun, 22 Dec 2024 16:01:23 +0100 Subject: [PATCH 23/37] nit import clean --- src/fairseq2/models/jepa/factory.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index fb7755a81..152df5b22 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -8,7 +8,6 @@ import math from dataclasses import dataclass, field -from functools import partial from typing import Final, cast import torch From 919b305693d31188046e83118fa5ffec3fa00ec1 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Sun, 22 Dec 2024 17:13:59 +0100 Subject: [PATCH 24/37] nit rename layers --- src/fairseq2/recipes/jepa/factory.py | 6 +++--- src/fairseq2/recipes/jepa/models.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/fairseq2/recipes/jepa/factory.py b/src/fairseq2/recipes/jepa/factory.py index cb32ac936..d3894bf81 100644 --- a/src/fairseq2/recipes/jepa/factory.py +++ b/src/fairseq2/recipes/jepa/factory.py @@ -75,8 +75,8 @@ class AttentivePoolerConfig: layers will have an additive bias. """ - num_pools: int = 1 - """Number of attentive pools""" + num_queries: int = 1 + """Number of query tokens in the attention pool layer""" attn_dropout_p: float = 0.0 """The dropout probability on attention weights.""" @@ -198,7 +198,7 @@ def init_pool(pool: Tensor) -> None: return AttentivePooler( decoder=decoder, encoder=encoder, - num_pools=config.num_pools, + num_queries=config.num_queries, init_fn=init_pool, device=self._device, dtype=self._dtype, diff --git a/src/fairseq2/recipes/jepa/models.py b/src/fairseq2/recipes/jepa/models.py index e590cf8aa..4bfe06555 100644 --- a/src/fairseq2/recipes/jepa/models.py +++ b/src/fairseq2/recipes/jepa/models.py @@ -228,7 +228,7 @@ class AttentivePooler(Module): """ model_dim: int - num_pools: int + num_queries: int decoder: CrossAttentionDecoder encoder: TransformerEncoder | None init_fn: Callable[[Tensor], None] | None @@ -238,7 +238,7 @@ def __init__( decoder: CrossAttentionDecoder, encoder: TransformerEncoder | None, *, - num_pools: int = 1, + num_queries: int = 1, init_fn: Callable[[Tensor], None] | None = None, device: Device | None = None, dtype: DataType | None = None, @@ -254,9 +254,9 @@ def __init__( else: self.register_module("encoder", None) - self.num_pools = num_pools - self.pool_layer = Parameter( - torch.empty(1, num_pools, self.model_dim, device=device, dtype=dtype) + self.num_queries = num_queries + self.query_tokens = Parameter( + torch.empty(1, num_queries, self.model_dim, device=device, dtype=dtype) ) if init_fn: @@ -267,15 +267,15 @@ def forward( ) -> tuple[Tensor, PaddingMask | None]: if self.encoder: seqs, padding_mask = self.encoder(seqs, padding_mask) - pool_layer = self.pool_layer.repeat(len(seqs), 1, 1) - seqs, padding_mask = self.decoder(pool_layer, None, seqs, padding_mask) + queries = self.query_tokens.repeat(len(seqs), 1, 1) + seqs, padding_mask = self.decoder(queries, None, seqs, padding_mask) return seqs, padding_mask def extra_repr(self) -> str: """:meta private:""" s = super().extra_repr() - return f"{s}, model_dim={self.model_dim}, pools={self.num_pools}" + return f"{s}, model_dim={self.model_dim}, pools={self.num_queries}" @final From c77176702bd652ae9ba1cc38dfb4048b56e0f07b Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Sun, 22 Dec 2024 17:45:09 +0100 Subject: [PATCH 25/37] update factory --- src/fairseq2/recipes/jepa/factory.py | 61 +++++++++++++--------------- src/fairseq2/recipes/jepa/models.py | 10 ++--- 2 files changed, 32 insertions(+), 39 deletions(-) diff --git a/src/fairseq2/recipes/jepa/factory.py b/src/fairseq2/recipes/jepa/factory.py index d3894bf81..eb4f2989b 100644 --- a/src/fairseq2/recipes/jepa/factory.py +++ b/src/fairseq2/recipes/jepa/factory.py @@ -13,12 +13,8 @@ from torch import Tensor from torch.nn import GELU -from fairseq2.models.jepa.factory import ( - JepaBuilder, - JepaConfig, - JepaEncoderConfig, - init_truncated_uniforma_weights_and_bias, -) +from fairseq2.config_registry import ConfigRegistry +from fairseq2.models.jepa.factory import init_truncated_normal from fairseq2.nn.normalization import LayerNorm, StandardLayerNorm from fairseq2.nn.projection import IdentityProjection, Linear, Projection from fairseq2.nn.transformer import ( @@ -38,9 +34,9 @@ from fairseq2.nn.transformer.ffn import StandardFeedForwardNetwork from fairseq2.nn.transformer.norm_order import TransformerNormOrder from fairseq2.recipes.jepa.models import ( + AttentiveClassifier, AttentivePooler, CrossAttentionDecoder, - JepaForClassification, ) from fairseq2.typing import DataType, Device @@ -98,12 +94,7 @@ class AttentivePoolerConfig: @dataclass(kw_only=True) -class JepaForClassificationConfig: - - encoder_config: JepaEncoderConfig = field( - default_factory=lambda: JepaEncoderConfig() - ) - """The configuration of the Vision Transformer encoder.""" +class AttentiveClassifierConfig: pooler_config: AttentivePoolerConfig = field( default_factory=lambda: AttentivePoolerConfig() @@ -113,17 +104,22 @@ class JepaForClassificationConfig: """Size of classification logits""" +attentive_archs = ConfigRegistry[AttentiveClassifierConfig]() + +attentive_arch = attentive_archs.decorator + + @final -class JepaForClassificationBuilder: +class AttentiveClassifierBuilder: """Build a Jepa model that is fine-tuned for classification""" - _config: JepaForClassificationConfig + _config: AttentiveClassifierConfig _device: Device | None _dtype: DataType | None def __init__( self, - config: JepaForClassificationConfig, + config: AttentiveClassifierConfig, *, device: Device | None = None, dtype: DataType | None = None, @@ -131,28 +127,20 @@ def __init__( self._config = config - pretrained_config = JepaConfig(encoder_config=config.encoder_config) - - self.encoder_builder = JepaBuilder( - pretrained_config, device=device, dtype=dtype - ) - self.pooler_builer = AttentivePoolerBuilder( config.pooler_config, device=device, dtype=dtype ) self._device, self._dtype = device, dtype - def build_model(self) -> JepaForClassification: + def build_model(self) -> AttentiveClassifier: config = self._config - encoder = self.encoder_builder.build_model() + pooler = self.pooler_builer.build_model() - pooler = self.pooler_builer.build_pooler() + head = Linear(config.pooler_config.model_dim, config.num_classes, bias=True) - head = Linear(config.encoder_config.model_dim, config.num_classes, bias=True) - - return JepaForClassification(encoder, pooler, head) + return AttentiveClassifier(pooler, head) @final @@ -180,7 +168,7 @@ def __init__( self._device, self._dtype = device, dtype - def build_pooler(self) -> AttentivePooler: + def build_model(self) -> AttentivePooler: config = self._config def init_pool(pool: Tensor) -> None: @@ -279,7 +267,7 @@ def build_mha_output_projection(self, layer_idx: int) -> Projection: init_std = config.init_std def init_projection(proj: Linear) -> None: - init_truncated_uniforma_weights_and_bias(proj, std=init_std) + init_truncated_normal(proj.weight, proj.bias, std=init_std) with torch.no_grad(): proj.weight.div_(math.sqrt(2.0 * (layer_idx + 1))) @@ -299,7 +287,7 @@ def build_ffn(self, layer_idx: int) -> FeedForwardNetwork: init_std = config.init_std def init_projection(proj: Linear) -> None: - init_truncated_uniforma_weights_and_bias(proj, std=init_std) + init_truncated_normal(proj.weight, proj.bias, std=init_std) with torch.no_grad(): proj.weight.div_(math.sqrt(2.0 * (layer_idx))) @@ -329,7 +317,7 @@ def build_layer_norm( init_std = config.init_std init_layer_norm = partial( - init_truncated_uniforma_weights_and_bias, std=init_std + init_truncated_normal, std=init_std ) return StandardLayerNorm( @@ -340,3 +328,12 @@ def build_layer_norm( device=device, dtype=dtype, ) + + +def create_attentive_pooler( + config: AttentivePoolerConfig, + *, + device: Device | None = None, + dtype: DataType | None = None, +) -> AttentivePooler: + return AttentivePoolerBuilder(config, device=device, dtype=dtype).build_model() diff --git a/src/fairseq2/recipes/jepa/models.py b/src/fairseq2/recipes/jepa/models.py index 4bfe06555..52cd9e12a 100644 --- a/src/fairseq2/recipes/jepa/models.py +++ b/src/fairseq2/recipes/jepa/models.py @@ -279,7 +279,7 @@ def extra_repr(self) -> str: @final -class JepaForClassification(Model): +class AttentiveClassifier(Model): """ Represents a pretrained Jepa model, with an attentive probing layer for classfication tasks. See @@ -288,27 +288,23 @@ class JepaForClassification(Model): """ model_dim: int - encoder: JepaModel pooler: AttentivePooler head: Projection def __init__( self, - encoder: JepaModel, pooler: AttentivePooler, head: Projection, ) -> None: super().__init__() - self.model_dim = encoder.model_dim + self.model_dim = pooler.model_dim - self.encoder = encoder self.pooler = pooler self.head = head def forward(self, batch: SequenceBatch) -> Tensor: - encoder_output: SequenceBatch = self.encoder(batch) - seqs, _ = self.pooler(encoder_output.seqs, encoder_output.padding_mask) + seqs, _ = self.pooler(batch.seqs, batch.padding_mask) seqs = seqs.squeeze(1) output: Tensor = self.head(seqs) return output From a43f68333794a454bb33c8cffe17926c7f45ec3d Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Sun, 22 Dec 2024 19:06:23 +0100 Subject: [PATCH 26/37] lint --- src/fairseq2/recipes/jepa/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fairseq2/recipes/jepa/models.py b/src/fairseq2/recipes/jepa/models.py index 52cd9e12a..83e8e50b9 100644 --- a/src/fairseq2/recipes/jepa/models.py +++ b/src/fairseq2/recipes/jepa/models.py @@ -14,7 +14,6 @@ from torch.nn import Dropout, Module, Parameter from typing_extensions import override -from fairseq2.models.jepa.model import JepaModel from fairseq2.models.model import Model from fairseq2.models.sequence import SequenceBatch from fairseq2.nn.incremental_state import IncrementalStateBag From b76bbe5d64f56d3866d651f126b2f80af69d4784 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Tue, 24 Dec 2024 08:40:15 +0100 Subject: [PATCH 27/37] fix typo --- src/fairseq2/recipes/jepa/factory.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/fairseq2/recipes/jepa/factory.py b/src/fairseq2/recipes/jepa/factory.py index eb4f2989b..ddcab3dbc 100644 --- a/src/fairseq2/recipes/jepa/factory.py +++ b/src/fairseq2/recipes/jepa/factory.py @@ -103,6 +103,16 @@ class AttentiveClassifierConfig: num_classes: int = 1000 """Size of classification logits""" +@dataclass(kw_only=True) +class JepaForClassificationConfig: + encoder_model_name: str = "" + + attentive_classifier_name: str = "" + + def __post_init__(self): + if not self.encoder_model_name: + raise ValueError("Must specify encoder_model_name") + attentive_archs = ConfigRegistry[AttentiveClassifierConfig]() From 88b2d95a659a035b3c7716c66ca53b24e2c3be21 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 26 Dec 2024 17:01:53 +0100 Subject: [PATCH 28/37] allow unstricted model loading --- src/fairseq2/models/loader.py | 10 +++++++++- src/fairseq2/nn/utils/module.py | 11 +++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/fairseq2/models/loader.py b/src/fairseq2/models/loader.py index 552cf482f..c2fdaf428 100644 --- a/src/fairseq2/models/loader.py +++ b/src/fairseq2/models/loader.py @@ -82,6 +82,7 @@ def __call__( dtype: DataType | None = None, force: bool = False, progress: bool = True, + strict_state_dict: bool = True, ) -> ModelT_co: """ :param model_name_or_card: @@ -98,6 +99,9 @@ def __call__( cache. :param progress: If ``True``, displays a progress bar to stderr. + :param strict_state_dict: + If ``True``, checkpoint' parameters and layers must be identical to + the model state dict) :returns: A model loaded from the checkpoint of ``model_name_or_card``. @@ -201,6 +205,8 @@ def __call__( dtype: DataType | None = None, force: bool = False, progress: bool = True, + strict_state_dict: bool = True, + ) -> ModelT: if isinstance(model_name_or_card, AssetCard): card = model_name_or_card @@ -355,7 +361,7 @@ def __call__( consume_prefix_in_state_dict_if_present(state_dict, prefix="module.") try: - load_state_dict(model, state_dict) + load_state_dict(model, state_dict, strict=strict_state_dict) except (KeyError, ValueError) as ex: raise AssetError( f"{card.name} cannot be loaded. See nested exception for details." @@ -396,6 +402,7 @@ def __call__( dtype: DataType | None = None, force: bool = False, progress: bool = True, + strict_state_dict: bool = True, ) -> ModelT: if isinstance(model_name_or_card, AssetCard): card = model_name_or_card @@ -419,6 +426,7 @@ def __call__( dtype=dtype, force=force, progress=progress, + strict_state_dict=strict_state_dict, ) def register(self, family: str, loader: ModelLoader[ModelT]) -> None: diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index e76200599..02ab7d1d2 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -431,16 +431,15 @@ def broadcast_module( _broadcast_coalesced(pg, tensors, bucket_size, source_rank) -def load_state_dict(module: Module, state_dict: Mapping[str, object]) -> None: +def load_state_dict(module: Module, state_dict: Mapping[str, object], strict: bool = True) -> None: """Copy parameters and buffers from ``state_dict`` into ``module`` and its descendant modules. - This implementation internally calls :meth:`Module.load_state_dict()` with - ``strict`` set to ``True``, and also enforces that ``state_dict`` does not - contain any keys corresponding to descendants that are set to ``None`` via - :meth:`Module.register_module()`. + This implementation internally calls :meth:`Module.load_state_dict()`, and also enforces that + ``state_dict`` does not contain any keys corresponding to descendants that are set to ``None`` + via :meth:`Module.register_module()`. """ - module.load_state_dict(state_dict, strict=True) + module.load_state_dict(state_dict, strict=strict) unexpected_keys = [] From 57987dac3d39b180286fbf748b67f0549c39b63a Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Mon, 23 Dec 2024 13:48:16 +0000 Subject: [PATCH 29/37] Feedback commit --- .../models/jepa/classifier/__init__.py | 0 .../models/jepa/classifier/factory.py | 151 ++++++++ src/fairseq2/models/jepa/classifier/model.py | 230 ++++++++++++ src/fairseq2/models/jepa/factory.py | 6 +- src/fairseq2/recipes/jepa/factory.py | 349 ------------------ src/fairseq2/recipes/jepa/models.py | 315 ---------------- 6 files changed, 385 insertions(+), 666 deletions(-) create mode 100644 src/fairseq2/models/jepa/classifier/__init__.py create mode 100644 src/fairseq2/models/jepa/classifier/factory.py create mode 100644 src/fairseq2/models/jepa/classifier/model.py delete mode 100644 src/fairseq2/recipes/jepa/factory.py delete mode 100644 src/fairseq2/recipes/jepa/models.py diff --git a/src/fairseq2/models/jepa/classifier/__init__.py b/src/fairseq2/models/jepa/classifier/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/fairseq2/models/jepa/classifier/factory.py b/src/fairseq2/models/jepa/classifier/factory.py new file mode 100644 index 000000000..947d100da --- /dev/null +++ b/src/fairseq2/models/jepa/classifier/factory.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from typing import final + +from fairseq2.config_registry import ConfigRegistry +from fairseq2.models.jepa.classifier.model import ( + AttentivePooler, + CrossAttentionDecoderLayer, + JepaClassifierModel, +) +from fairseq2.models.jepa.factory import JepaEncoderBuilder, JepaEncoderConfig +from fairseq2.nn.projection import IdentityProjection, Linear, Projection +from fairseq2.nn.transformer import ( + MultiheadAttention, + StandardMultiheadAttention, + create_default_sdpa, +) +from fairseq2.typing import DataType, Device + + +@dataclass(kw_only=True) +class JepaClassifierConfig: + encoder_config: JepaEncoderConfig = field( + default_factory=lambda: JepaEncoderConfig() + ) + """The configuration of the vision encoder.""" + + pool_depth: int = 1 + """The pool depth (minimum 1 decoder layer)""" + + num_queries: int = 1 + """Number of query tokens in the attention pool layer""" + + num_classes: int = 1000 + """Size of classification logits""" + + +jepa_classifier_archs = ConfigRegistry[JepaClassifierConfig]() + +jepa_classifier_arch = jepa_classifier_archs.decorator + + +@final +class JepaClassifierBuilder: + """Build a JEPA model fine-tuned for classification""" + + _config: JepaClassifierConfig + _encoder_builder: JepaEncoderBuilder + _device: Device | None + _dtype: DataType | None + + def __init__( + self, + config: JepaClassifierConfig, + *, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + self._config = config + + self._encoder_builder = JepaEncoderBuilder( + config.encoder_config, device=device, dtype=dtype + ) + + self._device, self._dtype = device, dtype + + def build_model(self) -> JepaClassifierModel: + config = self._config + + encoder_frontend = self._encoder_builder.build_frontend() + + encoder = self._encoder_builder.build_encoder() + + pooler = self.build_pooler() + + head = Linear(config.encoder_config.model_dim, config.num_classes, bias=True) + + return JepaClassifierModel(encoder_frontend, encoder, pooler, head) + + def build_pooler(self) -> AttentivePooler: + config = self._config + + if config.pool_depth > 1: + encoder = self._encoder_builder.build_encoder(config.pool_depth) + else: + encoder = None + + decoder_layer = self.build_decoder_layer() + + return AttentivePooler( + decoder_layer=decoder_layer, + encoder=encoder, + num_pools=config.num_queries, + init_std=config.encoder_config.init_std, + device=self._device, + dtype=self._dtype, + ) + + def build_decoder_layer(self) -> CrossAttentionDecoderLayer: + config = self._config + + cross_attn = self._encoder_builder.build_attention( + config.pool_depth, is_cross_attn=True + ) + + ffn = self._encoder_builder.build_ffn(config.pool_depth) + + return CrossAttentionDecoderLayer( + cross_attn, + ffn, + layer_norm_factory=self._encoder_builder.build_layer_norm, + device=self._device, + dtype=self._dtype, + ) + +# def build_attention( +# self, layer_idx: int, is_cross_attn: bool = False +# ) -> MultiheadAttention: +# config = self._config +# +# sdpa = create_default_sdpa(attn_dropout_p=config.attn_dropout_p) +# +# if is_cross_attn: +# output_proj: Projection = IdentityProjection( +# config.model_dim, config.model_dim +# ) +# output_proj = self.build_mha_output_projection(layer_idx) +# +# return StandardMultiheadAttention( +# config.model_dim, +# config.num_attn_heads, +# sdpa=sdpa, +# bias=config.qkv_bias, +# output_proj=output_proj, +# device=self._device, +# dtype=self._dtype, +# ) + + +def create_jepa_classifier_model( + config: JepaClassifierConfig, + *, + device: Device | None = None, + dtype: DataType | None = None, +) -> JepaClassifierModel: + return JepaClassifierBuilder(config, device=device, dtype=dtype).build_model() diff --git a/src/fairseq2/models/jepa/classifier/model.py b/src/fairseq2/models/jepa/classifier/model.py new file mode 100644 index 000000000..54ad9c73b --- /dev/null +++ b/src/fairseq2/models/jepa/classifier/model.py @@ -0,0 +1,230 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import final + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import Module, Parameter + +from fairseq2.models.sequence import SequenceBatch +from fairseq2.models.transformer import TransformerFrontend +from fairseq2.nn.normalization import LayerNorm +from fairseq2.nn.projection import Projection +from fairseq2.nn.transformer import ( + FeedForwardNetwork, + LayerNormFactory, + MultiheadAttention, + TransformerEncoder, + create_standard_layer_norm, +) +from fairseq2.typing import DataType, Device + + +@final +class JepaClassifierModel(Module): + """ + Represents a pretrained Jepa model, with an attentive probing layer for + classfication tasks. See + * :cite:t:`https://doi.org/10.48550/arXiv.2301.08243` + * :cite:t:`https://doi.org/10.48550/arXiv.2404.08471` + """ + + model_dim: int + encoder_frontend: TransformerFrontend + encoder: TransformerEncoder + pooler: AttentivePooler + head: Projection + + def __init__( + self, + encoder_frontend: TransformerFrontend, + encoder: TransformerEncoder, + pooler: AttentivePooler, + head: Projection, + ) -> None: + super().__init__() + + self.model_dim = encoder.model_dim + + self.encoder_frontend = encoder_frontend + self.encoder = encoder + + self.pooler = pooler + + self.head = head + + def forward(self, batch: SequenceBatch) -> Tensor: + seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.padding_mask) + + seqs, _ = self.encoder(seqs, padding_mask) + + seqs = self.pooler(seqs) + + # (N, P, M) + seqs = seqs.squeeze(1) # TODO: NEEDED? + + return self.head(seqs) # type: ignore[no-any-return] + + def extra_repr(self) -> str: + """:meta private:""" + return f"model_dim={self.model_dim}" + + +@final +class AttentivePooler(Module): + """ + An attentive pooler that gets output of a Jepa encoder and decode it into + a logit of a given task. + + TODO: + - Move this into fairseq2.nn to benefit other similiar tasks. Internally, + this module is just a thin transformer encoder without self attention layer. + Optionally, it can consist of some extra transformer encoders depending on the + (finetuning) task + """ + + model_dim: int + decoder_layer: CrossAttentionDecoderLayer + encoder: TransformerEncoder | None + pool_layer: Parameter + init_std: float + + def __init__( + self, + decoder_layer: CrossAttentionDecoderLayer, + encoder: TransformerEncoder | None, + *, + num_pools: int = 1, + init_std: float = 0.02, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + super().__init__() + + self.model_dim = decoder_layer.model_dim + + self.decoder_layer = decoder_layer + + if encoder: + self.encoder = encoder + else: + self.register_module("encoder", None) + + self.pool_layer = Parameter( + torch.empty((1, num_pools, self.model_dim), device=device, dtype=dtype) + ) + + self.init_std = init_std + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset the parameters and buffers of the module.""" + nn.init.trunc_normal_(self.pool_layer, std=self.init_std) + + def forward(self, seqs: Tensor) -> Tensor: + if self.encoder is not None: + seqs, _ = self.encoder(seqs, padding_mask=None) + + batch_size = seqs.size(0) + + # (1, P, M) -> (N, P, M) + pool_seqs = self.pool_layer.repeat(batch_size, 1, 1) + + return self.decoder_layer(pool_seqs, seqs) # type: ignore[no-any-return] + + def extra_repr(self) -> str: + """:meta private:""" + return f"model_dim={self.model_dim}, num_pools={self.pool_layer.size(1)}" + + +@final +class CrossAttentionDecoderLayer(Module): + """Represents a simple transformer decoder with only cross attention and layernorm""" + + model_dim: int + cross_attn_layer_norm: LayerNorm + cross_attn: MultiheadAttention + ffn_layer_norm: LayerNorm + ffn: FeedForwardNetwork + + def __init__( + self, + cross_attn: MultiheadAttention, + ffn: FeedForwardNetwork, + *, + layer_norm_factory: LayerNormFactory | None = None, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + """ + :param cross_attn: + The encoder-decoder attention layer. + :param ffn: + The feed-forward network. + :param layer_norm_factory: + The factory to construct the Layer Normalization modules. + """ + super().__init__() + + model_dim = cross_attn.model_dim + + if layer_norm_factory is None: + layer_norm_factory = create_standard_layer_norm + + self.cross_attn_layer_norm = layer_norm_factory( + model_dim, device=device, dtype=dtype + ) + + self.cross_attn = cross_attn + + self.ffn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) + + self.ffn = ffn + + def forward(self, seqs: Tensor, encoder_output: Tensor) -> Tensor: + seqs = self._forward_cross_attn(seqs, encoder_output) + + seqs = self._forward_ffn(seqs) + + return seqs + + def _forward_cross_attn(self, seqs: Tensor, encoder_output: Tensor) -> Tensor: + residual = seqs + + # Note that the cross-attention norm is applied on encoder output and not seqs + encoder_output = self.cross_attn_layer_norm(encoder_output) + + seqs = self.cross_attn( + seqs, + padding_mask=None, + keys=encoder_output, + key_padding_mask=None, + values=encoder_output, + ) + + seqs = seqs + residual + + return seqs + + def _forward_ffn(self, seqs: Tensor) -> Tensor: + residual = seqs + + seqs = self.ffn_layer_norm(seqs) + + seqs = self.ffn(seqs) + + seqs = seqs + residual + + return seqs + + def extra_repr(self) -> str: + """:meta private:""" + return f"model_dim={self.model_dim}" diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 152df5b22..b0eb74969 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -130,6 +130,7 @@ class JepaEncoderConfig: jepa_arch = jepa_archs.decorator +# TODO(balioglu): work in progress. Supports only vision encoder. class JepaBuilder: """Builds modules of a JEPA model.""" @@ -278,10 +279,11 @@ def build_position_encoder(self) -> InterpolatedPositionEncoder: f"The length of `input_dims` must be 2 or 3, but is {num_input_dims} instead." ) - def build_encoder(self) -> TransformerEncoder: + def build_encoder(self, num_layers: int | None = None) -> TransformerEncoder: config = self._config - num_layers = config.num_encoder_layers + if num_layers is None: + num_layers = config.num_encoder_layers layers = [self.build_encoder_layer(i) for i in range(num_layers)] diff --git a/src/fairseq2/recipes/jepa/factory.py b/src/fairseq2/recipes/jepa/factory.py deleted file mode 100644 index ddcab3dbc..000000000 --- a/src/fairseq2/recipes/jepa/factory.py +++ /dev/null @@ -1,349 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import math -from dataclasses import dataclass, field -from functools import partial -from typing import final - -import torch -from torch import Tensor -from torch.nn import GELU - -from fairseq2.config_registry import ConfigRegistry -from fairseq2.models.jepa.factory import init_truncated_normal -from fairseq2.nn.normalization import LayerNorm, StandardLayerNorm -from fairseq2.nn.projection import IdentityProjection, Linear, Projection -from fairseq2.nn.transformer import ( - FeedForwardNetwork, - MultiheadAttention, - StandardMultiheadAttention, - create_default_sdpa, -) -from fairseq2.nn.transformer.encoder import ( - StandardTransformerEncoder, - TransformerEncoder, -) -from fairseq2.nn.transformer.encoder_layer import ( - StandardTransformerEncoderLayer, - TransformerEncoderLayer, -) -from fairseq2.nn.transformer.ffn import StandardFeedForwardNetwork -from fairseq2.nn.transformer.norm_order import TransformerNormOrder -from fairseq2.recipes.jepa.models import ( - AttentiveClassifier, - AttentivePooler, - CrossAttentionDecoder, -) -from fairseq2.typing import DataType, Device - - -@dataclass(kw_only=True) -class AttentivePoolerConfig: - model_dim: int = 768 - """The dimensionality of the model.""" - - num_input_channels: int = 3 - """The number of input channels per frame.""" - - input_dims: tuple[int, ...] = (224, 224) - """ - The supported native dimensionality of inputs. Expected to be 2-dimensional - (height, width) for images and 3-dimensional (depth, height, width) for - videos. - """ - - patch_dims: tuple[int, ...] = (16, 16) - """The dimensionality of patches to be extracted from inputs.""" - - pool_depth: int = 1 - """The pool depth (minimum 1 decoder layer)""" - - num_attn_heads: int = 12 - """The number of attention heads in encoder layers.""" - - qkv_bias: bool = True - """ - If ``True``, query, key, and value projections in multi-head attention - layers will have an additive bias. - """ - - num_queries: int = 1 - """Number of query tokens in the attention pool layer""" - - attn_dropout_p: float = 0.0 - """The dropout probability on attention weights.""" - - ffn_inner_dim_ratio: float = 4.0 - """ - The ratio of the dimensionality of the inner projection layers in - feed-forward networks to :attr:`model_dim`. - """ - - init_std: float = 0.02 - """ - The standard deviation to initialize weights and biases of projection and - normalization layers. - """ - - dropout_p: float = 0.0 - """The dropout probability on outputs of Transformer layers.""" - - -@dataclass(kw_only=True) -class AttentiveClassifierConfig: - - pooler_config: AttentivePoolerConfig = field( - default_factory=lambda: AttentivePoolerConfig() - ) - - num_classes: int = 1000 - """Size of classification logits""" - -@dataclass(kw_only=True) -class JepaForClassificationConfig: - encoder_model_name: str = "" - - attentive_classifier_name: str = "" - - def __post_init__(self): - if not self.encoder_model_name: - raise ValueError("Must specify encoder_model_name") - - -attentive_archs = ConfigRegistry[AttentiveClassifierConfig]() - -attentive_arch = attentive_archs.decorator - - -@final -class AttentiveClassifierBuilder: - """Build a Jepa model that is fine-tuned for classification""" - - _config: AttentiveClassifierConfig - _device: Device | None - _dtype: DataType | None - - def __init__( - self, - config: AttentiveClassifierConfig, - *, - device: Device | None = None, - dtype: DataType | None = None, - ) -> None: - - self._config = config - - self.pooler_builer = AttentivePoolerBuilder( - config.pooler_config, device=device, dtype=dtype - ) - - self._device, self._dtype = device, dtype - - def build_model(self) -> AttentiveClassifier: - config = self._config - - pooler = self.pooler_builer.build_model() - - head = Linear(config.pooler_config.model_dim, config.num_classes, bias=True) - - return AttentiveClassifier(pooler, head) - - -@final -class AttentivePoolerBuilder: - """ - Build an attentive pooler. Many builer functions are similar to JepaEncoderBuilder - since we have an optional transformer encoder in the pool - - TODO: Refactor to have common building blocks for jepa encoder and pooler - in a base builder class (?) - """ - - _config: AttentivePoolerConfig - _device: Device | None - _dtype: DataType | None - - def __init__( - self, - config: AttentivePoolerConfig, - *, - device: Device | None = None, - dtype: DataType | None = None, - ) -> None: - self._config = config - - self._device, self._dtype = device, dtype - - def build_model(self) -> AttentivePooler: - config = self._config - - def init_pool(pool: Tensor) -> None: - std = config.init_std - with torch.no_grad(): - torch.nn.init.trunc_normal_(pool, std=std) - - decoder = self.build_decoder() - - if config.pool_depth > 1: - encoder = self.build_encoder() - else: - encoder = None - - return AttentivePooler( - decoder=decoder, - encoder=encoder, - num_queries=config.num_queries, - init_fn=init_pool, - device=self._device, - dtype=self._dtype, - ) - - def build_decoder(self) -> CrossAttentionDecoder: - config = self._config - - cross_attn = self.build_attention(config.pool_depth, is_cross_attn=True) - - ffn = self.build_ffn(config.pool_depth) - - return CrossAttentionDecoder( - cross_attn, - ffn, - norm_order=TransformerNormOrder.PRE, - layer_norm_factory=self.build_layer_norm, - device=self._device, - dtype=self._dtype, - ) - - def build_encoder(self) -> TransformerEncoder: - config = self._config - - num_layers = config.pool_depth - - layers = [self.build_encoder_layer(i) for i in range(1, num_layers)] - - return StandardTransformerEncoder( - layers, - norm_order=TransformerNormOrder.PRE, - layer_norm_factory=self.build_layer_norm, - device=self._device, - dtype=self._dtype, - ) - - def build_encoder_layer(self, layer_idx: int) -> TransformerEncoderLayer: - config = self._config - - self_attn = self.build_attention(layer_idx) - - ffn = self.build_ffn(layer_idx) - - return StandardTransformerEncoderLayer( - self_attn, - ffn, - dropout_p=config.dropout_p, - norm_order=TransformerNormOrder.PRE, - layer_norm_factory=self.build_layer_norm, - device=self._device, - dtype=self._dtype, - ) - - def build_attention( - self, layer_idx: int, is_cross_attn: bool = False - ) -> MultiheadAttention: - config = self._config - - sdpa = create_default_sdpa(attn_dropout_p=config.attn_dropout_p) - - if is_cross_attn: - output_proj: Projection = IdentityProjection(config.model_dim, config.model_dim) - output_proj = self.build_mha_output_projection(layer_idx) - - return StandardMultiheadAttention( - config.model_dim, - config.num_attn_heads, - sdpa=sdpa, - bias=config.qkv_bias, - output_proj=output_proj, - device=self._device, - dtype=self._dtype, - ) - - def build_mha_output_projection(self, layer_idx: int) -> Projection: - config = self._config - - init_std = config.init_std - - def init_projection(proj: Linear) -> None: - init_truncated_normal(proj.weight, proj.bias, std=init_std) - - with torch.no_grad(): - proj.weight.div_(math.sqrt(2.0 * (layer_idx + 1))) - - return Linear( - config.model_dim, - config.model_dim, - bias=True, - init_fn=init_projection, - device=self._device, - dtype=self._dtype, - ) - - def build_ffn(self, layer_idx: int) -> FeedForwardNetwork: - config = self._config - - init_std = config.init_std - - def init_projection(proj: Linear) -> None: - init_truncated_normal(proj.weight, proj.bias, std=init_std) - - with torch.no_grad(): - proj.weight.div_(math.sqrt(2.0 * (layer_idx))) - - inner_dim = int(config.model_dim * config.ffn_inner_dim_ratio) - - return StandardFeedForwardNetwork( - config.model_dim, - inner_dim, - bias=True, - inner_activation=GELU(), - proj_init_fn=init_projection, - norm_order=TransformerNormOrder.PRE, - device=self._device, - dtype=self._dtype, - ) - - def build_layer_norm( - self, - model_dim: int, - *, - device: Device | None = None, - dtype: DataType | None = None, - ) -> LayerNorm: - config = self._config - - init_std = config.init_std - - init_layer_norm = partial( - init_truncated_normal, std=init_std - ) - - return StandardLayerNorm( - model_dim, - bias=True, - eps=1e-6, - init_fn=init_layer_norm, - device=device, - dtype=dtype, - ) - - -def create_attentive_pooler( - config: AttentivePoolerConfig, - *, - device: Device | None = None, - dtype: DataType | None = None, -) -> AttentivePooler: - return AttentivePoolerBuilder(config, device=device, dtype=dtype).build_model() diff --git a/src/fairseq2/recipes/jepa/models.py b/src/fairseq2/recipes/jepa/models.py deleted file mode 100644 index 83e8e50b9..000000000 --- a/src/fairseq2/recipes/jepa/models.py +++ /dev/null @@ -1,315 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -from collections.abc import Callable -from typing import final - -import torch -from torch import Tensor -from torch.nn import Dropout, Module, Parameter -from typing_extensions import override - -from fairseq2.models.model import Model -from fairseq2.models.sequence import SequenceBatch -from fairseq2.nn.incremental_state import IncrementalStateBag -from fairseq2.nn.normalization import LayerNorm -from fairseq2.nn.padding import PaddingMask -from fairseq2.nn.projection import Projection -from fairseq2.nn.transformer import ( - FeedForwardNetwork, - LayerNormFactory, - MultiheadAttention, - ResidualConnect, - TransformerEncoder, - TransformerNormOrder, - make_standard_layer_norm, -) -from fairseq2.nn.transformer.residual import StandardResidualConnect -from fairseq2.typing import DataType, Device - - -class CrossAttentionDecoder(Module): - """Represents a simple transformer decoder with only cross attention and layernorm""" - - model_dim: int - cross_attn: MultiheadAttention - cross_attn_dropout: Dropout | None - cross_attn_residual: ResidualConnect | None - cross_attn_layer_norm: LayerNorm | None - ffn: FeedForwardNetwork - ffn_dropout: Dropout | None - ffn_layer_norm: LayerNorm - norm_order: TransformerNormOrder - - def __init__( - self, - cross_attn: MultiheadAttention, - ffn: FeedForwardNetwork, - *, - dropout_p: float = 0.0, - norm_order: TransformerNormOrder = TransformerNormOrder.POST, - layer_norm_factory: LayerNormFactory | None = None, - cross_attn_residual: ResidualConnect | None = None, - ffn_residual: ResidualConnect | None = None, - device: Device | None = None, - dtype: DataType | None = None, - ) -> None: - """ - :param cross_attn: - The encoder-decoder attention layer. - :param ffn: - The feed-forward network. - :param dropout_p: - The dropout probability on outputs of the attention layers and the - feed-forward network. - :param norm_order: - The Layer Normalization order. - :param layer_norm_factory: - The factory to construct the Layer Normalization modules. - :param cross_attn_residual: - The residual connection between the input and output of the - encoder-decoder attention layer. - :param ffn_residual: - The residual connection between the input and output of the - feed-forward network. - attention layer. - """ - model_dim = cross_attn.model_dim - - super().__init__(model_dim) - - if layer_norm_factory is None: - layer_norm_factory = make_standard_layer_norm - - cross_attn_layer_norm = layer_norm_factory( - model_dim, device=device, dtype=dtype - ) - - if norm_order != TransformerNormOrder.POST: - self.cross_attn_layer_norm = cross_attn_layer_norm - - self.cross_attn = cross_attn - - if dropout_p > 0.0: - self.cross_attn_dropout = Dropout(dropout_p) - else: - self.register_module("cross_attn_dropout", None) - - if cross_attn_residual is None: - cross_attn_residual = StandardResidualConnect() - - self.cross_attn_residual = cross_attn_residual - - if norm_order == TransformerNormOrder.POST: - self.cross_attn_layer_norm = cross_attn_layer_norm - - ffn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) - - if norm_order != TransformerNormOrder.POST: - self.ffn_layer_norm = ffn_layer_norm - - self.ffn = ffn - - if dropout_p > 0.0: - self.ffn_dropout = Dropout(dropout_p) - else: - self.register_module("ffn_dropout", None) - - if ffn_residual is None: - ffn_residual = StandardResidualConnect() - - self.ffn_residual = ffn_residual - - if norm_order == TransformerNormOrder.POST: - self.ffn_layer_norm = ffn_layer_norm - - self.norm_order = norm_order - - @override - def forward( - self, - seqs: Tensor, - padding_mask: PaddingMask | None, - encoder_output: Tensor | None = None, - encoder_padding_mask: PaddingMask | None = None, - state_bag: IncrementalStateBag | None = None, - ) -> tuple[Tensor, PaddingMask | None]: - if encoder_output is None: - raise ValueError( - "`encoder_output` must not be `None` for encoder-decoder attention." - ) - - seqs = self._forward_cross_attn( - seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag - ) - - seqs = self._forward_ffn(seqs) - return seqs, padding_mask - - def _forward_cross_attn( - self, - seqs: Tensor, - padding_mask: PaddingMask | None, - encoder_output: Tensor | None, - encoder_padding_mask: PaddingMask | None, - state_bag: IncrementalStateBag | None, - ) -> Tensor: - - residual = seqs - - assert self.cross_attn_residual is not None - assert self.cross_attn_layer_norm is not None - - # Note that the cross-attention norm is applief on encoder output and not seqs - if self.norm_order != TransformerNormOrder.POST: - encoder_output = self.cross_attn_layer_norm(encoder_output) - - seqs = self.cross_attn( - seqs, - padding_mask, - keys=encoder_output, - key_padding_mask=encoder_padding_mask, - values=encoder_output, - state_bag=state_bag, - ) - - if self.cross_attn_dropout is not None: - seqs = self.cross_attn_dropout(seqs) - - seqs = self.cross_attn_residual(seqs, residual) - - if self.norm_order == TransformerNormOrder.POST: - seqs = self.cross_attn_layer_norm(seqs) - - return seqs - - def _forward_ffn(self, seqs: Tensor) -> Tensor: - residual = seqs - - if self.norm_order != TransformerNormOrder.POST: - seqs = self.ffn_layer_norm(seqs) - - seqs = self.ffn(seqs) - - if self.ffn_dropout is not None: - seqs = self.ffn_dropout(seqs) - - seqs = self.ffn_residual(seqs, residual) - - if self.norm_order == TransformerNormOrder.POST: - seqs = self.ffn_layer_norm(seqs) - - return seqs - - def extra_repr(self) -> str: - """:meta private:""" - s = super().extra_repr() - - return f"{s}, model_dim={self.model_dim}, norm_order={self.norm_order.name}" - - -@final -class AttentivePooler(Module): - """ - An attentive pooler that gets output of a Jepa encoder and decode it into - a logit of a given task. - - TODO: - - Move this into fairseq2.nn to benefit other similiar tasks. Internally, - this module is just a thin transformer encoder without self attention layer. - Optionally, it can consist of some extra transformer encoders depending on the - (finetuning) task - """ - - model_dim: int - num_queries: int - decoder: CrossAttentionDecoder - encoder: TransformerEncoder | None - init_fn: Callable[[Tensor], None] | None - - def __init__( - self, - decoder: CrossAttentionDecoder, - encoder: TransformerEncoder | None, - *, - num_queries: int = 1, - init_fn: Callable[[Tensor], None] | None = None, - device: Device | None = None, - dtype: DataType | None = None, - ) -> None: - super().__init__() - - self.model_dim = decoder.model_dim - - self.decoder = decoder - - if encoder: - self.encoder = encoder - else: - self.register_module("encoder", None) - - self.num_queries = num_queries - self.query_tokens = Parameter( - torch.empty(1, num_queries, self.model_dim, device=device, dtype=dtype) - ) - - if init_fn: - init_fn(self.pool_layer) - - def forward( - self, seqs: Tensor, padding_mask: PaddingMask | None - ) -> tuple[Tensor, PaddingMask | None]: - if self.encoder: - seqs, padding_mask = self.encoder(seqs, padding_mask) - queries = self.query_tokens.repeat(len(seqs), 1, 1) - seqs, padding_mask = self.decoder(queries, None, seqs, padding_mask) - return seqs, padding_mask - - def extra_repr(self) -> str: - """:meta private:""" - s = super().extra_repr() - - return f"{s}, model_dim={self.model_dim}, pools={self.num_queries}" - - -@final -class AttentiveClassifier(Model): - """ - Represents a pretrained Jepa model, with an attentive probing layer for - classfication tasks. See - * :cite:t:`https://doi.org/10.48550/arXiv.2301.08243` - * :cite:t:`https://doi.org/10.48550/arXiv.2404.08471` - """ - - model_dim: int - pooler: AttentivePooler - head: Projection - - def __init__( - self, - pooler: AttentivePooler, - head: Projection, - ) -> None: - super().__init__() - - self.model_dim = pooler.model_dim - - self.pooler = pooler - self.head = head - - def forward(self, batch: SequenceBatch) -> Tensor: - seqs, _ = self.pooler(batch.seqs, batch.padding_mask) - seqs = seqs.squeeze(1) - output: Tensor = self.head(seqs) - return output - - def extra_repr(self) -> str: - """:meta private:""" - s = super().extra_repr() - - return f"{s}, model_dim={self.model_dim}" From 4690e6bd31bfba392d4d8826e9ee0b4918ef8cf6 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Tue, 24 Dec 2024 08:37:55 +0100 Subject: [PATCH 30/37] update cross_attn build func --- .../models/jepa/classifier/factory.py | 44 ++++++++----------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/src/fairseq2/models/jepa/classifier/factory.py b/src/fairseq2/models/jepa/classifier/factory.py index 947d100da..90af2329a 100644 --- a/src/fairseq2/models/jepa/classifier/factory.py +++ b/src/fairseq2/models/jepa/classifier/factory.py @@ -104,9 +104,7 @@ def build_pooler(self) -> AttentivePooler: def build_decoder_layer(self) -> CrossAttentionDecoderLayer: config = self._config - cross_attn = self._encoder_builder.build_attention( - config.pool_depth, is_cross_attn=True - ) + cross_attn = self.build_cross_attention() ffn = self._encoder_builder.build_ffn(config.pool_depth) @@ -118,28 +116,24 @@ def build_decoder_layer(self) -> CrossAttentionDecoderLayer: dtype=self._dtype, ) -# def build_attention( -# self, layer_idx: int, is_cross_attn: bool = False -# ) -> MultiheadAttention: -# config = self._config -# -# sdpa = create_default_sdpa(attn_dropout_p=config.attn_dropout_p) -# -# if is_cross_attn: -# output_proj: Projection = IdentityProjection( -# config.model_dim, config.model_dim -# ) -# output_proj = self.build_mha_output_projection(layer_idx) -# -# return StandardMultiheadAttention( -# config.model_dim, -# config.num_attn_heads, -# sdpa=sdpa, -# bias=config.qkv_bias, -# output_proj=output_proj, -# device=self._device, -# dtype=self._dtype, -# ) + def build_cross_attention(self) -> MultiheadAttention: + config = self._config.encoder_config + + model_dim = config.model_dim + + sdpa = create_default_sdpa(attn_dropout_p=config.attn_dropout_p) + + output_proj: Projection = IdentityProjection(model_dim, model_dim) + + return StandardMultiheadAttention( + model_dim, + config.num_encoder_attn_heads, + sdpa=sdpa, + bias=config.qkv_bias, + output_proj=output_proj, + device=self._device, + dtype=self._dtype, + ) def create_jepa_classifier_model( From a345c4c8f6b72a904dd604a828d240ff72b32b74 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 26 Dec 2024 17:16:54 +0100 Subject: [PATCH 31/37] lint --- src/fairseq2/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fairseq2/models/loader.py b/src/fairseq2/models/loader.py index c2fdaf428..0f4dabfab 100644 --- a/src/fairseq2/models/loader.py +++ b/src/fairseq2/models/loader.py @@ -206,7 +206,7 @@ def __call__( force: bool = False, progress: bool = True, strict_state_dict: bool = True, - + ) -> ModelT: if isinstance(model_name_or_card, AssetCard): card = model_name_or_card From f96344b8019656e632f26644309b740ac3c7c707 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 26 Dec 2024 17:21:21 +0100 Subject: [PATCH 32/37] update AttentivePooler param names --- .../models/jepa/classifier/__init__.py | 18 ++++++ src/fairseq2/models/jepa/classifier/archs.py | 37 +++++++++++ .../models/jepa/classifier/factory.py | 61 +++++++++++++++---- src/fairseq2/models/jepa/classifier/model.py | 26 ++++---- 4 files changed, 117 insertions(+), 25 deletions(-) create mode 100644 src/fairseq2/models/jepa/classifier/archs.py diff --git a/src/fairseq2/models/jepa/classifier/__init__.py b/src/fairseq2/models/jepa/classifier/__init__.py index e69de29bb..3b315f16d 100644 --- a/src/fairseq2/models/jepa/classifier/__init__.py +++ b/src/fairseq2/models/jepa/classifier/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.models.jepa.classifier.factory import create_jepa_classifier_model as create_jepa_classifier_model +from fairseq2.models.jepa.classifier.factory import jepa_classifier_archs as jepa_classifier_archs +from fairseq2.models.jepa.classifier.factory import JEPA_CLASSIFIER_FAMILY as JEPA_CLASSIFIER_FAMILY +from fairseq2.models.jepa.classifier.factory import JepaClassifierConfig as JepaClassifierConfig +from fairseq2.models.jepa.classifier.factory import JepaClassifierBuilder as JepaClassifierBuilder +from fairseq2.models.jepa.classifier.model import JepaClassifierModel as JepaClassifierModel + +# isort: split + +import fairseq2.models.jepa.classifier.archs # Register architectures diff --git a/src/fairseq2/models/jepa/classifier/archs.py b/src/fairseq2/models/jepa/classifier/archs.py new file mode 100644 index 000000000..dc049e9e3 --- /dev/null +++ b/src/fairseq2/models/jepa/classifier/archs.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.models.jepa.archs import ( + base as jepa_base, + large as jepa_large, + huge as jepa_huge, +) +from fairseq2.models.jepa.classifier.factory import JepaClassifierConfig, jepa_classifier_arch + +@jepa_classifier_arch("base") +def base() -> JepaClassifierConfig: + pretrain_config = jepa_base() + return JepaClassifierConfig( + encoder_config=pretrain_config.encoder_config + ) + + +@jepa_classifier_arch("large") +def large() -> JepaClassifierConfig: + pretrain_config = jepa_large() + return JepaClassifierConfig( + encoder_config=pretrain_config.encoder_config + ) + + +@jepa_classifier_arch("huge") +def huge() -> JepaClassifierConfig: + pretrain_config = jepa_huge() + return JepaClassifierConfig( + encoder_config=pretrain_config.encoder_config + ) diff --git a/src/fairseq2/models/jepa/classifier/factory.py b/src/fairseq2/models/jepa/classifier/factory.py index 90af2329a..8839ec20f 100644 --- a/src/fairseq2/models/jepa/classifier/factory.py +++ b/src/fairseq2/models/jepa/classifier/factory.py @@ -13,7 +13,8 @@ CrossAttentionDecoderLayer, JepaClassifierModel, ) -from fairseq2.models.jepa.factory import JepaEncoderBuilder, JepaEncoderConfig +from fairseq2.models.factory import model_factories +from fairseq2.models.jepa import JepaEncoderBuilder, JepaEncoderConfig from fairseq2.nn.projection import IdentityProjection, Linear, Projection from fairseq2.nn.transformer import ( MultiheadAttention, @@ -22,6 +23,8 @@ ) from fairseq2.typing import DataType, Device +JEPA_CLASSIFIER_FAMILY = "jepa_classifier" + @dataclass(kw_only=True) class JepaClassifierConfig: @@ -32,6 +35,9 @@ class JepaClassifierConfig: pool_depth: int = 1 """The pool depth (minimum 1 decoder layer)""" + + decoder_projection: bool = True + """If True, the decoder will have a linear layer on top""" num_queries: int = 1 """Number of query tokens in the attention pool layer""" @@ -70,15 +76,11 @@ def __init__( self._device, self._dtype = device, dtype def build_model(self) -> JepaClassifierModel: - config = self._config encoder_frontend = self._encoder_builder.build_frontend() - encoder = self._encoder_builder.build_encoder() - pooler = self.build_pooler() - - head = Linear(config.encoder_config.model_dim, config.num_classes, bias=True) + head = self.build_head() return JepaClassifierModel(encoder_frontend, encoder, pooler, head) @@ -90,17 +92,27 @@ def build_pooler(self) -> AttentivePooler: else: encoder = None - decoder_layer = self.build_decoder_layer() + decoder = self.build_decoder_layer() return AttentivePooler( - decoder_layer=decoder_layer, + decoder=decoder, encoder=encoder, - num_pools=config.num_queries, + num_queries=config.num_queries, init_std=config.encoder_config.init_std, device=self._device, dtype=self._dtype, ) + def build_head(self) -> Projection: + config = self._config + return Linear( + config.encoder_config.model_dim, + config.num_classes, + device=self._device, + dtype=self._dtype, + bias=True, + ) + def build_decoder_layer(self) -> CrossAttentionDecoderLayer: config = self._config @@ -118,12 +130,12 @@ def build_decoder_layer(self) -> CrossAttentionDecoderLayer: def build_cross_attention(self) -> MultiheadAttention: config = self._config.encoder_config - + model_dim = config.model_dim - + sdpa = create_default_sdpa(attn_dropout_p=config.attn_dropout_p) - - output_proj: Projection = IdentityProjection(model_dim, model_dim) + + output_proj = self.build_cross_attn_output_projection() return StandardMultiheadAttention( model_dim, @@ -135,6 +147,22 @@ def build_cross_attention(self) -> MultiheadAttention: dtype=self._dtype, ) + def build_cross_attn_output_projection(self) -> Projection: + config = self._config + + model_dim = config.encoder_config.model_dim + + if config.decoder_projection: + return Linear( + model_dim, + model_dim, + bias=True, + device=self._device, + dtype=self._dtype, + ) + else: + return IdentityProjection(model_dim, model_dim) + def create_jepa_classifier_model( config: JepaClassifierConfig, @@ -143,3 +171,10 @@ def create_jepa_classifier_model( dtype: DataType | None = None, ) -> JepaClassifierModel: return JepaClassifierBuilder(config, device=device, dtype=dtype).build_model() + +model_factories.register( + JEPA_CLASSIFIER_FAMILY, + create_jepa_classifier_model, + JepaClassifierConfig, + jepa_classifier_archs, +) diff --git a/src/fairseq2/models/jepa/classifier/model.py b/src/fairseq2/models/jepa/classifier/model.py index 54ad9c73b..281f230f8 100644 --- a/src/fairseq2/models/jepa/classifier/model.py +++ b/src/fairseq2/models/jepa/classifier/model.py @@ -91,34 +91,34 @@ class AttentivePooler(Module): """ model_dim: int - decoder_layer: CrossAttentionDecoderLayer + decoder: CrossAttentionDecoderLayer encoder: TransformerEncoder | None - pool_layer: Parameter + query_tokens: Parameter init_std: float def __init__( self, - decoder_layer: CrossAttentionDecoderLayer, + decoder: CrossAttentionDecoderLayer, encoder: TransformerEncoder | None, *, - num_pools: int = 1, + num_queries: int = 1, init_std: float = 0.02, device: Device | None = None, dtype: DataType | None = None, ) -> None: super().__init__() - self.model_dim = decoder_layer.model_dim + self.model_dim = decoder.model_dim - self.decoder_layer = decoder_layer + self.decoder = decoder if encoder: self.encoder = encoder else: self.register_module("encoder", None) - self.pool_layer = Parameter( - torch.empty((1, num_pools, self.model_dim), device=device, dtype=dtype) + self.query_tokens = Parameter( + torch.empty((1, num_queries, self.model_dim), device=device, dtype=dtype) ) self.init_std = init_std @@ -127,7 +127,7 @@ def __init__( def reset_parameters(self) -> None: """Reset the parameters and buffers of the module.""" - nn.init.trunc_normal_(self.pool_layer, std=self.init_std) + nn.init.trunc_normal_(self.query_tokens, std=self.init_std) def forward(self, seqs: Tensor) -> Tensor: if self.encoder is not None: @@ -136,13 +136,13 @@ def forward(self, seqs: Tensor) -> Tensor: batch_size = seqs.size(0) # (1, P, M) -> (N, P, M) - pool_seqs = self.pool_layer.repeat(batch_size, 1, 1) + pool_seqs = self.query_tokens.repeat(batch_size, 1, 1) - return self.decoder_layer(pool_seqs, seqs) # type: ignore[no-any-return] + return self.decoder(pool_seqs, seqs) # type: ignore[no-any-return] def extra_repr(self) -> str: """:meta private:""" - return f"model_dim={self.model_dim}, num_pools={self.pool_layer.size(1)}" + return f"model_dim={self.model_dim}, num_queries={self.query_tokens.size(1)}" @final @@ -182,6 +182,8 @@ def __init__( self.cross_attn_layer_norm = layer_norm_factory( model_dim, device=device, dtype=dtype ) + + self.model_dim = model_dim self.cross_attn = cross_attn From 75472abe1bbac6e187b7168bb58dc9c6ba515b31 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 26 Dec 2024 17:36:24 +0100 Subject: [PATCH 33/37] decouple #938 --- src/fairseq2/models/loader.py | 9 +-------- src/fairseq2/nn/utils/module.py | 11 ++++++----- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/fairseq2/models/loader.py b/src/fairseq2/models/loader.py index 0f4dabfab..589532e3f 100644 --- a/src/fairseq2/models/loader.py +++ b/src/fairseq2/models/loader.py @@ -82,7 +82,6 @@ def __call__( dtype: DataType | None = None, force: bool = False, progress: bool = True, - strict_state_dict: bool = True, ) -> ModelT_co: """ :param model_name_or_card: @@ -99,9 +98,6 @@ def __call__( cache. :param progress: If ``True``, displays a progress bar to stderr. - :param strict_state_dict: - If ``True``, checkpoint' parameters and layers must be identical to - the model state dict) :returns: A model loaded from the checkpoint of ``model_name_or_card``. @@ -205,7 +201,6 @@ def __call__( dtype: DataType | None = None, force: bool = False, progress: bool = True, - strict_state_dict: bool = True, ) -> ModelT: if isinstance(model_name_or_card, AssetCard): @@ -361,7 +356,7 @@ def __call__( consume_prefix_in_state_dict_if_present(state_dict, prefix="module.") try: - load_state_dict(model, state_dict, strict=strict_state_dict) + load_state_dict(model, state_dict) except (KeyError, ValueError) as ex: raise AssetError( f"{card.name} cannot be loaded. See nested exception for details." @@ -402,7 +397,6 @@ def __call__( dtype: DataType | None = None, force: bool = False, progress: bool = True, - strict_state_dict: bool = True, ) -> ModelT: if isinstance(model_name_or_card, AssetCard): card = model_name_or_card @@ -426,7 +420,6 @@ def __call__( dtype=dtype, force=force, progress=progress, - strict_state_dict=strict_state_dict, ) def register(self, family: str, loader: ModelLoader[ModelT]) -> None: diff --git a/src/fairseq2/nn/utils/module.py b/src/fairseq2/nn/utils/module.py index 02ab7d1d2..e76200599 100644 --- a/src/fairseq2/nn/utils/module.py +++ b/src/fairseq2/nn/utils/module.py @@ -431,15 +431,16 @@ def broadcast_module( _broadcast_coalesced(pg, tensors, bucket_size, source_rank) -def load_state_dict(module: Module, state_dict: Mapping[str, object], strict: bool = True) -> None: +def load_state_dict(module: Module, state_dict: Mapping[str, object]) -> None: """Copy parameters and buffers from ``state_dict`` into ``module`` and its descendant modules. - This implementation internally calls :meth:`Module.load_state_dict()`, and also enforces that - ``state_dict`` does not contain any keys corresponding to descendants that are set to ``None`` - via :meth:`Module.register_module()`. + This implementation internally calls :meth:`Module.load_state_dict()` with + ``strict`` set to ``True``, and also enforces that ``state_dict`` does not + contain any keys corresponding to descendants that are set to ``None`` via + :meth:`Module.register_module()`. """ - module.load_state_dict(state_dict, strict=strict) + module.load_state_dict(state_dict, strict=True) unexpected_keys = [] From 8993698acb75ef56e2c2ebd4c9d37c68c96c5bf2 Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 26 Dec 2024 17:37:15 +0100 Subject: [PATCH 34/37] lint --- src/fairseq2/models/loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fairseq2/models/loader.py b/src/fairseq2/models/loader.py index 589532e3f..552cf482f 100644 --- a/src/fairseq2/models/loader.py +++ b/src/fairseq2/models/loader.py @@ -201,7 +201,6 @@ def __call__( dtype: DataType | None = None, force: bool = False, progress: bool = True, - ) -> ModelT: if isinstance(model_name_or_card, AssetCard): card = model_name_or_card From 4394a0ab452e728d83fdf87f5f05e8b253faf3a7 Mon Sep 17 00:00:00 2001 From: Tuan Tran Date: Thu, 26 Dec 2024 16:42:15 +0000 Subject: [PATCH 35/37] lint --- src/fairseq2/models/loader.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/fairseq2/models/loader.py b/src/fairseq2/models/loader.py index 552cf482f..57e72e333 100644 --- a/src/fairseq2/models/loader.py +++ b/src/fairseq2/models/loader.py @@ -124,8 +124,7 @@ def __call__( model: ModelT_contra, config: ModelConfigT_contra, gangs: Mapping[str, Gang], - ) -> None: - ... + ) -> None: ... @final From 05c979fb71b1be49618c39ec9c8fd5b8080b11bb Mon Sep 17 00:00:00 2001 From: Tuan Tran <{ID}+{username}@users.noreply.github.com> Date: Thu, 26 Dec 2024 17:58:59 +0100 Subject: [PATCH 36/37] lint --- .../models/jepa/classifier/__init__.py | 27 +++++++++++++------ src/fairseq2/models/jepa/classifier/archs.py | 26 ++++++++---------- .../models/jepa/classifier/factory.py | 14 ++++++---- src/fairseq2/models/jepa/classifier/model.py | 2 +- 4 files changed, 40 insertions(+), 29 deletions(-) diff --git a/src/fairseq2/models/jepa/classifier/__init__.py b/src/fairseq2/models/jepa/classifier/__init__.py index 3b315f16d..7f321afb0 100644 --- a/src/fairseq2/models/jepa/classifier/__init__.py +++ b/src/fairseq2/models/jepa/classifier/__init__.py @@ -6,13 +6,24 @@ from __future__ import annotations -from fairseq2.models.jepa.classifier.factory import create_jepa_classifier_model as create_jepa_classifier_model -from fairseq2.models.jepa.classifier.factory import jepa_classifier_archs as jepa_classifier_archs -from fairseq2.models.jepa.classifier.factory import JEPA_CLASSIFIER_FAMILY as JEPA_CLASSIFIER_FAMILY -from fairseq2.models.jepa.classifier.factory import JepaClassifierConfig as JepaClassifierConfig -from fairseq2.models.jepa.classifier.factory import JepaClassifierBuilder as JepaClassifierBuilder -from fairseq2.models.jepa.classifier.model import JepaClassifierModel as JepaClassifierModel +import fairseq2.models.jepa.classifier.archs # Register architectures +from fairseq2.models.jepa.classifier.factory import ( + JEPA_CLASSIFIER_FAMILY as JEPA_CLASSIFIER_FAMILY, +) +from fairseq2.models.jepa.classifier.factory import ( + JepaClassifierBuilder as JepaClassifierBuilder, +) +from fairseq2.models.jepa.classifier.factory import ( + JepaClassifierConfig as JepaClassifierConfig, +) +from fairseq2.models.jepa.classifier.factory import ( + create_jepa_classifier_model as create_jepa_classifier_model, +) +from fairseq2.models.jepa.classifier.factory import ( + jepa_classifier_archs as jepa_classifier_archs, +) +from fairseq2.models.jepa.classifier.model import ( + JepaClassifierModel as JepaClassifierModel, +) # isort: split - -import fairseq2.models.jepa.classifier.archs # Register architectures diff --git a/src/fairseq2/models/jepa/classifier/archs.py b/src/fairseq2/models/jepa/classifier/archs.py index dc049e9e3..b131e2342 100644 --- a/src/fairseq2/models/jepa/classifier/archs.py +++ b/src/fairseq2/models/jepa/classifier/archs.py @@ -6,32 +6,28 @@ from __future__ import annotations -from fairseq2.models.jepa.archs import ( - base as jepa_base, - large as jepa_large, - huge as jepa_huge, +from fairseq2.models.jepa.archs import base as jepa_base +from fairseq2.models.jepa.archs import huge as jepa_huge +from fairseq2.models.jepa.archs import large as jepa_large +from fairseq2.models.jepa.classifier.factory import ( + JepaClassifierConfig, + jepa_classifier_arch, ) -from fairseq2.models.jepa.classifier.factory import JepaClassifierConfig, jepa_classifier_arch + @jepa_classifier_arch("base") def base() -> JepaClassifierConfig: pretrain_config = jepa_base() - return JepaClassifierConfig( - encoder_config=pretrain_config.encoder_config - ) + return JepaClassifierConfig(encoder_config=pretrain_config.encoder_config) @jepa_classifier_arch("large") def large() -> JepaClassifierConfig: pretrain_config = jepa_large() - return JepaClassifierConfig( - encoder_config=pretrain_config.encoder_config - ) - + return JepaClassifierConfig(encoder_config=pretrain_config.encoder_config) + @jepa_classifier_arch("huge") def huge() -> JepaClassifierConfig: pretrain_config = jepa_huge() - return JepaClassifierConfig( - encoder_config=pretrain_config.encoder_config - ) + return JepaClassifierConfig(encoder_config=pretrain_config.encoder_config) diff --git a/src/fairseq2/models/jepa/classifier/factory.py b/src/fairseq2/models/jepa/classifier/factory.py index 8839ec20f..bb7a6ea5d 100644 --- a/src/fairseq2/models/jepa/classifier/factory.py +++ b/src/fairseq2/models/jepa/classifier/factory.py @@ -8,13 +8,13 @@ from typing import final from fairseq2.config_registry import ConfigRegistry +from fairseq2.models.factory import model_factories +from fairseq2.models.jepa import JepaEncoderBuilder, JepaEncoderConfig from fairseq2.models.jepa.classifier.model import ( AttentivePooler, CrossAttentionDecoderLayer, JepaClassifierModel, ) -from fairseq2.models.factory import model_factories -from fairseq2.models.jepa import JepaEncoderBuilder, JepaEncoderConfig from fairseq2.nn.projection import IdentityProjection, Linear, Projection from fairseq2.nn.transformer import ( MultiheadAttention, @@ -35,7 +35,7 @@ class JepaClassifierConfig: pool_depth: int = 1 """The pool depth (minimum 1 decoder layer)""" - + decoder_projection: bool = True """If True, the decoder will have a linear layer on top""" @@ -76,7 +76,6 @@ def __init__( self._device, self._dtype = device, dtype def build_model(self) -> JepaClassifierModel: - encoder_frontend = self._encoder_builder.build_frontend() encoder = self._encoder_builder.build_encoder() pooler = self.build_pooler() @@ -170,7 +169,12 @@ def create_jepa_classifier_model( device: Device | None = None, dtype: DataType | None = None, ) -> JepaClassifierModel: - return JepaClassifierBuilder(config, device=device, dtype=dtype).build_model() + return JepaClassifierBuilder( + config, + device=device, + dtype=dtype, + ).build_model() + model_factories.register( JEPA_CLASSIFIER_FAMILY, diff --git a/src/fairseq2/models/jepa/classifier/model.py b/src/fairseq2/models/jepa/classifier/model.py index 281f230f8..0ba5541d3 100644 --- a/src/fairseq2/models/jepa/classifier/model.py +++ b/src/fairseq2/models/jepa/classifier/model.py @@ -182,7 +182,7 @@ def __init__( self.cross_attn_layer_norm = layer_norm_factory( model_dim, device=device, dtype=dtype ) - + self.model_dim = model_dim self.cross_attn = cross_attn From fc3b5c93bfe53a0a060e80486c1811240cf50b3d Mon Sep 17 00:00:00 2001 From: Tuan Tran Date: Thu, 26 Dec 2024 17:41:05 +0000 Subject: [PATCH 37/37] lint --- src/fairseq2/models/loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fairseq2/models/loader.py b/src/fairseq2/models/loader.py index 57e72e333..552cf482f 100644 --- a/src/fairseq2/models/loader.py +++ b/src/fairseq2/models/loader.py @@ -124,7 +124,8 @@ def __call__( model: ModelT_contra, config: ModelConfigT_contra, gangs: Mapping[str, Gang], - ) -> None: ... + ) -> None: + ... @final