From 737ec9f136f37531feaac33c6b083c1241901d06 Mon Sep 17 00:00:00 2001
From: Franklin <41602287+fcogidi@users.noreply.github.com>
Date: Tue, 29 Oct 2024 10:15:51 -0400
Subject: [PATCH] Update Zero-shot Classification Task (#27)
---
README.md | 11 +++++
mmlearn/tasks/zero_shot_classification.py | 43 +++++++++++++------
.../zeroshot_classification_eval.yaml | 9 ++--
3 files changed, 44 insertions(+), 19 deletions(-)
diff --git a/README.md b/README.md
index d353106..6baa3a7 100644
--- a/README.md
+++ b/README.md
@@ -156,6 +156,17 @@ Evaluates the quality of the learned representations in retrieving the k
using recall@k metric. This is applicable to any number of pairs of modalities at once, depending on memory constraints.
+
+
+
+Zero-shot Classification
+ |
+
+Evaluates the ability of a pre-trained encoder-based multimodal model to predict classes that were not explicitly seen
+during training. The new classes are given as text prompts, and the query modality can be any of the supported modalities.
+Binary and multi-class classification tasks are supported.
+ |
+
## Components
diff --git a/mmlearn/tasks/zero_shot_classification.py b/mmlearn/tasks/zero_shot_classification.py
index 3eaae53..f5a4c08 100644
--- a/mmlearn/tasks/zero_shot_classification.py
+++ b/mmlearn/tasks/zero_shot_classification.py
@@ -195,7 +195,13 @@ def evaluation_step(
query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True)
query_embeddings = query_embeddings[matching_indices]
- logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
+ if self.all_dataset_info[dataset_index]["num_classes"] == 2:
+ softmax_output = _safe_matmul(
+ query_embeddings, class_embeddings
+ ).softmax(dim=-1)
+ logits = softmax_output[:, 1] - softmax_output[:, 0]
+ else:
+ logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
targets = batch[Modalities.get_modality(query_modality).target][
matching_indices
]
@@ -233,27 +239,36 @@ def _create_metrics(
num_classes: int, top_k: List[int], prefix: str, postfix: str
) -> MetricCollection:
"""Create a collection of classification metrics."""
+ task_type = "binary" if num_classes == 2 else "multiclass"
+ acc_metrics = (
+ {
+ f"top{k}_accuracy": Accuracy(
+ task=task_type, num_classes=num_classes, top_k=k, average="micro"
+ )
+ for k in top_k
+ }
+ if num_classes > 2
+ else {"accuracy": Accuracy(task=task_type, num_classes=num_classes)}
+ )
return MetricCollection(
{
"precision": Precision(
- task="multiclass", num_classes=num_classes, average="macro"
+ task=task_type,
+ num_classes=num_classes,
+ average="macro" if num_classes > 2 else "micro",
),
"recall": Recall(
- task="multiclass", num_classes=num_classes, average="macro"
+ task=task_type,
+ num_classes=num_classes,
+ average="macro" if num_classes > 2 else "micro",
),
"f1_score_macro": F1Score(
- task="multiclass", num_classes=num_classes, average="macro"
+ task=task_type,
+ num_classes=num_classes,
+ average="macro" if num_classes > 2 else "micro",
),
- "aucroc": AUROC(task="multiclass", num_classes=num_classes),
- **{
- f"top{k}_accuracy": Accuracy(
- task="multiclass",
- num_classes=num_classes,
- top_k=k,
- average="micro",
- )
- for k in top_k
- },
+ "aucroc": AUROC(task=task_type, num_classes=num_classes),
+ **acc_metrics,
},
prefix=prefix,
postfix=postfix,
diff --git a/projects/med_benchmarking/configs/experiment/zeroshot_classification_eval.yaml b/projects/med_benchmarking/configs/experiment/zeroshot_classification_eval.yaml
index 3a43ba7..4073b99 100644
--- a/projects/med_benchmarking/configs/experiment/zeroshot_classification_eval.yaml
+++ b/projects/med_benchmarking/configs/experiment/zeroshot_classification_eval.yaml
@@ -143,7 +143,7 @@ datasets:
dataloader:
test:
- batch_size: 64
+ batch_size: 128
num_workers: 4
task:
@@ -153,15 +153,14 @@ task:
task_specs:
- top_k: [1]
query_modality: rgb
- run_on_validation: false
- run_on_test: true
+ run_on_validation: False
+ run_on_test: True
compute_validation_loss: False
compute_test_loss: False
trainer:
precision: 16-mixed
- deterministic: False
- benchmark: True
+ deterministic: True
sync_batchnorm: False # set to True if using DDP with batchnorm
log_every_n_steps: 100