From 737ec9f136f37531feaac33c6b083c1241901d06 Mon Sep 17 00:00:00 2001 From: Franklin <41602287+fcogidi@users.noreply.github.com> Date: Tue, 29 Oct 2024 10:15:51 -0400 Subject: [PATCH] Update Zero-shot Classification Task (#27) --- README.md | 11 +++++ mmlearn/tasks/zero_shot_classification.py | 43 +++++++++++++------ .../zeroshot_classification_eval.yaml | 9 ++-- 3 files changed, 44 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index d353106..6baa3a7 100644 --- a/README.md +++ b/README.md @@ -156,6 +156,17 @@ Evaluates the quality of the learned representations in retrieving the k using recall@k metric. This is applicable to any number of pairs of modalities at once, depending on memory constraints. + + + +Zero-shot Classification + + +Evaluates the ability of a pre-trained encoder-based multimodal model to predict classes that were not explicitly seen +during training. The new classes are given as text prompts, and the query modality can be any of the supported modalities. +Binary and multi-class classification tasks are supported. + + ## Components diff --git a/mmlearn/tasks/zero_shot_classification.py b/mmlearn/tasks/zero_shot_classification.py index 3eaae53..f5a4c08 100644 --- a/mmlearn/tasks/zero_shot_classification.py +++ b/mmlearn/tasks/zero_shot_classification.py @@ -195,7 +195,13 @@ def evaluation_step( query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True) query_embeddings = query_embeddings[matching_indices] - logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings) + if self.all_dataset_info[dataset_index]["num_classes"] == 2: + softmax_output = _safe_matmul( + query_embeddings, class_embeddings + ).softmax(dim=-1) + logits = softmax_output[:, 1] - softmax_output[:, 0] + else: + logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings) targets = batch[Modalities.get_modality(query_modality).target][ matching_indices ] @@ -233,27 +239,36 @@ def _create_metrics( num_classes: int, top_k: List[int], prefix: str, postfix: str ) -> MetricCollection: """Create a collection of classification metrics.""" + task_type = "binary" if num_classes == 2 else "multiclass" + acc_metrics = ( + { + f"top{k}_accuracy": Accuracy( + task=task_type, num_classes=num_classes, top_k=k, average="micro" + ) + for k in top_k + } + if num_classes > 2 + else {"accuracy": Accuracy(task=task_type, num_classes=num_classes)} + ) return MetricCollection( { "precision": Precision( - task="multiclass", num_classes=num_classes, average="macro" + task=task_type, + num_classes=num_classes, + average="macro" if num_classes > 2 else "micro", ), "recall": Recall( - task="multiclass", num_classes=num_classes, average="macro" + task=task_type, + num_classes=num_classes, + average="macro" if num_classes > 2 else "micro", ), "f1_score_macro": F1Score( - task="multiclass", num_classes=num_classes, average="macro" + task=task_type, + num_classes=num_classes, + average="macro" if num_classes > 2 else "micro", ), - "aucroc": AUROC(task="multiclass", num_classes=num_classes), - **{ - f"top{k}_accuracy": Accuracy( - task="multiclass", - num_classes=num_classes, - top_k=k, - average="micro", - ) - for k in top_k - }, + "aucroc": AUROC(task=task_type, num_classes=num_classes), + **acc_metrics, }, prefix=prefix, postfix=postfix, diff --git a/projects/med_benchmarking/configs/experiment/zeroshot_classification_eval.yaml b/projects/med_benchmarking/configs/experiment/zeroshot_classification_eval.yaml index 3a43ba7..4073b99 100644 --- a/projects/med_benchmarking/configs/experiment/zeroshot_classification_eval.yaml +++ b/projects/med_benchmarking/configs/experiment/zeroshot_classification_eval.yaml @@ -143,7 +143,7 @@ datasets: dataloader: test: - batch_size: 64 + batch_size: 128 num_workers: 4 task: @@ -153,15 +153,14 @@ task: task_specs: - top_k: [1] query_modality: rgb - run_on_validation: false - run_on_test: true + run_on_validation: False + run_on_test: True compute_validation_loss: False compute_test_loss: False trainer: precision: 16-mixed - deterministic: False - benchmark: True + deterministic: True sync_batchnorm: False # set to True if using DDP with batchnorm log_every_n_steps: 100