Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrating supervised contrastive learning example to Keras 3 #1734

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

sitamgithub-MSIT
Copy link
Contributor

This PR changes the Supervised Contrastive Learning example to keras 3.0 [TF-Only Example] as requested in keras-team/keras-cv#2211.

For example, here is the notebook link provided:
https://colab.research.google.com/drive/1QGiil-RpO55UNESBkilNtYETuX3B_4ZF?usp=sharing

cc: @divyashreepathihalli @fchollet

The following describes the Git difference for the changed files:

Changes:
diff --git a/examples/vision/supervised-contrastive-learning.py b/examples/vision/supervised-contrastive-learning.py
index 4803e671..6b45d568 100644
--- a/examples/vision/supervised-contrastive-learning.py
+++ b/examples/vision/supervised-contrastive-learning.py
@@ -20,22 +20,20 @@ Learning is performed in two phases:
 that representations of images in the same class will be more similar compared to
 representations of images in different classes.
 2. Training a classifier on top of the frozen encoder.
+"""
 
-Note that this example requires [TensorFlow Addons](https://www.tensorflow.org/addons),
-which you can install using the following command:
-
-```python
-pip install tensorflow-addons
-```
-
+"""
 ## Setup
 """
 
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+import keras
+from keras import ops
+from keras import layers
 import tensorflow as tf
-import tensorflow_addons as tfa
 import numpy as np
-from tensorflow import keras
-from tensorflow.keras import layers
 
 """
 ## Prepare the data
@@ -159,6 +157,19 @@ softmax are optimized.
 """
 
 
+def npairs_loss(y_true, y_pred):
+    y_pred = ops.convert_to_tensor(y_pred)
+    y_true = ops.cast(y_true, dtype="uint32")
+
+    # Expand to [batch_size, 1]
+    y_true = ops.expand_dims(y_true, -1)
+    y_true = ops.cast(ops.equal(y_true, ops.transpose(y_true)), dtype="uint32")
+    y_true /= ops.sum(y_true, 1, keepdims=True)
+
+    loss = tf.nn.softmax_cross_entropy_with_logits(logits=y_pred, labels=y_true)
+    return ops.mean(loss)
+
+
 class SupervisedContrastiveLoss(keras.losses.Loss):
     def __init__(self, temperature=1, name=None):
         super().__init__(name=name)
@@ -174,7 +185,7 @@ class SupervisedContrastiveLoss(keras.losses.Loss):
             ),
             self.temperature,
         )
-        return tfa.losses.npairs_loss(tf.squeeze(labels), logits)
+        return npairs_loss(ops.squeeze(labels), logits)
 
 
 def add_projection_head(encoder):
(END)

@sitamgithub-MSIT
Copy link
Contributor Author

So, the updated code is working fine with the TensorFlow backend. With other backends, like Jax, I tested, it is giving errors like TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[265,128].. Although there are two instances of the tf method (not tf.data). The changes I made were basically removing the TensorFlow addons. Tensorflow addons are used for implementing the npair loss. So, I made a function implementation by following this code. And it also worked with the entire example code successfully.

Copy link
Contributor

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sitamgithub-MSIT Thank you for the PR!
This example can be converted to Keras 3 and it should work with all 3 backends. You will need to replace all the Tensorflow ops with Keras ops - here is a list of all Keras ops. Here is a guide for migration

@sitamgithub-MSIT
Copy link
Contributor Author

The two only instances where TF methods are used are:

def npairs_loss(y_true, y_pred):

    y_pred = ops.convert_to_tensor(y_pred)
    y_true = ops.cast(y_true, dtype="uint32")

    # Expand to [batch_size, 1]
    y_true = ops.expand_dims(y_true, -1)
    y_true = ops.cast(ops.equal(y_true, ops.transpose(y_true)), dtype="uint32")
    y_true /= ops.sum(y_true, 1, keepdims=True)

    loss = tf.nn.softmax_cross_entropy_with_logits(logits=y_pred, labels=y_true)

    return ops.mean(loss)

Now I know loss = tf.nn.softmax_cross_entropy_with_logits(logits=y_pred, labels=y_true) can be substituted with keras.ops as described here. But whenever I am doing this, an error shows up: ValueError: Cannot take the length of shape with unknown rank..

Another instance is this tf.math.l2_normalize(feature_vectors, axis=1). So, I need some help in only these two instances to deal with.

@divyashreepathihalli
Copy link
Contributor

divyashreepathihalli commented Jan 16, 2024

here

do you have the repro for this ValueError?

For l2_normalize - please use this normalize function and set the order to 2.

@sitamgithub-MSIT
Copy link
Contributor Author

here

do you have the repro for this ValueError?

For l2_normalize - please use this normalize function and set the order to 2.

Thanks for the help! I was actually able to omit both TF instances. But other backends are showing errors. This is the modified code that worked well with Tensorflow.

def npairs_loss(y_true, y_pred):

    y_pred = ops.convert_to_tensor(y_pred)
    y_true = ops.cast(y_true, dtype="uint32")

    # Expand to [batch_size, 1]
    y_true = ops.expand_dims(y_true, -1)
    y_true = ops.cast(ops.equal(y_true, ops.transpose(y_true)), dtype="uint32")
    y_true /= ops.sum(y_true, 1, keepdims=True)

    y_true.set_shape([None, None])
    loss = ops.categorical_crossentropy(y_pred, y_true, from_logits=True, axis=-1)

    return ops.mean(loss)

But I see this set_shape is not going to work with other backends showing attribute errors. For example, jax AttributeError: 'ArrayImpl' object has no attribute 'set_shape' is showing this.

@sitamgithub-MSIT
Copy link
Contributor Author

Any updates on this?

cc: @divyashreepathihalli

@fchollet
Copy link
Member

I think you can just omit y_true.set_shape([None, None]).

@sitamgithub-MSIT
Copy link
Contributor Author

I think you can just omit y_true.set_shape([None, None]).

Then it is giving this error with the TensorFlow backend ValueError: Cannot take the length of shape with unknown rank.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants