Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

knn eval of MAE #6

Open
Dongshengjiang opened this issue Nov 18, 2021 · 13 comments
Open

knn eval of MAE #6

Dongshengjiang opened this issue Nov 18, 2021 · 13 comments
Labels
enhancement New feature or request

Comments

@Dongshengjiang
Copy link

I eval vit_base of 500/1600 pretraining on imagenet1000 using knn metric. By loading all the pretained parameter with vit GAP method (not need cls token), the knn 20-NN result is 33.4 in imagenet100 dataset, which is very low and not match the accuracy of linear prob.

@Dongshengjiang
Copy link
Author

have you tried the linear prob eval?

@pengzhiliang
Copy link
Owner

Emm, how about end-to-end finetuning?

@Dongshengjiang
Copy link
Author

I just tried your latest updata of end-to-end finetuning, it seems good. But I think linear prob still is a metric cannot avoided.

@pengzhiliang pengzhiliang added the enhancement New feature or request label Nov 18, 2021
@pengzhiliang
Copy link
Owner

Thanks for you suggestions, we actually ignore the linear prob metric. In fact, I am not very familiar with Linear Prob. Can you help me try to implement it? Thank you very much!

@Dongshengjiang
Copy link
Author

https://github.com/facebookresearch/dino/blob/main/eval_linear.py
dino contains the code of knn and linear eval code. I am not sure how to treat the cls token, as the linear prob only finetune the last head, but for MAE , the cls token is not pre-trained.

@pengzhiliang
Copy link
Owner

Ok, thank you~

@pengzhiliang
Copy link
Owner

Hello, have you finished the end-to-end fine-tuning of vit-base/1600e? Can you tell me the result? Thank you!

@Dongshengjiang
Copy link
Author

Hi, I finished the epoch 1600 training, but I only got fine-tuning result of 83.15 for epoch 1400 and 82.97 for epoch 1600. which is lower than your reported epoch 400 and the paper results.

@Dongshengjiang
Copy link
Author

Dongshengjiang commented Nov 23, 2021

From your pretrained log of vit_base, I found your max learning rate is 0.0024, is you run with 128X32 batch size?
according to the code: args.lr = args.lr * total_batch_size / 256, which should be 0.0006 for batchsize of 128X8.

@pengzhiliang
Copy link
Owner

Ok, that is very strange. I run vit-base with 512 x 8 = 4096, where the lr: 1.5e-4 * 512 * 8 / 256 = 0.0024.

@Dongshengjiang
Copy link
Author

ok, I will try your setting to reimplement your results for epoch 400. But the results of epoch 1600 is on batchsize 4096, still not good enough. the ft accuracy incrase slowly with epoch: 82.71/200, 82.82/400,82.87/600, 83/800,82.78/1000,82.96/1200,83.15/1400,82.97/1600.

@pengzhiliang
Copy link
Owner

OK, thank you for your so much experiments!
Maybe there is still some problems, I will check it carefully.

@pengzhiliang pengzhiliang pinned this issue Nov 24, 2021
@pengzhiliang pengzhiliang unpinned this issue Nov 24, 2021
@Harick1
Copy link

Harick1 commented Dec 12, 2021

@Dongshengjiang Have you tried the LinearProbe evaluation with cls token?

The paper said: As ViT has a class token [16], to adapt to this design, in our MAE pre-training we append an auxiliary dummy token to the encoder input. This token will be treated as the class token for training the classifier in linear probing and fine-tuning.

It seems that the author just adds a dummy token when pre-training, and directly uses it as the feature for linear probing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants