Skip to content

Commit

Permalink
Merge branch 'main' into add_data2vec_task
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 authored Dec 18, 2024
2 parents 3d72770 + ad7de05 commit 18e3b9a
Show file tree
Hide file tree
Showing 24 changed files with 1,300 additions and 477 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,18 @@ unimodal tasks applied to specific modalities.
</td>
</tr>
<tr>
<td>

I-JEPA
</td>
<td>
The <a href=https://arxiv.org/pdf/2301.08243>Image-based Joint-Embedding Predictive Architecture</a> (I-JEPA) is a unimodal non-generative
self-supervised learning method that predicts the <i>representations</i> of several target blocks of an image given a context block
from the same image. This task can be combined with the contrastive pretraining task to learn multimodal representations from
paired and unpaired data.
</td>
</tr>
<tr>
<th style="text-align: left; width: 250px"> Evaluation Methods </th>
<th style="text-align: center"> Notes </th>
</tr>
Expand Down
2 changes: 1 addition & 1 deletion mmlearn/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def main(cfg: MMLearnConf) -> None: # noqa: PLR0912

if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = True
if "16-mixed" in cfg.trainer.precision:
if "16-mixed" in str(cfg.trainer.precision):
cfg.trainer.precision = "bf16-mixed"

# setup trainer first so that we can get some variables for distributed training
Expand Down
2 changes: 1 addition & 1 deletion mmlearn/conf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class MMLearnConf:
job=JobConf(
name=II("experiment_name"),
env_set={
"TORCH_NCCL_ASYNC_ERROR_HANDLING": "3",
"TORCH_NCCL_ASYNC_ERROR_HANDLING": "1",
"HYDRA_FULL_ERROR": "1",
},
),
Expand Down
80 changes: 48 additions & 32 deletions mmlearn/datasets/processors/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,31 +237,38 @@ def apply_masks(
Parameters
----------
x : torch.Tensor
Input tensor of shape (B, N, D), where B is the batch size, N is the number
of patches, and D is the feature dimension.
Input tensor of shape (B, N, D).
masks : Union[torch.Tensor, List[torch.Tensor]]
A list of tensors containing the indices of patches to keep for each sample.
Each mask tensor has shape (B, N), where B is the batch size and N is the number
of patches.
A list of mask tensors of shape (N,), (1, N), or (B, N).
Returns
-------
torch.Tensor
The masked tensor where only the patches indicated by the masks are kept.
The output tensor has shape (B', N', D), where B' is the new batch size
(which may be different due to concatenation) and N' is the
reduced number of patches.
Notes
-----
- The masks should indicate which patches to keep (1 for keep, 0 for discard).
- The function uses `torch.gather` to select the patches specified by the masks.
The output tensor has shape (B * num_masks, N', D),
where N' is the number of patches kept.
"""
all_x = []
for m in masks:
# Expand the mask to match the feature dimension and gather the relevant patches
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
all_x.append(torch.gather(x, dim=1, index=mask_keep))
batch_size = x.size(0)
for m_ in masks:
m = m_.to(x.device)

# Ensure mask is at least 2D
if m.dim() == 1:
m = m.unsqueeze(0) # Shape: (1, N)

# Expand mask to match the batch size if needed
if m.size(0) == 1 and batch_size > 1:
m = m.expand(batch_size, -1) # Shape: (B, N)

# Expand mask to match x's dimensions
m_expanded = (
m.unsqueeze(-1).expand(-1, -1, x.size(-1)).bool()
) # Shape: (B, N, D)

# Use boolean indexing
selected_patches = x[m_expanded].view(batch_size, -1, x.size(-1))
all_x.append(selected_patches)

# Concatenate along the batch dimension
return torch.cat(all_x, dim=0)
Expand All @@ -271,40 +278,39 @@ def apply_masks(
class IJEPAMaskGenerator:
"""Generates encoder and predictor masks for preprocessing.
This class generates masks dynamically for individual examples and can be passed to
a data loader as a preprocessing step.
This class generates masks dynamically for batches of examples.
Parameters
----------
input_size : tuple[int, int], default=(224, 224)
Input image size.
patch_size : int, default=16
Size of each patch.
min_keep : int, default=4
min_keep : int, default=10
Minimum number of patches to keep.
allow_overlap : bool, default=False
Whether to allow overlap between encoder and predictor masks.
enc_mask_scale : tuple[float, float], default=(0.2, 0.8)
enc_mask_scale : tuple[float, float], default=(0.85, 1.0)
Scale range for encoder mask.
pred_mask_scale : tuple[float, float], default=(0.2, 0.8)
pred_mask_scale : tuple[float, float], default=(0.15, 0.2)
Scale range for predictor mask.
aspect_ratio : tuple[float, float], default=(0.3, 3.0)
aspect_ratio : tuple[float, float], default=(0.75, 1.0)
Aspect ratio range for mask blocks.
nenc : int, default=1
Number of encoder masks to generate.
npred : int, default=2
npred : int, default=4
Number of predictor masks to generate.
"""

input_size: Tuple[int, int] = (224, 224)
patch_size: int = 16
min_keep: int = 4
min_keep: int = 10
allow_overlap: bool = False
enc_mask_scale: Tuple[float, float] = (0.2, 0.8)
pred_mask_scale: Tuple[float, float] = (0.2, 0.8)
aspect_ratio: Tuple[float, float] = (0.3, 3.0)
enc_mask_scale: Tuple[float, float] = (0.85, 1.0)
pred_mask_scale: Tuple[float, float] = (0.15, 0.2)
aspect_ratio: Tuple[float, float] = (0.75, 1.0)
nenc: int = 1
npred: int = 2
npred: int = 4

def __post_init__(self) -> None:
"""Initialize the mask generator."""
Expand Down Expand Up @@ -353,8 +359,14 @@ def _sample_block_mask(

def __call__(
self,
batch_size: int = 1,
) -> Dict[str, Any]:
"""Generate encoder and predictor masks for a single example.
"""Generate encoder and predictor masks for a batch of examples.
Parameters
----------
batch_size : int, default=1
The batch size for which to generate masks.
Returns
-------
Expand All @@ -378,14 +390,18 @@ def __call__(
masks_pred, masks_enc = [], []
for _ in range(self.npred):
mask_p, _ = self._sample_block_mask(p_size)
# Expand mask to match batch size
mask_p = mask_p.unsqueeze(0).expand(batch_size, -1)
masks_pred.append(mask_p)

# Generate encoder masks
for _ in range(self.nenc):
mask_e, _ = self._sample_block_mask(e_size)
# Expand mask to match batch size
mask_e = mask_e.unsqueeze(0).expand(batch_size, -1)
masks_enc.append(mask_e)

return {
"encoder_masks": torch.stack(masks_enc),
"predictor_masks": torch.stack(masks_pred),
"encoder_masks": masks_enc, # List of tensors of shape (batch_size, N)
"predictor_masks": masks_pred, # List of tensors of shape (batch_size, N)
}
2 changes: 1 addition & 1 deletion mmlearn/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def load_huggingface_model(
return_unused_kwargs=True,
**model_config_kwargs,
)
model = model_type._from_config(config, **kwargs)
model = model_type.from_config(config, **kwargs)

if get_model_attr is not None and hasattr(model, get_model_attr):
model = getattr(model, get_model_attr)
Expand Down
92 changes: 46 additions & 46 deletions mmlearn/modules/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,52 @@ def __init__(
self.ema_end_decay = ema_end_decay
self.ema_anneal_end_step = ema_anneal_end_step

@staticmethod
def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module:
"""Deep copy the model."""
try:
return copy.deepcopy(model)
except RuntimeError as e:
raise RuntimeError("Unable to copy the model ", e) from e

@staticmethod
def get_annealed_rate(
start: float,
end: float,
curr_step: int,
total_steps: int,
) -> float:
"""Calculate EMA annealing rate."""
r = end - start
pct_remaining = 1 - curr_step / total_steps
return end - r * pct_remaining

def step(self, new_model: torch.nn.Module) -> None:
"""Perform single EMA update step."""
self._update_weights(new_model)
self._update_ema_decay()

def restore(self, model: torch.nn.Module) -> torch.nn.Module:
"""Reassign weights from another model.
Parameters
----------
model : nn.Module
Model to load weights from.
Returns
-------
nn.Module
model with new weights
"""
d = self.model.state_dict()
model.load_state_dict(d, strict=False)
return model

def state_dict(self) -> dict[str, Any]:
"""Return the state dict of the model."""
return self.model.state_dict() # type: ignore[no-any-return]

@torch.no_grad() # type: ignore[misc]
def _update_weights(self, new_model: torch.nn.Module) -> None:
if self.decay < 1:
Expand Down Expand Up @@ -98,49 +144,3 @@ def _update_ema_decay(self) -> None:
self.ema_anneal_end_step,
)
self.decay = decay

def step(self, new_model: torch.nn.Module) -> None:
"""Perform single EMA update step."""
self._update_weights(new_model)
self._update_ema_decay()

@staticmethod
def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module:
"""Deep copy the model."""
try:
return copy.deepcopy(model)
except RuntimeError as e:
raise RuntimeError("Unable to copy the model ", e) from e

def restore(self, model: torch.nn.Module) -> torch.nn.Module:
"""Reassign weights from another model.
Parameters
----------
model : nn.Module
Model to load weights from.
Returns
-------
nn.Module
model with new weights
"""
d = self.model.state_dict()
model.load_state_dict(d, strict=False)
return model

def state_dict(self) -> dict[str, Any]:
"""Return the state dict of the model."""
return self.model.state_dict() # type: ignore[no-any-return]

@staticmethod
def get_annealed_rate(
start: float,
end: float,
curr_step: int,
total_steps: int,
) -> float:
"""Calculate EMA annealing rate."""
r = end - start
pct_remaining = 1 - curr_step / total_steps
return end - r * pct_remaining
4 changes: 2 additions & 2 deletions mmlearn/modules/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
HFCLIPTextEncoderWithProjection,
HFCLIPVisionEncoder,
HFCLIPVisionEncoderWithProjection,
PubMedBERTForCLIPTextEncoding,
)
from mmlearn.modules.encoders.text import HFTextEncoder
from mmlearn.modules.encoders.vision import TimmViT


__all__ = [
Expand All @@ -16,5 +16,5 @@
"HFCLIPTextEncoderWithProjection",
"HFCLIPVisionEncoder",
"HFCLIPVisionEncoderWithProjection",
"PubMedBERTForCLIPTextEncoding",
"TimmViT",
]
Loading

0 comments on commit 18e3b9a

Please sign in to comment.