diff --git a/src/fairseq2/models/jepa/classifier/__init__.py b/src/fairseq2/models/jepa/classifier/__init__.py new file mode 100644 index 000000000..7f321afb0 --- /dev/null +++ b/src/fairseq2/models/jepa/classifier/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import fairseq2.models.jepa.classifier.archs # Register architectures +from fairseq2.models.jepa.classifier.factory import ( + JEPA_CLASSIFIER_FAMILY as JEPA_CLASSIFIER_FAMILY, +) +from fairseq2.models.jepa.classifier.factory import ( + JepaClassifierBuilder as JepaClassifierBuilder, +) +from fairseq2.models.jepa.classifier.factory import ( + JepaClassifierConfig as JepaClassifierConfig, +) +from fairseq2.models.jepa.classifier.factory import ( + create_jepa_classifier_model as create_jepa_classifier_model, +) +from fairseq2.models.jepa.classifier.factory import ( + jepa_classifier_archs as jepa_classifier_archs, +) +from fairseq2.models.jepa.classifier.model import ( + JepaClassifierModel as JepaClassifierModel, +) + +# isort: split diff --git a/src/fairseq2/models/jepa/classifier/archs.py b/src/fairseq2/models/jepa/classifier/archs.py new file mode 100644 index 000000000..b131e2342 --- /dev/null +++ b/src/fairseq2/models/jepa/classifier/archs.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from fairseq2.models.jepa.archs import base as jepa_base +from fairseq2.models.jepa.archs import huge as jepa_huge +from fairseq2.models.jepa.archs import large as jepa_large +from fairseq2.models.jepa.classifier.factory import ( + JepaClassifierConfig, + jepa_classifier_arch, +) + + +@jepa_classifier_arch("base") +def base() -> JepaClassifierConfig: + pretrain_config = jepa_base() + return JepaClassifierConfig(encoder_config=pretrain_config.encoder_config) + + +@jepa_classifier_arch("large") +def large() -> JepaClassifierConfig: + pretrain_config = jepa_large() + return JepaClassifierConfig(encoder_config=pretrain_config.encoder_config) + + +@jepa_classifier_arch("huge") +def huge() -> JepaClassifierConfig: + pretrain_config = jepa_huge() + return JepaClassifierConfig(encoder_config=pretrain_config.encoder_config) diff --git a/src/fairseq2/models/jepa/classifier/factory.py b/src/fairseq2/models/jepa/classifier/factory.py new file mode 100644 index 000000000..bb7a6ea5d --- /dev/null +++ b/src/fairseq2/models/jepa/classifier/factory.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from typing import final + +from fairseq2.config_registry import ConfigRegistry +from fairseq2.models.factory import model_factories +from fairseq2.models.jepa import JepaEncoderBuilder, JepaEncoderConfig +from fairseq2.models.jepa.classifier.model import ( + AttentivePooler, + CrossAttentionDecoderLayer, + JepaClassifierModel, +) +from fairseq2.nn.projection import IdentityProjection, Linear, Projection +from fairseq2.nn.transformer import ( + MultiheadAttention, + StandardMultiheadAttention, + create_default_sdpa, +) +from fairseq2.typing import DataType, Device + +JEPA_CLASSIFIER_FAMILY = "jepa_classifier" + + +@dataclass(kw_only=True) +class JepaClassifierConfig: + encoder_config: JepaEncoderConfig = field( + default_factory=lambda: JepaEncoderConfig() + ) + """The configuration of the vision encoder.""" + + pool_depth: int = 1 + """The pool depth (minimum 1 decoder layer)""" + + decoder_projection: bool = True + """If True, the decoder will have a linear layer on top""" + + num_queries: int = 1 + """Number of query tokens in the attention pool layer""" + + num_classes: int = 1000 + """Size of classification logits""" + + +jepa_classifier_archs = ConfigRegistry[JepaClassifierConfig]() + +jepa_classifier_arch = jepa_classifier_archs.decorator + + +@final +class JepaClassifierBuilder: + """Build a JEPA model fine-tuned for classification""" + + _config: JepaClassifierConfig + _encoder_builder: JepaEncoderBuilder + _device: Device | None + _dtype: DataType | None + + def __init__( + self, + config: JepaClassifierConfig, + *, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + self._config = config + + self._encoder_builder = JepaEncoderBuilder( + config.encoder_config, device=device, dtype=dtype + ) + + self._device, self._dtype = device, dtype + + def build_model(self) -> JepaClassifierModel: + encoder_frontend = self._encoder_builder.build_frontend() + encoder = self._encoder_builder.build_encoder() + pooler = self.build_pooler() + head = self.build_head() + + return JepaClassifierModel(encoder_frontend, encoder, pooler, head) + + def build_pooler(self) -> AttentivePooler: + config = self._config + + if config.pool_depth > 1: + encoder = self._encoder_builder.build_encoder(config.pool_depth) + else: + encoder = None + + decoder = self.build_decoder_layer() + + return AttentivePooler( + decoder=decoder, + encoder=encoder, + num_queries=config.num_queries, + init_std=config.encoder_config.init_std, + device=self._device, + dtype=self._dtype, + ) + + def build_head(self) -> Projection: + config = self._config + return Linear( + config.encoder_config.model_dim, + config.num_classes, + device=self._device, + dtype=self._dtype, + bias=True, + ) + + def build_decoder_layer(self) -> CrossAttentionDecoderLayer: + config = self._config + + cross_attn = self.build_cross_attention() + + ffn = self._encoder_builder.build_ffn(config.pool_depth) + + return CrossAttentionDecoderLayer( + cross_attn, + ffn, + layer_norm_factory=self._encoder_builder.build_layer_norm, + device=self._device, + dtype=self._dtype, + ) + + def build_cross_attention(self) -> MultiheadAttention: + config = self._config.encoder_config + + model_dim = config.model_dim + + sdpa = create_default_sdpa(attn_dropout_p=config.attn_dropout_p) + + output_proj = self.build_cross_attn_output_projection() + + return StandardMultiheadAttention( + model_dim, + config.num_encoder_attn_heads, + sdpa=sdpa, + bias=config.qkv_bias, + output_proj=output_proj, + device=self._device, + dtype=self._dtype, + ) + + def build_cross_attn_output_projection(self) -> Projection: + config = self._config + + model_dim = config.encoder_config.model_dim + + if config.decoder_projection: + return Linear( + model_dim, + model_dim, + bias=True, + device=self._device, + dtype=self._dtype, + ) + else: + return IdentityProjection(model_dim, model_dim) + + +def create_jepa_classifier_model( + config: JepaClassifierConfig, + *, + device: Device | None = None, + dtype: DataType | None = None, +) -> JepaClassifierModel: + return JepaClassifierBuilder( + config, + device=device, + dtype=dtype, + ).build_model() + + +model_factories.register( + JEPA_CLASSIFIER_FAMILY, + create_jepa_classifier_model, + JepaClassifierConfig, + jepa_classifier_archs, +) diff --git a/src/fairseq2/models/jepa/classifier/model.py b/src/fairseq2/models/jepa/classifier/model.py new file mode 100644 index 000000000..0ba5541d3 --- /dev/null +++ b/src/fairseq2/models/jepa/classifier/model.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import final + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import Module, Parameter + +from fairseq2.models.sequence import SequenceBatch +from fairseq2.models.transformer import TransformerFrontend +from fairseq2.nn.normalization import LayerNorm +from fairseq2.nn.projection import Projection +from fairseq2.nn.transformer import ( + FeedForwardNetwork, + LayerNormFactory, + MultiheadAttention, + TransformerEncoder, + create_standard_layer_norm, +) +from fairseq2.typing import DataType, Device + + +@final +class JepaClassifierModel(Module): + """ + Represents a pretrained Jepa model, with an attentive probing layer for + classfication tasks. See + * :cite:t:`https://doi.org/10.48550/arXiv.2301.08243` + * :cite:t:`https://doi.org/10.48550/arXiv.2404.08471` + """ + + model_dim: int + encoder_frontend: TransformerFrontend + encoder: TransformerEncoder + pooler: AttentivePooler + head: Projection + + def __init__( + self, + encoder_frontend: TransformerFrontend, + encoder: TransformerEncoder, + pooler: AttentivePooler, + head: Projection, + ) -> None: + super().__init__() + + self.model_dim = encoder.model_dim + + self.encoder_frontend = encoder_frontend + self.encoder = encoder + + self.pooler = pooler + + self.head = head + + def forward(self, batch: SequenceBatch) -> Tensor: + seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.padding_mask) + + seqs, _ = self.encoder(seqs, padding_mask) + + seqs = self.pooler(seqs) + + # (N, P, M) + seqs = seqs.squeeze(1) # TODO: NEEDED? + + return self.head(seqs) # type: ignore[no-any-return] + + def extra_repr(self) -> str: + """:meta private:""" + return f"model_dim={self.model_dim}" + + +@final +class AttentivePooler(Module): + """ + An attentive pooler that gets output of a Jepa encoder and decode it into + a logit of a given task. + + TODO: + - Move this into fairseq2.nn to benefit other similiar tasks. Internally, + this module is just a thin transformer encoder without self attention layer. + Optionally, it can consist of some extra transformer encoders depending on the + (finetuning) task + """ + + model_dim: int + decoder: CrossAttentionDecoderLayer + encoder: TransformerEncoder | None + query_tokens: Parameter + init_std: float + + def __init__( + self, + decoder: CrossAttentionDecoderLayer, + encoder: TransformerEncoder | None, + *, + num_queries: int = 1, + init_std: float = 0.02, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + super().__init__() + + self.model_dim = decoder.model_dim + + self.decoder = decoder + + if encoder: + self.encoder = encoder + else: + self.register_module("encoder", None) + + self.query_tokens = Parameter( + torch.empty((1, num_queries, self.model_dim), device=device, dtype=dtype) + ) + + self.init_std = init_std + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset the parameters and buffers of the module.""" + nn.init.trunc_normal_(self.query_tokens, std=self.init_std) + + def forward(self, seqs: Tensor) -> Tensor: + if self.encoder is not None: + seqs, _ = self.encoder(seqs, padding_mask=None) + + batch_size = seqs.size(0) + + # (1, P, M) -> (N, P, M) + pool_seqs = self.query_tokens.repeat(batch_size, 1, 1) + + return self.decoder(pool_seqs, seqs) # type: ignore[no-any-return] + + def extra_repr(self) -> str: + """:meta private:""" + return f"model_dim={self.model_dim}, num_queries={self.query_tokens.size(1)}" + + +@final +class CrossAttentionDecoderLayer(Module): + """Represents a simple transformer decoder with only cross attention and layernorm""" + + model_dim: int + cross_attn_layer_norm: LayerNorm + cross_attn: MultiheadAttention + ffn_layer_norm: LayerNorm + ffn: FeedForwardNetwork + + def __init__( + self, + cross_attn: MultiheadAttention, + ffn: FeedForwardNetwork, + *, + layer_norm_factory: LayerNormFactory | None = None, + device: Device | None = None, + dtype: DataType | None = None, + ) -> None: + """ + :param cross_attn: + The encoder-decoder attention layer. + :param ffn: + The feed-forward network. + :param layer_norm_factory: + The factory to construct the Layer Normalization modules. + """ + super().__init__() + + model_dim = cross_attn.model_dim + + if layer_norm_factory is None: + layer_norm_factory = create_standard_layer_norm + + self.cross_attn_layer_norm = layer_norm_factory( + model_dim, device=device, dtype=dtype + ) + + self.model_dim = model_dim + + self.cross_attn = cross_attn + + self.ffn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype) + + self.ffn = ffn + + def forward(self, seqs: Tensor, encoder_output: Tensor) -> Tensor: + seqs = self._forward_cross_attn(seqs, encoder_output) + + seqs = self._forward_ffn(seqs) + + return seqs + + def _forward_cross_attn(self, seqs: Tensor, encoder_output: Tensor) -> Tensor: + residual = seqs + + # Note that the cross-attention norm is applied on encoder output and not seqs + encoder_output = self.cross_attn_layer_norm(encoder_output) + + seqs = self.cross_attn( + seqs, + padding_mask=None, + keys=encoder_output, + key_padding_mask=None, + values=encoder_output, + ) + + seqs = seqs + residual + + return seqs + + def _forward_ffn(self, seqs: Tensor) -> Tensor: + residual = seqs + + seqs = self.ffn_layer_norm(seqs) + + seqs = self.ffn(seqs) + + seqs = seqs + residual + + return seqs + + def extra_repr(self) -> str: + """:meta private:""" + return f"model_dim={self.model_dim}" diff --git a/src/fairseq2/models/jepa/factory.py b/src/fairseq2/models/jepa/factory.py index 152df5b22..b0eb74969 100644 --- a/src/fairseq2/models/jepa/factory.py +++ b/src/fairseq2/models/jepa/factory.py @@ -130,6 +130,7 @@ class JepaEncoderConfig: jepa_arch = jepa_archs.decorator +# TODO(balioglu): work in progress. Supports only vision encoder. class JepaBuilder: """Builds modules of a JEPA model.""" @@ -278,10 +279,11 @@ def build_position_encoder(self) -> InterpolatedPositionEncoder: f"The length of `input_dims` must be 2 or 3, but is {num_input_dims} instead." ) - def build_encoder(self) -> TransformerEncoder: + def build_encoder(self, num_layers: int | None = None) -> TransformerEncoder: config = self._config - num_layers = config.num_encoder_layers + if num_layers is None: + num_layers = config.num_encoder_layers layers = [self.build_encoder_layer(i) for i in range(num_layers)]