Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrating supervised contrastive learning example to Keras 3 #1734

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions examples/vision/supervised-contrastive-learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,20 @@
that representations of images in the same class will be more similar compared to
representations of images in different classes.
2. Training a classifier on top of the frozen encoder.
"""

Note that this example requires [TensorFlow Addons](https://www.tensorflow.org/addons),
which you can install using the following command:

```python
pip install tensorflow-addons
```

"""
## Setup
"""

import os

os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
from keras import ops
from keras import layers
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

"""
## Prepare the data
Expand Down Expand Up @@ -159,6 +157,19 @@ def create_classifier(encoder, trainable=True):
"""


def npairs_loss(y_true, y_pred):
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.cast(y_true, dtype="uint32")

# Expand to [batch_size, 1]
y_true = ops.expand_dims(y_true, -1)
y_true = ops.cast(ops.equal(y_true, ops.transpose(y_true)), dtype="uint32")
y_true /= ops.sum(y_true, 1, keepdims=True)

loss = tf.nn.softmax_cross_entropy_with_logits(logits=y_pred, labels=y_true)
return ops.mean(loss)


class SupervisedContrastiveLoss(keras.losses.Loss):
def __init__(self, temperature=1, name=None):
super().__init__(name=name)
Expand All @@ -174,7 +185,7 @@ def __call__(self, labels, feature_vectors, sample_weight=None):
),
self.temperature,
)
return tfa.losses.npairs_loss(tf.squeeze(labels), logits)
return npairs_loss(ops.squeeze(labels), logits)


def add_projection_head(encoder):
Expand Down
Loading