Skip to content
This repository has been archived by the owner on Jul 24, 2024. It is now read-only.

Is the trained projection head available? #199

Open
lkshrsch opened this issue Mar 30, 2022 · 7 comments
Open

Is the trained projection head available? #199

lkshrsch opened this issue Mar 30, 2022 · 7 comments

Comments

@lkshrsch
Copy link

I am interested in downloading a pre-trained simCLR model with the projection head, to retrieve the latent features z, upon which the contrastive loss was applied.
Is this layer + pre-trained weights available somewhere?

@chentingpc
Copy link
Contributor

yes, the projection head weights should also be included in gs://simclr-checkpoints/simclrv2/pretrained/...

@lkshrsch
Copy link
Author

lkshrsch commented Apr 4, 2022

In the github README that link is under the description

"Pretrained SimCLRv2 models (with linear eval head):"

I assumed "with linear eval head" refers to the classification layer for ImageNet,

but downloading the model r50_1x_sk0 from:

https://console.cloud.google.com/storage/browser/simclr-checkpoints/simclrv2/pretrained/r50_1x_sk0?pageState=(%22StorageObjectListTable%22:(%22f%22:%22%255B%255D%22))&prefix=&forceOnObjectsSortingFiltering=false

the model output is of dimension 2048, which could be either the output from the resnet50, or the output of the projection head.

so to confirm:
are these the features from the projection head z = g(h) (as described in the simCLR paper, Figure 2)?
or from resNet50: h = f(x) (as described in the simCLR paper, Figure 2)?
or the linear evaluation head for classification (as described in the github README, which should be logits of dimension (1000) for ImageNet )?

Thanks!

@chentingpc
Copy link
Contributor

Both projection head's and supervised linear head's weights are available in the checkpoints. I suppose you're using hub module? If so, you could choose output by providing signature that's available in module.get_output_info_dict(), I listed the results below. Note that the projection head's output is not included, so in order to get that, you may need to run the tf code with the checkpoint loaded to build a new graph.

{'block_group1': <hub.ParsedTensorInfo shape=(None, None, None, 256) dtype=float32 is_sparse=False>, 'block_group2': <hub.ParsedTensorInfo shape=(None, None, None, 512) dtype=float32 is_sparse=False>, 'block_group3': <hub.ParsedTensorInfo shape=(None, None, None, 1024) dtype=float32 is_sparse=False>, 'block_group4': <hub.ParsedTensorInfo shape=(None, None, None, 2048) dtype=float32 is_sparse=False>, 'default': <hub.ParsedTensorInfo shape=(None, 2048) dtype=float32 is_sparse=False>, 'final_avg_pool': <hub.ParsedTensorInfo shape=(None, 2048) dtype=float32 is_sparse=False>, 'initial_conv': <hub.ParsedTensorInfo shape=(None, None, None, 64) dtype=float32 is_sparse=False>, 'initial_max_pool': <hub.ParsedTensorInfo shape=(None, None, None, 64) dtype=float32 is_sparse=False>, 'logits_sup': <hub.ParsedTensorInfo shape=(None, 1000) dtype=float32 is_sparse=False>}

@ilia10000
Copy link

I'm struggling to actually get the projection representations and still not quite certain what to do based on the previous comments in this thread. Does anyone have a minimal working example of loading the pre-trained model, pushing an input through, and getting the representation from the projection head?

@collinskatie
Copy link

collinskatie commented Apr 3, 2023

Thanks for the great repo! I wanted to follow-up to explore whether this issue has been reconciled?

I'm also trying to access the projection representations. Specifically, I'd like to be able to pass in an image and get out just the representation (prior to the class-level logits). What layer should I use for this?

If I load a model as follows:

saved_model_path = 'gs://simclr-checkpoints-tf2/simclrv2/pretrained/r50_1x_sk0/saved_model/'
saved_model = tf.saved_model.load(saved_model_path)

The keys available when running inference on a new image as follows:

saved_model(image, trainable=False).keys()

dict_keys(['logits_sup', 'block_group3', 'block_group4', 'final_avg_pool', 'block_group2', 'block_group1', 'initial_max_pool', 'initial_conv'])

Which of these is the key associated with the representation? final_avg_pool?

Thank you for any insight @chentingpc or others!

@chentingpc
Copy link
Contributor

Hi final_avg_pool is the output of the resnet which is used for linear probing. hope that helps

@collinskatie
Copy link

Thank you @chentingpc !! That's great to know!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants