-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert 'Vision Transformer without Attention' to Keras 3. #1855
Conversation
On Tensorflow, I am able to train and test the model, but hit this issue when loading the saved model to do inference on it. It may be the same issue as keras-team/keras#19492 but I am not 100% sure.
|
On PyTorch, I am hitting this issue when compiling the initial model before training starts.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
I believe you may be able to make the code fully-backend agnostic without implementing backend-specific train_steps. Instead, you could override compute_loss()
and make it work with all backends. The train step is generic and only loss computation appears to be custom.
Thx for the review and suggestion Francois! I dropped the custom train and test steps. The combination of overriding call() method and the native compute_loss() method was equivalent to the custom loss method. Current issues I am debugging:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates.
torch issue
This caught a bug: ops.split
is supposed to return a list, but for torch it returns a tuple (same as torch.split
). I fixed it. You can route around it by creating an output list and appending elements to it. Once done, the code runs with torch.
tf issue with deserialization
You need to call deserialize_keras_object
on the models/layers passed to constructors, to enable deserialization / model loading.
e.g.
self.data_augmentation = keras.saving.deserialize_keras_object(data_augmentation)
jax issue
This one has to do with tracer leaks. Those problems are unique to JAX and can be tricky to debug. A first problem is using ops.linspace
instead of np.linspace
in build(). There are further issues down the line however.
# Update the metrics | ||
self.compiled_metrics.update_state(labels, logits) | ||
return {m.name: m.result() for m in self.metrics} | ||
|
||
def call(self, images): | ||
augmented_images = self.data_augmentation(images) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Surely this should only be applied at training time? Also, we may consider moving it to the data pipeline instead of inside the model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. Only active at training time. see here for full context, but block level comment summarizes this well.
The augmentation pipeline consists of:
Rescaling
Resizing
Random cropping
Random horizontal flipping
Note: The image data augmentation layers do not apply data transformations at inference time. This means that when these layers are called withtraining=False
they behave differently. Refer to the [documentation (https://keras.io/api/layers/preprocessing_layers/image_augmentation/) for more details.
Tensorflow and PyTorch only compatibilty.