Skip to content

Commit

Permalink
Warning for CLIP Score on long captions (Lightning-AI#2001)
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Aug 19, 2023
1 parent 3b3b997 commit 9fdd57c
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added argument `extended_summary` to `MeanAveragePrecision` such that precision, recall, iou can be easily returned ([#1983](https://github.com/Lightning-AI/torchmetrics/pull/1983))


- Added warning to `ClipScore` if long captions are detected and truncate ([#2001](https://github.com/Lightning-AI/torchmetrics/pull/2001))


### Changed

-
Expand Down
12 changes: 12 additions & 0 deletions src/torchmetrics/functional/multimodal/clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout
from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_10

Expand Down Expand Up @@ -65,6 +66,17 @@ def _clip_score_update(
img_features = model.get_image_features(processed_input["pixel_values"].to(device))
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)

max_position_embeddings = model.config.text_config.max_position_embeddings
if processed_input["attention_mask"].shape[-1] > max_position_embeddings:
rank_zero_warn(
f"Encountered caption longer than {max_position_embeddings=}. Will truncate captions to this length."
"If longer captions are needed, initialize argument `model_name_or_path` with a model that supports"
"longer sequences",
UserWarning,
)
processed_input["attention_mask"] = processed_input["attention_mask"][..., :max_position_embeddings]
processed_input["input_ids"] = processed_input["input_ids"][..., :max_position_embeddings]

txt_features = model.get_text_features(
processed_input["input_ids"].to(device), processed_input["attention_mask"].to(device)
)
Expand Down
12 changes: 12 additions & 0 deletions tests/unittests/multimodal/test_clip_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,15 @@ def test_plot_method(self, inputs, model_name_or_path):
fig, ax = metric.plot()
assert isinstance(fig, plt.Figure)
assert isinstance(ax, matplotlib.axes.Axes)

@skip_on_connection_issues()
def test_warning_on_long_caption(self, inputs, model_name_or_path):
"""Test that warning is given on long captions but metric still works."""
metric = CLIPScore(model_name_or_path=model_name_or_path)
preds, target = inputs
target[0] = [target[0][0], "A 28-year-old chef who recently moved to San Francisco was found dead. " * 100]
with pytest.warns(
UserWarning,
match="Encountered caption longer than max_position_embeddings=77. Will truncate captions to this length.*",
):
metric.update(preds[0], target[0])

0 comments on commit 9fdd57c

Please sign in to comment.