Skip to content

Commit

Permalink
Refactor contrastive loss (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi authored Dec 3, 2024
1 parent 54fc119 commit 3dea00f
Show file tree
Hide file tree
Showing 12 changed files with 508 additions and 265 deletions.
2 changes: 1 addition & 1 deletion mmlearn/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def main(cfg: MMLearnConf) -> None: # noqa: PLR0912

if is_torch_tf32_available():
torch.backends.cuda.matmul.allow_tf32 = True
if "16-mixed" in cfg.trainer.precision:
if "16-mixed" in str(cfg.trainer.precision):
cfg.trainer.precision = "bf16-mixed"

# setup trainer first so that we can get some variables for distributed training
Expand Down
2 changes: 1 addition & 1 deletion mmlearn/conf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class MMLearnConf:
job=JobConf(
name=II("experiment_name"),
env_set={
"TORCH_NCCL_ASYNC_ERROR_HANDLING": "3",
"TORCH_NCCL_ASYNC_ERROR_HANDLING": "1",
"HYDRA_FULL_ERROR": "1",
},
),
Expand Down
2 changes: 1 addition & 1 deletion mmlearn/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def load_huggingface_model(
return_unused_kwargs=True,
**model_config_kwargs,
)
model = model_type._from_config(config, **kwargs)
model = model_type.from_config(config, **kwargs)

if get_model_attr is not None and hasattr(model, get_model_attr):
model = getattr(model, get_model_attr)
Expand Down
4 changes: 2 additions & 2 deletions mmlearn/modules/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
HFCLIPTextEncoderWithProjection,
HFCLIPVisionEncoder,
HFCLIPVisionEncoderWithProjection,
PubMedBERTForCLIPTextEncoding,
)
from mmlearn.modules.encoders.text import HFTextEncoder
from mmlearn.modules.encoders.vision import TimmViT


__all__ = [
Expand All @@ -16,5 +16,5 @@
"HFCLIPTextEncoderWithProjection",
"HFCLIPVisionEncoder",
"HFCLIPVisionEncoderWithProjection",
"PubMedBERTForCLIPTextEncoding",
"TimmViT",
]
117 changes: 0 additions & 117 deletions mmlearn/modules/encoders/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,123 +474,6 @@ def forward(self, inputs: Dict[str, Any]) -> Tuple[torch.Tensor]:
return (self.model.visual_projection(pooled_output),)


@store(group="modules/encoders", provider="mmlearn", hydra_convert="object")
class PubMedBERTForCLIPTextEncoding(nn.Module):
"""BiomedNLP's PubMedBERT model for CLIP text encoding.
This module is wrapper around the PubMedBERT model from huggingface.
Parameters
----------
pretrained : bool, default=False
Whether to load the pretrained weights or not.
pooling_layer : nn.Module, optional, default=None
Pooling layer to apply to the last hidden state of the model.
freeze_layers : int | float | List[int] | bool, default=False
Whether to freeze layers of the model and which layers to freeze. If `True`,
all model layers are frozen. If it is an integer, the first `N` layers of
the model are frozen. If it is a float, the first `N` percent of the layers
are frozen. If it is a list of integers, the layers at the indices in the
list are frozen.
freeze_layer_norm : bool, default=True
Whether to freeze the layer normalization layers of the model.
peft_config : PeftConfig, optional, default=None
The configuration from the `peft` library to use to wrap the model
for parameter-efficient finetuning.
model_config_kwargs : Dict[str, Any], optional, default=None
Additional keyword arguments to pass to the model configuration.
Warns
-----
UserWarning
If both `peft_config` and `freeze_layers` are set. The `peft_config` will
override the `freeze_layers` setting.
"""

def __init__(
self,
pretrained: bool = True,
pooling_layer: Optional[nn.Module] = None,
freeze_layers: Union[int, float, List[int], bool] = False,
freeze_layer_norm: bool = True,
peft_config: Optional["PeftConfig"] = None,
model_config_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the model."""
super().__init__()
_warn_freeze_with_peft(peft_config, freeze_layers)

model = hf_utils.load_huggingface_model(
transformers.AutoModelForMaskedLM,
"microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
load_pretrained_weights=pretrained,
get_model_attr="bert",
model_config_kwargs=model_config_kwargs,
)

if isinstance(freeze_layers, bool) and freeze_layers:
for name, param in model.named_parameters():
param.requires_grad = (
(not freeze_layer_norm) if "LayerNorm" in name else False
)

layers = [model.embeddings, *model.encoder.layer]
if isinstance(freeze_layers, float):
freeze_layers = int(freeze_layers * len(layers))
if isinstance(freeze_layers, int):
freeze_layers = list(range(freeze_layers))

if isinstance(freeze_layers, list):
for idx, layer in enumerate(layers):
if idx in freeze_layers:
for name, param in layer.named_parameters():
param.requires_grad = (
(not freeze_layer_norm) if "LayerNorm" in name else False
)

if peft_config is not None:
model = hf_utils._wrap_peft_model(model, peft_config)

self.model = model
self.pooling_layer = pooling_layer

def forward(self, inputs: Dict[str, Any]) -> BaseModelOutput:
"""Run the forward pass.
Parameters
----------
inputs : Dict[str, Any]
The input data. The `input_ids` will be expected under the `Modalities.TEXT`
key.
Returns
-------
BaseModelOutput
The output of the model, including the last hidden state, all hidden states,
and the attention weights, if `output_attentions` is set to `True`.
"""
output = self.model(
input_ids=inputs[Modalities.TEXT.name],
attention_mask=inputs.get(
"attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
),
inputs_embeds=inputs.get("inputs_embeds"),
output_attentions=inputs.get("output_attentions"),
output_hidden_states=True,
return_dict=True,
)
last_hidden_state = output.last_hidden_state
if self.pooling_layer is not None:
last_hidden_state = self.pooling_layer(last_hidden_state)

return BaseModelOutput(
last_hidden_state=last_hidden_state,
hidden_states=output.hidden_states,
attentions=output.attentions,
)


#### Utility methods ####


Expand Down
2 changes: 1 addition & 1 deletion mmlearn/modules/encoders/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
@store(
group="modules/encoders",
provider="mmlearn",
model_name_or_path="vit_base_patch16_224",
model_name="vit_base_patch16_224",
hydra_convert="object",
)
class TimmViT(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions mmlearn/modules/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Loss functions."""

from mmlearn.modules.losses.contrastive import CLIPLoss
from mmlearn.modules.losses.contrastive import ContrastiveLoss
from mmlearn.modules.losses.data2vec import Data2VecLoss


__all__ = ["CLIPLoss", "Data2VecLoss"]
__all__ = ["ContrastiveLoss", "Data2VecLoss"]
Loading

0 comments on commit 3dea00f

Please sign in to comment.