-
Notifications
You must be signed in to change notification settings - Fork 432
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
28 changed files
with
1,246 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# dataset settings | ||
data_source = 'ImageNet' | ||
train_dataset_type = 'MultiViewDataset' | ||
extract_dataset_type = 'SingleViewDataset' | ||
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||
train_pipeline = [ | ||
dict(type='RandomResizedCrop', size=224, scale=(0.2, 1.)), | ||
dict( | ||
type='RandomAppliedTrans', | ||
transforms=[ | ||
dict( | ||
type='ColorJitter', | ||
brightness=0.4, | ||
contrast=0.4, | ||
saturation=0.4, | ||
hue=0.1) | ||
], | ||
p=0.8), | ||
dict(type='RandomGrayscale', p=0.2), | ||
dict(type='GaussianBlur', sigma_min=0.1, sigma_max=2.0, p=0.5), | ||
dict(type='RandomHorizontalFlip'), | ||
] | ||
extract_pipeline = [ | ||
dict(type='Resize', size=256), | ||
dict(type='CenterCrop', size=224), | ||
] | ||
|
||
# prefetch | ||
prefetch = False | ||
if not prefetch: | ||
train_pipeline.extend( | ||
[dict(type='ToTensor'), | ||
dict(type='Normalize', **img_norm_cfg)]) | ||
extract_pipeline.extend( | ||
[dict(type='ToTensor'), | ||
dict(type='Normalize', **img_norm_cfg)]) | ||
|
||
# dataset summary | ||
data = dict( | ||
samples_per_gpu=32, # total 32*8=256 | ||
replace=True, | ||
workers_per_gpu=4, | ||
drop_last=True, | ||
train=dict( | ||
type=train_dataset_type, | ||
data_source=dict( | ||
type=data_source, | ||
data_prefix='data/imagenet/train', | ||
ann_file='data/imagenet/meta/train.txt', | ||
), | ||
num_views=[2], | ||
pipelines=[train_pipeline], | ||
prefetch=prefetch)) | ||
|
||
# additional hooks | ||
num_classes = 10000 | ||
custom_hooks = [ | ||
dict( | ||
type='InterCLRHook', | ||
extractor=dict( | ||
samples_per_gpu=256, | ||
workers_per_gpu=8, | ||
dataset=dict( | ||
type=extract_dataset_type, | ||
data_source=dict( | ||
type=data_source, | ||
data_prefix='data/imagenet/train', | ||
ann_file='data/imagenet/meta/train.txt', | ||
), | ||
pipeline=extract_pipeline, | ||
prefetch=prefetch), | ||
prefetch=prefetch, | ||
img_norm_cfg=img_norm_cfg), | ||
clustering=dict(type='Kmeans', k=num_classes, pca_dim=-1), # no pca | ||
centroids_update_interval=10, # iter | ||
deal_with_small_clusters_interval=1, | ||
evaluate_interval=50, | ||
warmup_epochs=0, | ||
init_memory=True, | ||
initial=True, # call initially | ||
online_labels=True, | ||
interval=10) # same as the checkpoint interval | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# model settings | ||
model = dict( | ||
type='InterCLRMoCo', | ||
queue_len=65536, | ||
feat_dim=128, | ||
momentum=0.999, | ||
backbone=dict( | ||
type='ResNet', | ||
depth=50, | ||
in_channels=3, | ||
out_indices=[4], # 0: conv-1, x: stage-x | ||
norm_cfg=dict(type='BN')), | ||
neck=dict( | ||
type='MoCoV2Neck', | ||
in_channels=2048, | ||
hid_channels=2048, | ||
out_channels=128, | ||
with_avg_pool=True), | ||
head=dict(type='ContrastiveHead', temperature=0.2), | ||
memory_bank=dict( | ||
type='InterCLRMemory', | ||
length=1281167, | ||
feat_dim=128, | ||
momentum=1., | ||
num_classes=10000, | ||
min_cluster=20, | ||
debug=False), | ||
online_labels=True, | ||
neg_num=16384, | ||
neg_sampling='semihard', # 'hard', 'semihard', 'random', 'semieasy' | ||
semihard_neg_pool_num=128000, | ||
semieasy_neg_pool_num=128000, | ||
intra_cos_marign_loss=False, | ||
intra_cos_margin=0, | ||
inter_cos_marign_loss=True, | ||
inter_cos_margin=-0.5, | ||
intra_loss_weight=0.75, | ||
inter_loss_weight=0.25, | ||
share_neck=True, | ||
num_classes=10000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# InterCLR | ||
|
||
> [Delving into Inter-Image Invariance for Unsupervised Visual Representations](https://arxiv.org/abs/2008.11702) | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
Contrastive learning has recently shown immense | ||
potential in unsupervised visual representation learning. Existing studies in this track mainly focus on intra-image invariance learning. The learning typically uses rich intraimage transformations to construct positive pairs and then | ||
maximizes agreement using a contrastive loss. The merits | ||
of inter-image invariance, conversely, remain much less explored. One major obstacle to exploit inter-image invariance | ||
is that it is unclear how to reliably construct inter-image | ||
positive pairs, and further derive effective supervision from | ||
them since no pair annotations are available. In this work, | ||
we present a comprehensive empirical study to better understand the role of inter-image invariance learning from three main constituting components: pseudo-label maintenance, | ||
sampling strategy, and decision boundary design. To facilitate the study, we introduce a unified and generic framework that supports the integration of unsupervised intra- and | ||
inter-image invariance learning. Through carefully-designed | ||
comparisons and analysis, multiple valuable observations | ||
are revealed: 1) online labels converge faster and perform | ||
better than offline labels; 2) semi-hard negative samples are more reliable and unbiased than hard negative samples; 3) a | ||
less stringent decision boundary is more favorable for interimage invariance learning. With all the obtained recipes, our final model, namely InterCLR, shows consistent improvements over state-of-the-art intra-image invariance learning methods on multiple standard benchmarks. We hope this | ||
work will provide useful experience for devising effective unsupervised inter-image invariance learning. | ||
|
||
<div align="center"> | ||
<img src="https://user-images.githubusercontent.com/52497952/205854109-2385b765-e12b-4e22-b7b8-45db6292895b.png" width="800" /> | ||
</div> | ||
|
||
## Results and Models | ||
|
||
In this page, we provide benchmarks as much as possible to evaluate our pre-trained models. If not mentioned, all models are pre-trained on ImageNet-1k dataset. Here, we use MoCov2-InterCLR as an example. More models and results are coming soon. | ||
|
||
### Classification | ||
|
||
#### VOC SVM / Low-shot SVM | ||
|
||
The **Best Layer** indicates that the best results are obtained from which layers feature map. For example, if the **Best Layer** is **feature3**, its best result is obtained from the second stage of ResNet (1 for stem layer, 2-5 for 4 stage layers). | ||
|
||
Besides, k=1 to 96 indicates the hyper-parameter of Low-shot SVM. | ||
|
||
| Self-Supervised Config | Best Layer | SVM | k=1 | k=2 | k=4 | k=8 | k=16 | k=32 | k=64 | k=96 | | ||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ---- | ----- | | ||
| [interclr-moco_resnet50_8xb32-coslr-200e](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k.py) | feature5 | 85.24 | 45.08 | 59.25 | 65.99 | 74.31 | 77.95 | 80.68 | 82.7 | 83.49 | | ||
|
||
#### ImageNet Linear Evaluation | ||
|
||
The **Feature1 - Feature5** don't have the GlobalAveragePooling, the feature map is pooled to the specific dimensions and then follows a Linear layer to do the classification. Please refer to [resnet50_mhead_linear-8xb32-steplr-90e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/resnet50_mhead_linear-8xb32-steplr-90e_in1k.py) for details of config. | ||
|
||
| Self-Supervised Config | Feature1 | Feature2 | Feature3 | Feature4 | Feature5 | | ||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------- | -------- | -------- | -------- | -------- | | ||
| [interclr-moco_resnet50_8xb32-coslr-200e](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k.py) | 15.59 | 35.10 | 47.36 | 62.86 | 68.04 | | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@article{xie2022delving, | ||
title={Delving into inter-image invariance for unsupervised visual representations}, | ||
author={Xie, Jiahao and Zhan, Xiaohang and Liu, Ziwei and Ong, Yew-Soon and Loy, Chen Change}, | ||
journal={International Journal of Computer Vision}, | ||
year={2022} | ||
} | ||
``` |
18 changes: 18 additions & 0 deletions
18
configs/selfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
_base_ = [ | ||
'../_base_/models/interclr-moco.py', | ||
'../_base_/datasets/imagenet_interclr-moco.py', | ||
'../_base_/schedules/sgd_coslr-200e_in1k.py', | ||
'../_base_/default_runtime.py', | ||
] | ||
|
||
# model settings | ||
model = dict( | ||
memory_bank=dict(num_classes={{_base_.num_classes}}), | ||
num_classes={{_base_.num_classes}}, | ||
) | ||
|
||
# runtime settings | ||
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs | ||
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt | ||
# it will remove the oldest one to keep the number of total ckpts as 3 | ||
checkpoint_config = dict(interval=10, max_keep_ckpts=3) |
Oops, something went wrong.