The pre-trained models (base network with linear classifier layer) can be found below.
Model checkpoint and hub-module | ImageNet Top-1 |
---|---|
ResNet50 (1x) | 69.1 |
ResNet50 (2x) | 74.2 |
ResNet50 (4x) | 76.6 |
Our models are trained with TPUs. It is recommended to run distributed training with TPUs when using our code for pretraining.
Our code can also run on a single GPU. It does not support multi-GPUs, for reasons such as global BatchNorm and contrastive loss across cores.
The code is compatible with both TensorFlow v1 and v2. See requirements.txt for all prerequisites, and you can also install them using the following command.
pip install -r requirements.txt
To pretrain the model on CIFAR-10 with a single GPU, try the following command:
python run.py --train_mode=pretrain \
--train_batch_size=512 --train_epochs=1000 \
--learning_rate=1.0 --weight_decay=1e-6 --temperature=0.5 \
--dataset=cifar10 --image_size=32 --eval_split=test --resnet_depth=18 \
--use_blur=False --color_jitter_strength=0.5 \
--model_dir=/tmp/simclr_test --use_tpu=False
To pretrain the model on ImageNet with Cloud TPUs, you should also set the following flags.
--use_tpu=True
--tpu_name=$TPU_NAME
Please see the Google Cloud TPU tutorial for how to use Cloud TPUs. More instruction on how to run with Cloud TPUs will be released soon!
To fine-tune a linear head (with a single GPU), try the following command:
python run.py --mode=train_then_eval --train_mode=finetune \
--fine_tune_after_block=4 --zero_init_logits_layer=True \
--variable_schema='(?!global_step|(?:.*/|^)LARSOptimizer|head)' \
--global_bn=False --optimizer=momentum --learning_rate=0.1 --weight_decay=0.0 \
--train_epochs=100 --train_batch_size=512 --warmup_epochs=0 \
--dataset=cifar10 --image_size=32 --eval_split=test --resnet_depth=18 \
--checkpoint=/tmp/simclr_test --model_dir=/tmp/simclr_test_ft --use_tpu=False
You can check the results using tensorboard, such as
python -m tensorboard.main --logdir=/tmp/simclr_test
As a reference, the above runs on CIFAR-10 should give you around 91% accuracy, though it can be further optimized.
Image IDs of ImageNet 1% and 10% subsets used for semi-supervised learning can be found in imagenet_subsets/
.
Our arXiv paper.
@article{chen2020simple,
title={A Simple Framework for Contrastive Learning of Visual Representations},
author={Chen, Ting and Kornblith, Simon and Norouzi, Mohammad and Hinton, Geoffrey},
journal={arXiv preprint arXiv:2002.05709},
year={2020}
}
This is not an official Google product.
- 4-12-2020 :
| /Volumes/Bo500G32MCache/Cervical/png2cifar10/cifar-10-binary.tar-04122020.gz:
|--- /Volumes/Bo500G32MCache/Cervical/Training_Testing_Datasets/cells_towclass_1230
|--- data_train
|--- cells_N_test (26987)
|--- cells_P_test (14365)
|--- data_test
|--- cells_N_test (1001)
|--- cells_P_test (1001)
Evaluation (./simclr_test_ft/result.json):
{"contrast_loss": 0.0, "contrastive_top_1_accuracy": 1.0, "contrastive_top_5_accuracy": 1.0, "label_top_1_accuracy": 0.8486999869346619, "label_top_5_accuracy": 0.9953, "loss": 0.42983800172805786, "regularization_loss": 0.0, "global_step": 9766.0}