Skip to content

Commit

Permalink
Update Zero-shot Classification Task (#27)
Browse files Browse the repository at this point in the history
fcogidi authored Oct 29, 2024
1 parent a247b4e commit 737ec9f
Showing 3 changed files with 44 additions and 19 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -156,6 +156,17 @@ Evaluates the quality of the learned representations in retrieving the <i>k</i>
using recall@k metric. This is applicable to any number of pairs of modalities at once, depending on memory constraints.
</td>
</tr>
<tr>
<td>

Zero-shot Classification
</td>
<td>
Evaluates the ability of a pre-trained encoder-based multimodal model to predict classes that were not explicitly seen
during training. The new classes are given as text prompts, and the query modality can be any of the supported modalities.
Binary and multi-class classification tasks are supported.
</td>
</tr>
</table>

## Components
43 changes: 29 additions & 14 deletions mmlearn/tasks/zero_shot_classification.py
Original file line number Diff line number Diff line change
@@ -195,7 +195,13 @@ def evaluation_step(
query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True)
query_embeddings = query_embeddings[matching_indices]

logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
if self.all_dataset_info[dataset_index]["num_classes"] == 2:
softmax_output = _safe_matmul(
query_embeddings, class_embeddings
).softmax(dim=-1)
logits = softmax_output[:, 1] - softmax_output[:, 0]
else:
logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
targets = batch[Modalities.get_modality(query_modality).target][
matching_indices
]
@@ -233,27 +239,36 @@ def _create_metrics(
num_classes: int, top_k: List[int], prefix: str, postfix: str
) -> MetricCollection:
"""Create a collection of classification metrics."""
task_type = "binary" if num_classes == 2 else "multiclass"
acc_metrics = (
{
f"top{k}_accuracy": Accuracy(
task=task_type, num_classes=num_classes, top_k=k, average="micro"
)
for k in top_k
}
if num_classes > 2
else {"accuracy": Accuracy(task=task_type, num_classes=num_classes)}
)
return MetricCollection(
{
"precision": Precision(
task="multiclass", num_classes=num_classes, average="macro"
task=task_type,
num_classes=num_classes,
average="macro" if num_classes > 2 else "micro",
),
"recall": Recall(
task="multiclass", num_classes=num_classes, average="macro"
task=task_type,
num_classes=num_classes,
average="macro" if num_classes > 2 else "micro",
),
"f1_score_macro": F1Score(
task="multiclass", num_classes=num_classes, average="macro"
task=task_type,
num_classes=num_classes,
average="macro" if num_classes > 2 else "micro",
),
"aucroc": AUROC(task="multiclass", num_classes=num_classes),
**{
f"top{k}_accuracy": Accuracy(
task="multiclass",
num_classes=num_classes,
top_k=k,
average="micro",
)
for k in top_k
},
"aucroc": AUROC(task=task_type, num_classes=num_classes),
**acc_metrics,
},
prefix=prefix,
postfix=postfix,
Original file line number Diff line number Diff line change
@@ -143,7 +143,7 @@ datasets:

dataloader:
test:
batch_size: 64
batch_size: 128
num_workers: 4

task:
@@ -153,15 +153,14 @@ task:
task_specs:
- top_k: [1]
query_modality: rgb
run_on_validation: false
run_on_test: true
run_on_validation: False
run_on_test: True
compute_validation_loss: False
compute_test_loss: False

trainer:
precision: 16-mixed
deterministic: False
benchmark: True
deterministic: True
sync_batchnorm: False # set to True if using DDP with batchnorm
log_every_n_steps: 100

0 comments on commit 737ec9f

Please sign in to comment.