From 920deff14a69c55a624c6ff095b4e3fa07fd5ad3 Mon Sep 17 00:00:00 2001
From: Jiahao Xie <52497952+Jiahao000@users.noreply.github.com>
Date: Wed, 21 Dec 2022 13:42:45 +0800
Subject: [PATCH 1/4] [Feature] add InterCLR (#609)
* [Feature] add InterCLR
* [Fix] fix lint
* [Fix] fix lint
* [Fix] fix lint
* [Fix] enhance
* [Fix] reformat bibtex
---
README.md | 1 +
README_zh-CN.md | 1 +
.../_base_/datasets/imagenet_interclr-moco.py | 83 ++++
.../selfsup/_base_/models/interclr-moco.py | 40 ++
configs/selfsup/interclr/README.md | 62 +++
...clr-moco_resnet50_8xb32-coslr-200e_in1k.py | 18 +
configs/selfsup/interclr/metafile.yml | 29 ++
docs/en/model_zoo.md | 2 +
docs/zh_cn/model_zoo.md | 2 +
mmselfsup/apis/train.py | 5 +-
mmselfsup/core/hooks/__init__.py | 5 +-
mmselfsup/core/hooks/interclr_hook.py | 185 +++++++
mmselfsup/datasets/multi_view.py | 2 +-
mmselfsup/models/algorithms/__init__.py | 5 +-
mmselfsup/models/algorithms/interclr_moco.py | 451 ++++++++++++++++++
mmselfsup/models/memories/__init__.py | 3 +-
mmselfsup/models/memories/interclr_memory.py | 254 ++++++++++
tools/train.py | 4 +-
18 files changed, 1142 insertions(+), 10 deletions(-)
create mode 100644 configs/selfsup/_base_/datasets/imagenet_interclr-moco.py
create mode 100644 configs/selfsup/_base_/models/interclr-moco.py
create mode 100644 configs/selfsup/interclr/README.md
create mode 100644 configs/selfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k.py
create mode 100644 configs/selfsup/interclr/metafile.yml
create mode 100644 mmselfsup/core/hooks/interclr_hook.py
create mode 100644 mmselfsup/models/algorithms/interclr_moco.py
create mode 100644 mmselfsup/models/memories/interclr_memory.py
diff --git a/README.md b/README.md
index 7abe0e59b..a8374560b 100644
--- a/README.md
+++ b/README.md
@@ -138,6 +138,7 @@ Supported algorithms:
- [x] [SimSiam (CVPR'2021)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/simsiam)
- [x] [Barlow Twins (ICML'2021)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/barlowtwins)
- [x] [MoCo v3 (ICCV'2021)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/mocov3)
+- [x] [InterCLR (IJCV'2022)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/interclr)
- [x] [MAE (CVPR'2022)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/mae)
- [x] [SimMIM (CVPR'2022)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/simmim)
- [x] [MaskFeat (CVPR'2022)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/maskfeat)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index a79640963..9c72ce939 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -138,6 +138,7 @@ MMSelfSup 依赖 [PyTorch](https://pytorch.org/), [MMCV](https://github.com/open
- [x] [SimSiam (CVPR'2021)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/simsiam)
- [x] [Barlow Twins (ICML'2021)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/barlowtwins)
- [x] [MoCo v3 (ICCV'2021)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/mocov3)
+- [x] [InterCLR (IJCV'2022)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/interclr)
- [x] [MAE (CVPR'2022)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/mae)
- [x] [SimMIM (CVPR'2022)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/simmim)
- [x] [MaskFeat (CVPR'2022)](https://github.com/open-mmlab/mmselfsup/tree/master/configs/selfsup/maskfeat)
diff --git a/configs/selfsup/_base_/datasets/imagenet_interclr-moco.py b/configs/selfsup/_base_/datasets/imagenet_interclr-moco.py
new file mode 100644
index 000000000..fe53c8867
--- /dev/null
+++ b/configs/selfsup/_base_/datasets/imagenet_interclr-moco.py
@@ -0,0 +1,83 @@
+# dataset settings
+data_source = 'ImageNet'
+train_dataset_type = 'MultiViewDataset'
+extract_dataset_type = 'SingleViewDataset'
+img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+train_pipeline = [
+ dict(type='RandomResizedCrop', size=224, scale=(0.2, 1.)),
+ dict(
+ type='RandomAppliedTrans',
+ transforms=[
+ dict(
+ type='ColorJitter',
+ brightness=0.4,
+ contrast=0.4,
+ saturation=0.4,
+ hue=0.1)
+ ],
+ p=0.8),
+ dict(type='RandomGrayscale', p=0.2),
+ dict(type='GaussianBlur', sigma_min=0.1, sigma_max=2.0, p=0.5),
+ dict(type='RandomHorizontalFlip'),
+]
+extract_pipeline = [
+ dict(type='Resize', size=256),
+ dict(type='CenterCrop', size=224),
+]
+
+# prefetch
+prefetch = False
+if not prefetch:
+ train_pipeline.extend(
+ [dict(type='ToTensor'),
+ dict(type='Normalize', **img_norm_cfg)])
+ extract_pipeline.extend(
+ [dict(type='ToTensor'),
+ dict(type='Normalize', **img_norm_cfg)])
+
+# dataset summary
+data = dict(
+ samples_per_gpu=32, # total 32*8=256
+ replace=True,
+ workers_per_gpu=4,
+ drop_last=True,
+ train=dict(
+ type=train_dataset_type,
+ data_source=dict(
+ type=data_source,
+ data_prefix='data/imagenet/train',
+ ann_file='data/imagenet/meta/train.txt',
+ ),
+ num_views=[2],
+ pipelines=[train_pipeline],
+ prefetch=prefetch))
+
+# additional hooks
+num_classes = 10000
+custom_hooks = [
+ dict(
+ type='InterCLRHook',
+ extractor=dict(
+ samples_per_gpu=256,
+ workers_per_gpu=8,
+ dataset=dict(
+ type=extract_dataset_type,
+ data_source=dict(
+ type=data_source,
+ data_prefix='data/imagenet/train',
+ ann_file='data/imagenet/meta/train.txt',
+ ),
+ pipeline=extract_pipeline,
+ prefetch=prefetch),
+ prefetch=prefetch,
+ img_norm_cfg=img_norm_cfg),
+ clustering=dict(type='Kmeans', k=num_classes, pca_dim=-1), # no pca
+ centroids_update_interval=10, # iter
+ deal_with_small_clusters_interval=1,
+ evaluate_interval=50,
+ warmup_epochs=0,
+ init_memory=True,
+ initial=True, # call initially
+ online_labels=True,
+ interval=10) # same as the checkpoint interval
+]
diff --git a/configs/selfsup/_base_/models/interclr-moco.py b/configs/selfsup/_base_/models/interclr-moco.py
new file mode 100644
index 000000000..7ef23da2d
--- /dev/null
+++ b/configs/selfsup/_base_/models/interclr-moco.py
@@ -0,0 +1,40 @@
+# model settings
+model = dict(
+ type='InterCLRMoCo',
+ queue_len=65536,
+ feat_dim=128,
+ momentum=0.999,
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ in_channels=3,
+ out_indices=[4], # 0: conv-1, x: stage-x
+ norm_cfg=dict(type='BN')),
+ neck=dict(
+ type='MoCoV2Neck',
+ in_channels=2048,
+ hid_channels=2048,
+ out_channels=128,
+ with_avg_pool=True),
+ head=dict(type='ContrastiveHead', temperature=0.2),
+ memory_bank=dict(
+ type='InterCLRMemory',
+ length=1281167,
+ feat_dim=128,
+ momentum=1.,
+ num_classes=10000,
+ min_cluster=20,
+ debug=False),
+ online_labels=True,
+ neg_num=16384,
+ neg_sampling='semihard', # 'hard', 'semihard', 'random', 'semieasy'
+ semihard_neg_pool_num=128000,
+ semieasy_neg_pool_num=128000,
+ intra_cos_marign_loss=False,
+ intra_cos_margin=0,
+ inter_cos_marign_loss=True,
+ inter_cos_margin=-0.5,
+ intra_loss_weight=0.75,
+ inter_loss_weight=0.25,
+ share_neck=True,
+ num_classes=10000)
diff --git a/configs/selfsup/interclr/README.md b/configs/selfsup/interclr/README.md
new file mode 100644
index 000000000..464297bb1
--- /dev/null
+++ b/configs/selfsup/interclr/README.md
@@ -0,0 +1,62 @@
+# InterCLR
+
+> [Delving into Inter-Image Invariance for Unsupervised Visual Representations](https://arxiv.org/abs/2008.11702)
+
+
+
+## Abstract
+
+Contrastive learning has recently shown immense
+potential in unsupervised visual representation learning. Existing studies in this track mainly focus on intra-image invariance learning. The learning typically uses rich intraimage transformations to construct positive pairs and then
+maximizes agreement using a contrastive loss. The merits
+of inter-image invariance, conversely, remain much less explored. One major obstacle to exploit inter-image invariance
+is that it is unclear how to reliably construct inter-image
+positive pairs, and further derive effective supervision from
+them since no pair annotations are available. In this work,
+we present a comprehensive empirical study to better understand the role of inter-image invariance learning from three main constituting components: pseudo-label maintenance,
+sampling strategy, and decision boundary design. To facilitate the study, we introduce a unified and generic framework that supports the integration of unsupervised intra- and
+inter-image invariance learning. Through carefully-designed
+comparisons and analysis, multiple valuable observations
+are revealed: 1) online labels converge faster and perform
+better than offline labels; 2) semi-hard negative samples are more reliable and unbiased than hard negative samples; 3) a
+less stringent decision boundary is more favorable for interimage invariance learning. With all the obtained recipes, our final model, namely InterCLR, shows consistent improvements over state-of-the-art intra-image invariance learning methods on multiple standard benchmarks. We hope this
+work will provide useful experience for devising effective unsupervised inter-image invariance learning.
+
+
+
+
+
+## Results and Models
+
+In this page, we provide benchmarks as much as possible to evaluate our pre-trained models. If not mentioned, all models are pre-trained on ImageNet-1k dataset. Here, we use MoCov2-InterCLR as an example. More models and results are coming soon.
+
+### Classification
+
+#### VOC SVM / Low-shot SVM
+
+The **Best Layer** indicates that the best results are obtained from which layers feature map. For example, if the **Best Layer** is **feature3**, its best result is obtained from the second stage of ResNet (1 for stem layer, 2-5 for 4 stage layers).
+
+Besides, k=1 to 96 indicates the hyper-parameter of Low-shot SVM.
+
+| Self-Supervised Config | Best Layer | SVM | k=1 | k=2 | k=4 | k=8 | k=16 | k=32 | k=64 | k=96 |
+| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ---- | ----- |
+| [interclr-moco_resnet50_8xb32-coslr-200e](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k.py) | feature5 | 85.24 | 45.08 | 59.25 | 65.99 | 74.31 | 77.95 | 80.68 | 82.7 | 83.49 |
+
+#### ImageNet Linear Evaluation
+
+The **Feature1 - Feature5** don't have the GlobalAveragePooling, the feature map is pooled to the specific dimensions and then follows a Linear layer to do the classification. Please refer to [resnet50_mhead_linear-8xb32-steplr-90e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/resnet50_mhead_linear-8xb32-steplr-90e_in1k.py) for details of config.
+
+| Self-Supervised Config | Feature1 | Feature2 | Feature3 | Feature4 | Feature5 |
+| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------- | -------- | -------- | -------- | -------- |
+| [interclr-moco_resnet50_8xb32-coslr-200e](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k.py) | 15.59 | 35.10 | 47.36 | 62.86 | 68.04 |
+
+## Citation
+
+```bibtex
+@article{xie2022delving,
+ title={Delving into inter-image invariance for unsupervised visual representations},
+ author={Xie, Jiahao and Zhan, Xiaohang and Liu, Ziwei and Ong, Yew-Soon and Loy, Chen Change},
+ journal={International Journal of Computer Vision},
+ year={2022}
+}
+```
diff --git a/configs/selfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k.py b/configs/selfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k.py
new file mode 100644
index 000000000..1aaac65f7
--- /dev/null
+++ b/configs/selfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k.py
@@ -0,0 +1,18 @@
+_base_ = [
+ '../_base_/models/interclr-moco.py',
+ '../_base_/datasets/imagenet_interclr-moco.py',
+ '../_base_/schedules/sgd_coslr-200e_in1k.py',
+ '../_base_/default_runtime.py',
+]
+
+# model settings
+model = dict(
+ memory_bank=dict(num_classes={{_base_.num_classes}}),
+ num_classes={{_base_.num_classes}},
+)
+
+# runtime settings
+# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
+# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
+# it will remove the oldest one to keep the number of total ckpts as 3
+checkpoint_config = dict(interval=10, max_keep_ckpts=3)
diff --git a/configs/selfsup/interclr/metafile.yml b/configs/selfsup/interclr/metafile.yml
new file mode 100644
index 000000000..ca2028fce
--- /dev/null
+++ b/configs/selfsup/interclr/metafile.yml
@@ -0,0 +1,29 @@
+Collections:
+ - Name: InterCLR
+ Metadata:
+ Training Data: ImageNet-1k
+ Training Techniques:
+ - SGD with Momentum
+ - Weight Decay
+ Training Resources: 8x V100 GPUs
+ Architecture:
+ - ResNet
+ - InterCLR
+ Paper:
+ URL: https://arxiv.org/abs/2008.11702
+ Title: "Delving into Inter-Image Invariance for Unsupervised Visual Representations"
+ README: configs/selfsup/interclr/README.md
+
+Models:
+ - Name: interclr-moco_resnet50_8xb32-coslr-200e_in1k
+ In Collection: InterCLR
+ Metadata:
+ Epochs: 200
+ Batch Size: 256
+ Results:
+ - Task: Self-Supervised Image Classification
+ Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 68.04
+ Config: configs/selfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k.py
+ Weights: https://download.openmmlab.com/mmselfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k_20221206-38f8fdaf.pth
diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md
index e4fe7341d..55e2ad7cd 100644
--- a/docs/en/model_zoo.md
+++ b/docs/en/model_zoo.md
@@ -23,6 +23,7 @@ All models and part of benchmark results are recorded below.
| | [simsiam_resnet50_8xb32-coslr-200e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k_20220225-2f488143.pth) \| [log](https://download.openmmlab.com/mmselfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k_20220210_195402.log.json) |
| [BarlowTwins](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/barlowtwins/README.md) | [barlowtwins_resnet50_8xb256-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/barlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/barlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k_20220419-5ae15f89.pth) \| [log](https://download.openmmlab.com/mmselfsup/barlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k_20220413_111555.log.json) |
| [MoCo v3](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/README.md) | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | [model](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220225-e31238dd.pth) \| [log](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220222_160222.log.json) |
+| [InterCLR](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/interclr/README.md) | [interclr-moco_resnet50_8xb32-coslr-200e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k_20221206-38f8fdaf.pth) \| [log](https://download.openmmlab.com/mmselfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k_20221103_085100.log.json) |
| [MAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/README.md) | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e-fp16_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) \| [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) |
| [SimMIM](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/README.md) | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | [model](https://download.openmmlab.com/mmselfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192_20220316-1d090125.pth) \| [log](https://download.openmmlab.com/mmselfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192_20220316-1d090125.log.json) |
| [CAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/README.md) | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.pth) \| [log](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.log.json) |
@@ -61,6 +62,7 @@ If not specified, we use linear evaluation setting from [MoCo](http://openaccess
| | [simsiam_resnet50_8xb32-coslr-200e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k.py) | SimSiam paper setting | 69.84 |
| Barlow Twins | [barlowtwins_resnet50_8xb256-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/barlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k.py) | Barlow Twins paper setting | 71.66 |
| MoCo v3 | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | MoCo v3 paper setting | 73.19 |
+| InterCLR | [interclr-moco_resnet50_8xb32-coslr-200e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k.py) | | 68.04 |
### ImageNet Fine-tuning
diff --git a/docs/zh_cn/model_zoo.md b/docs/zh_cn/model_zoo.md
index 34b75ca4a..a0e10a6bf 100644
--- a/docs/zh_cn/model_zoo.md
+++ b/docs/zh_cn/model_zoo.md
@@ -23,6 +23,7 @@
| | [simsiam_resnet50_8xb32-coslr-200e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k_20220225-2f488143.pth) \| [log](https://download.openmmlab.com/mmselfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k_20220210_195402.log.json) |
| [BarlowTwins](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/barlowtwins/README.md) | [barlowtwins_resnet50_8xb256-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/barlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/barlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k_20220419-5ae15f89.pth) \| [log](https://download.openmmlab.com/mmselfsup/barlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k_20220413_111555.log.json) |
| [MoCo v3](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/README.md) | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | [model](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220225-e31238dd.pth) \| [log](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220222_160222.log.json) |
+| [InterCLR](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/interclr/README.md) | [interclr-moco_resnet50_8xb32-coslr-200e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k_20221206-38f8fdaf.pth) \| [log](https://download.openmmlab.com/mmselfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k_20221103_085100.log.json) |
| [MAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/README.md) | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e-fp16_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) \| [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) |
| [SimMIM](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/README.md) | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | [model](https://download.openmmlab.com/mmselfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192_20220316-1d090125.pth) \| [log](https://download.openmmlab.com/mmselfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192_20220316-1d090125.log.json) |
| [CAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/README.md) | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.pth) \| [log](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.log.json) |
@@ -61,6 +62,7 @@
| | [simsiam_resnet50_8xb32-coslr-200e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simsiam/simsiam_resnet50_8xb32-coslr-200e_in1k.py) | SimSiam 论文设置 | 69.84 |
| Barlow Twins | [barlowtwins_resnet50_8xb256-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/barlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k.py) | Barlow Twins 论文设置 | 71.66 |
| MoCo v3 | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | MoCo v3 论文设置 | 73.19 |
+| InterCLR | [interclr-moco_resnet50_8xb32-coslr-200e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/interclr/interclr-moco_resnet50_8xb32-coslr-200e_in1k.py) | | 68.04 |
### ImageNet 微调
diff --git a/mmselfsup/apis/train.py b/mmselfsup/apis/train.py
index c4a4db98e..4b2df956a 100644
--- a/mmselfsup/apis/train.py
+++ b/mmselfsup/apis/train.py
@@ -95,7 +95,7 @@ def train_model(model,
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
- replace=getattr(cfg.data, 'sampling_replace', False),
+ replace=getattr(cfg.data, 'replace', False),
drop_last=getattr(cfg.data, 'drop_last', False),
prefetch=getattr(cfg, 'prefetch', False),
seed=cfg.get('seed'),
@@ -170,7 +170,8 @@ def train_model(model,
assert isinstance(hook_cfg, dict), \
'Each item in custom_hooks expects dict type, but got ' \
f'{type(hook_cfg)}'
- if hook_cfg.get('type', None) == 'DeepClusterHook':
+ if hook_cfg.get('type',
+ None) in ['DeepClusterHook', 'InterCLRHook']:
common_params = dict(dist_mode=True, data_loaders=data_loaders)
else:
common_params = dict()
diff --git a/mmselfsup/core/hooks/__init__.py b/mmselfsup/core/hooks/__init__.py
index 3993ce4b1..d71bd786c 100644
--- a/mmselfsup/core/hooks/__init__.py
+++ b/mmselfsup/core/hooks/__init__.py
@@ -2,6 +2,7 @@
from .cosine_annealing_hook import StepFixCosineAnnealingLrUpdaterHook
from .deepcluster_hook import DeepClusterHook
from .densecl_hook import DenseCLHook
+from .interclr_hook import InterCLRHook
from .momentum_update_hook import MomentumUpdateHook
from .odc_hook import ODCHook
from .optimizer_hook import DistOptimizerHook, GradAccumFp16OptimizerHook
@@ -10,6 +11,6 @@
__all__ = [
'MomentumUpdateHook', 'DeepClusterHook', 'DenseCLHook', 'ODCHook',
- 'DistOptimizerHook', 'GradAccumFp16OptimizerHook', 'SimSiamHook',
- 'SwAVHook', 'StepFixCosineAnnealingLrUpdaterHook'
+ 'InterCLRHook', 'DistOptimizerHook', 'GradAccumFp16OptimizerHook',
+ 'SimSiamHook', 'SwAVHook', 'StepFixCosineAnnealingLrUpdaterHook'
]
diff --git a/mmselfsup/core/hooks/interclr_hook.py b/mmselfsup/core/hooks/interclr_hook.py
new file mode 100644
index 000000000..2429fa144
--- /dev/null
+++ b/mmselfsup/core/hooks/interclr_hook.py
@@ -0,0 +1,185 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+import torch.distributed as dist
+from mmcv.runner import HOOKS, Hook
+from mmcv.utils import print_log
+
+from mmselfsup.utils import Extractor
+from mmselfsup.utils import clustering as _clustering
+
+
+@HOOKS.register_module()
+class InterCLRHook(Hook):
+ """Hook for InterCLR.
+
+ This hook includes the clustering process in InterCLR.
+
+ Args:
+ extractor (dict): Config dict for feature extraction.
+ clustering (dict): Config dict that specifies the clustering algorithm.
+ centroids_update_interval (int): Frequency of iterations to
+ update centroids.
+ deal_with_small_clusters_interval (int): Frequency of iterations to
+ deal with small clusters.
+ evaluate_interval (int): Frequency of iterations to evaluate clusters.
+ warmup_epochs (int, optional): The number of warmup epochs to set
+ ``intra_loss_weight=1`` and ``inter_loss_weight=0``. Defaults to 0.
+ init_memory (bool): Whether to initialize memory banks used in online
+ labels. Defaults to True.
+ initial (bool): Whether to call the hook initially. Defaults to True.
+ online_labels (bool): Whether to use online labels. Defaults to True.
+ interval (int): Frequency of epochs to call the hook. Defaults to 1.
+ dist_mode (bool): Use distributed training or not. Defaults to True.
+ data_loaders (DataLoader): A PyTorch dataloader. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ extractor,
+ clustering,
+ centroids_update_interval,
+ deal_with_small_clusters_interval,
+ evaluate_interval,
+ warmup_epochs=0,
+ init_memory=True,
+ initial=True,
+ online_labels=True,
+ interval=1, # same as the checkpoint interval
+ dist_mode=True,
+ data_loaders=None):
+ assert dist_mode, 'non-dist mode is not implemented'
+ self.extractor = Extractor(dist_mode=dist_mode, **extractor)
+ self.clustering_type = clustering.pop('type')
+ self.clustering_cfg = clustering
+ self.centroids_update_interval = centroids_update_interval
+ self.deal_with_small_clusters_interval = \
+ deal_with_small_clusters_interval
+ self.evaluate_interval = evaluate_interval
+ self.warmup_epochs = warmup_epochs
+ self.init_memory = init_memory
+ self.initial = initial
+ self.online_labels = online_labels
+ self.interval = interval
+ self.dist_mode = dist_mode
+ self.data_loaders = data_loaders
+
+ def before_run(self, runner):
+ assert hasattr(runner.model.module, 'intra_loss_weight'), \
+ "The runner must have attribute \"intra_loss_weight\" in InterCLR."
+ assert hasattr(runner.model.module, 'inter_loss_weight'), \
+ "The runner must have attribute \"inter_loss_weight\" in InterCLR."
+ self.intra_loss_weight = runner.model.module.intra_loss_weight
+ self.inter_loss_weight = runner.model.module.inter_loss_weight
+ if self.initial:
+ if runner.epoch > 0 and self.online_labels:
+ if runner.rank == 0:
+ print(f'Resuming memory banks from epoch {runner.epoch}')
+ features = np.load(
+ f'{runner.work_dir}/feature_epoch_{runner.epoch}.npy')
+ else:
+ features = None
+ loaded_labels = np.load(
+ f'{runner.work_dir}/cluster_epoch_{runner.epoch}.npy')
+ runner.model.module.memory_bank.init_memory(
+ features, loaded_labels)
+ return
+
+ self.deepcluster(runner)
+
+ def before_train_epoch(self, runner):
+ cur_epoch = runner.epoch
+ if cur_epoch >= self.warmup_epochs:
+ runner.model.module.intra_loss_weight = self.intra_loss_weight
+ runner.model.module.inter_loss_weight = self.inter_loss_weight
+ else:
+ runner.model.module.intra_loss_weight = 1.
+ runner.model.module.inter_loss_weight = 0.
+
+ def after_train_iter(self, runner):
+ if not self.online_labels:
+ return
+ # centroids update
+ if self.every_n_iters(runner, self.centroids_update_interval):
+ runner.model.module.memory_bank.update_centroids_memory()
+
+ # deal with small clusters
+ if self.every_n_iters(runner, self.deal_with_small_clusters_interval):
+ runner.model.module.memory_bank.deal_with_small_clusters()
+
+ # evaluate
+ if self.every_n_iters(runner, self.evaluate_interval):
+ new_labels = runner.model.module.memory_bank.label_bank
+ if new_labels.is_cuda:
+ new_labels = new_labels.cpu()
+ self.evaluate(runner, new_labels.numpy())
+
+ def after_train_epoch(self, runner):
+ if self.online_labels: # online labels
+ # save cluster
+ if self.every_n_epochs(runner, self.interval) and runner.rank == 0:
+ features = runner.model.module.memory_bank.feature_bank
+ new_labels = runner.model.module.memory_bank.label_bank
+ if new_labels.is_cuda:
+ new_labels = new_labels.cpu()
+ np.save(
+ f'{runner.work_dir}/feature_epoch_{runner.epoch + 1}.npy',
+ features.cpu().numpy())
+ np.save(
+ f'{runner.work_dir}/cluster_epoch_{runner.epoch + 1}.npy',
+ new_labels.numpy())
+ else: # offline labels
+ if self.every_n_epochs(runner, self.interval):
+ self.deepcluster(runner)
+
+ def deepcluster(self, runner):
+ # step 1: get features
+ runner.model.eval()
+ features = self.extractor(runner)
+ runner.model.train()
+
+ # step 2: get labels
+ if not self.dist_mode or (self.dist_mode and runner.rank == 0):
+ clustering_algo = _clustering.__dict__[self.clustering_type](
+ **self.clustering_cfg)
+ # Features are normalized during clustering
+ clustering_algo.cluster(features, verbose=True)
+ assert isinstance(clustering_algo.labels, np.ndarray)
+ new_labels = clustering_algo.labels.astype(np.int64)
+ if self.init_memory:
+ np.save(f'{runner.work_dir}/cluster_epoch_{runner.epoch}.npy',
+ new_labels)
+ else:
+ np.save(
+ f'{runner.work_dir}/cluster_epoch_{runner.epoch + 1}.npy',
+ new_labels)
+ self.evaluate(runner, new_labels)
+ else:
+ new_labels = np.zeros((len(self.data_loaders[0].dataset), ),
+ dtype=np.int64)
+
+ if self.dist_mode:
+ new_labels_tensor = torch.from_numpy(new_labels).cuda()
+ dist.broadcast(new_labels_tensor, 0)
+ new_labels = new_labels_tensor.cpu().numpy()
+
+ # step 3 (optional): assign offline labels
+ if not (self.online_labels or self.init_memory):
+ runner.model.module.memory_bank.assign_label(new_labels)
+
+ # step 4 (before run): initialize memory
+ if self.init_memory:
+ runner.model.module.memory_bank.init_memory(features, new_labels)
+ self.init_memory = False
+
+ def evaluate(self, runner, new_labels):
+ histogram = np.bincount(
+ new_labels, minlength=runner.model.module.memory_bank.num_classes)
+ empty_cls = (histogram == 0).sum()
+ minimal_cls_size, maximal_cls_size = histogram.min(), histogram.max()
+ if runner.rank == 0:
+ print_log(
+ f'empty_num: {empty_cls.item()}\t'
+ f'min_cluster: {minimal_cls_size.item()}\t'
+ f'max_cluster: {maximal_cls_size.item()}',
+ logger='mmselfsup')
diff --git a/mmselfsup/datasets/multi_view.py b/mmselfsup/datasets/multi_view.py
index 2bae2e18f..537180da3 100644
--- a/mmselfsup/datasets/multi_view.py
+++ b/mmselfsup/datasets/multi_view.py
@@ -58,7 +58,7 @@ def __getitem__(self, idx):
multi_views = [
torch.from_numpy(to_numpy(img)) for img in multi_views
]
- return dict(img=multi_views)
+ return dict(img=multi_views, idx=idx)
def evaluate(self, results, logger=None):
return NotImplemented
diff --git a/mmselfsup/models/algorithms/__init__.py b/mmselfsup/models/algorithms/__init__.py
index 8f8992f3b..f8e5818cc 100644
--- a/mmselfsup/models/algorithms/__init__.py
+++ b/mmselfsup/models/algorithms/__init__.py
@@ -6,6 +6,7 @@
from .classification import Classification
from .deepcluster import DeepCluster
from .densecl import DenseCL
+from .interclr_moco import InterCLRMoCo
from .mae import MAE
from .maskfeat import MaskFeat
from .mmcls_classifier_wrapper import MMClsImageClassifierWrapper
@@ -22,7 +23,7 @@
__all__ = [
'BaseModel', 'BarlowTwins', 'BYOL', 'Classification', 'DeepCluster',
- 'DenseCL', 'MoCo', 'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR',
- 'SimSiam', 'SwAV', 'MAE', 'MoCoV3', 'SimMIM',
+ 'DenseCL', 'InterCLRMoCo', 'MoCo', 'NPID', 'ODC', 'RelativeLoc',
+ 'RotationPred', 'SimCLR', 'SimSiam', 'SwAV', 'MAE', 'MoCoV3', 'SimMIM',
'MMClsImageClassifierWrapper', 'CAE', 'MaskFeat'
]
diff --git a/mmselfsup/models/algorithms/interclr_moco.py b/mmselfsup/models/algorithms/interclr_moco.py
new file mode 100644
index 000000000..00f8348f1
--- /dev/null
+++ b/mmselfsup/models/algorithms/interclr_moco.py
@@ -0,0 +1,451 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from mmselfsup.utils import (batch_shuffle_ddp, batch_unshuffle_ddp,
+ concat_all_gather)
+from ..builder import (ALGORITHMS, build_backbone, build_head, build_memory,
+ build_neck)
+from .base import BaseModel
+
+
+@ALGORITHMS.register_module()
+class InterCLRMoCo(BaseModel):
+ """MoCo-InterCLR.
+
+ Official implementation of `Delving into Inter-Image Invariance for
+ Unsupervised Visual Representations `_.
+ The clustering operation is in `core/hooks/interclr_hook.py`.
+
+ Args:
+ backbone (dict): Config dict for module of backbone.
+ neck (dict): Config dict for module of deep features to compact feature
+ vectors. Defaults to None.
+ head (dict): Config dict for module of loss functions.
+ Defaults to None.
+ queue_len (int): Number of negative keys maintained in the queue.
+ Defaults to 65536.
+ feat_dim (int): Dimension of compact feature vectors. Defaults to 128.
+ momentum (float): Momentum coefficient for the momentum-updated
+ encoder. Defaults to 0.999.
+ memory_bank (dict): Config dict for module of memory banks.
+ Defaults to None.
+ online_labels (bool): Whether to use online labels. Defaults to True.
+ neg_num (int): Number of negative samples for inter-image branch.
+ Defaults to 16384.
+ neg_sampling (str): Negative sampling strategy. Support 'hard',
+ 'semihard', 'random', 'semieasy'. Defaults to 'semihard'.
+ semihard_neg_pool_num (int): Number of negative samples for semi-hard
+ nearest neighbor pool. Defaults to 128000.
+ semieasy_neg_pool_num (int): Number of negative samples for semi-easy
+ nearest neighbor pool. Defaults to 128000.
+ intra_cos_marign_loss (bool): Whether to use a cosine margin for
+ intra-image branch. Defaults to False.
+ intra_cos_marign (float): Intra-image cosine margin. Defaults to 0.
+ intra_arc_marign_loss (bool): Whether to use an arc margin for
+ intra-image branch. Defaults to False.
+ intra_arc_marign (float): Intra-image arc margin. Defaults to 0.
+ inter_cos_marign_loss (bool): Whether to use a cosine margin for
+ inter-image branch. Defaults to True.
+ inter_cos_marign (float): Inter-image cosine margin. Defaults to -0.5.
+ inter_arc_marign_loss (bool): Whether to use an arc margin for
+ inter-image branch. Defaults to False.
+ inter_arc_marign (float): Inter-image arc margin. Defaults to 0.
+ intra_loss_weight (float): Loss weight for intra-image branch.
+ Defaults to 0.75.
+ inter_loss_weight (float): Loss weight for inter-image branch.
+ Defaults to 0.25.
+ share_neck (bool): Whether to share the neck for intra- and inter-image
+ branches. Defaults to True.
+ num_classes (int): Number of clusters. Defaults to 10000.
+ """
+
+ def __init__(self,
+ backbone,
+ neck=None,
+ head=None,
+ queue_len=65536,
+ feat_dim=128,
+ momentum=0.999,
+ memory_bank=None,
+ online_labels=True,
+ neg_num=16384,
+ neg_sampling='semihard',
+ semihard_neg_pool_num=128000,
+ semieasy_neg_pool_num=128000,
+ intra_cos_marign_loss=False,
+ intra_cos_margin=0,
+ intra_arc_marign_loss=False,
+ intra_arc_margin=0,
+ inter_cos_marign_loss=True,
+ inter_cos_margin=-0.5,
+ inter_arc_marign_loss=False,
+ inter_arc_margin=0,
+ intra_loss_weight=0.75,
+ inter_loss_weight=0.25,
+ share_neck=True,
+ num_classes=10000,
+ init_cfg=None,
+ **kwargs):
+ super(InterCLRMoCo, self).__init__(init_cfg)
+ self.encoder_q = nn.Sequential(
+ build_backbone(backbone), build_neck(neck))
+ self.encoder_k = nn.Sequential(
+ build_backbone(backbone), build_neck(neck))
+ if not share_neck:
+ self.inter_neck_q = build_neck(neck)
+ self.inter_neck_k = build_neck(neck)
+ self.backbone = self.encoder_q[0]
+ self.neck = self.encoder_q[1]
+ self.head = build_head(head)
+ self.memory_bank = build_memory(memory_bank)
+
+ # moco params
+ self.queue_len = queue_len
+ self.momentum = momentum
+
+ # interclr params
+ self.online_labels = online_labels
+ self.neg_num = neg_num
+ self.neg_sampling = neg_sampling
+ self.semihard_neg_pool_num = semihard_neg_pool_num
+ self.semieasy_neg_pool_num = semieasy_neg_pool_num
+ self.intra_cos = intra_cos_marign_loss
+ self.intra_cos_margin = intra_cos_margin
+ self.intra_arc = intra_arc_marign_loss
+ self.intra_arc_margin = intra_arc_margin
+ self.inter_cos = inter_cos_marign_loss
+ self.inter_cos_margin = inter_cos_margin
+ self.inter_arc = inter_arc_marign_loss
+ self.inter_arc_margin = inter_arc_margin
+ self.intra_loss_weight = intra_loss_weight
+ self.inter_loss_weight = inter_loss_weight
+ self.share_neck = share_neck
+ self.num_classes = num_classes
+
+ # create the queue
+ self.register_buffer('queue', torch.randn(feat_dim, queue_len))
+ self.queue = nn.functional.normalize(self.queue, dim=0)
+ self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
+
+ def init_weights(self):
+ """Initialize base_encoder with init_cfg defined in backbone."""
+ super(InterCLRMoCo, self).init_weights()
+
+ for param_q, param_k in zip(self.encoder_q.parameters(),
+ self.encoder_k.parameters()):
+ param_k.data.copy_(param_q.data)
+ param_k.requires_grad = False
+ if not self.share_neck:
+ for param_q, param_k in zip(self.inter_neck_q.parameters(),
+ self.inter_neck_k.parameters()):
+ param_k.data.copy_(param_q.data)
+ param_k.requires_grad = False
+
+ @torch.no_grad()
+ def _momentum_update_key_encoder(self):
+ """Momentum update of the key encoder."""
+ for param_q, param_k in zip(self.encoder_q.parameters(),
+ self.encoder_k.parameters()):
+ param_k.data = param_k.data * self.momentum + \
+ param_q.data * (1. - self.momentum)
+ if not self.share_neck:
+ for param_q, param_k in zip(self.inter_neck_q.parameters(),
+ self.inter_neck_k.parameters()):
+ param_k.data = param_k.data * self.momentum + \
+ param_q.data * (1. - self.momentum)
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self, keys):
+ """Update queue."""
+ # normalize
+ keys = nn.functional.normalize(keys, dim=1)
+ # gather keys before updating queue
+ keys = concat_all_gather(keys)
+
+ batch_size = keys.shape[0]
+
+ ptr = int(self.queue_ptr)
+ assert self.queue_len % batch_size == 0 # for simplicity
+
+ # replace the keys at ptr (dequeue and enqueue)
+ self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1)
+ ptr = (ptr + batch_size) % self.queue_len # move pointer
+
+ self.queue_ptr[0] = ptr
+
+ def contrast_intra(self, q, k):
+ """Intra-image invariance learning.
+
+ Args:
+ q (Tensor): Query features with shape (N, C).
+ k (Tensor): Key features with shape (N, C).
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ # normalize
+ q = nn.functional.normalize(q, dim=1)
+ k = nn.functional.normalize(k, dim=1)
+ # compute logits
+ # Einstein sum is more intuitive
+ # positive logits: Nx1
+ pos_logits = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
+ # negative logits: NxK
+ neg_logits = torch.einsum('nc,ck->nk',
+ [q, self.queue.clone().detach()])
+
+ # use cosine margin
+ if self.intra_cos:
+ cosine = pos_logits.clone()
+ phi = cosine - self.intra_cos_margin
+ pos_logits.copy_(phi)
+ # use arc margin
+ if self.intra_arc:
+ cosine = pos_logits.clone()
+ sine = torch.sqrt((1.0 - cosine**2).clamp(0, 1))
+ phi = cosine * math.cos(self.intra_arc_margin) - sine * math.sin(
+ self.intra_arc_margin)
+ if self.intra_arc_margin < 0:
+ phi = torch.where(
+ cosine < math.cos(self.intra_arc_margin), phi, cosine +
+ math.sin(self.intra_arc_margin) * self.intra_arc_margin)
+ else:
+ phi = torch.where(
+ cosine > math.cos(math.pi - self.intra_arc_margin), phi,
+ cosine - math.sin(math.pi - self.intra_arc_margin) *
+ self.intra_arc_margin)
+ pos_logits.copy_(phi)
+
+ losses = self.head(pos_logits, neg_logits)
+
+ return losses
+
+ def contrast_inter(self, q, idx):
+ """Inter-image invariance learning.
+
+ Args:
+ q (Tensor): Query features with shape (N, C).
+ idx (Tensor): Index corresponding to each query.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ # normalize
+ feat_norm = nn.functional.normalize(q, dim=1)
+ bs, feat_dim = feat_norm.shape[:2]
+ # positive sampling
+ pos_label = self.memory_bank.label_bank[idx.cpu()]
+ pos_idx_list = []
+ for i, l in enumerate(pos_label):
+ pos_idx_pool = torch.where(
+ self.memory_bank.label_bank == l)[0] # positive index pool
+ pos_i = torch.zeros(
+ 1, dtype=torch.long).random_(0, pos_idx_pool.size(0))
+ pos_idx_list.append(pos_idx_pool[pos_i])
+ pos_idx = torch.cuda.LongTensor(pos_idx_list)
+ # negative sampling
+ if self.neg_sampling == 'random': # random negative sampling
+ pos_label = pos_label.cuda().unsqueeze(1)
+ neg_idx = self.memory_bank.multinomial.draw(
+ bs * self.neg_num).view(bs, -1)
+ while True:
+ neg_label = self.memory_bank.label_bank[neg_idx.cpu()].cuda()
+ pos_in_neg_idx = (neg_label == pos_label)
+ if pos_in_neg_idx.sum().item() > 0:
+ neg_idx[
+ pos_in_neg_idx] = self.memory_bank.multinomial.draw(
+ pos_in_neg_idx.sum().item())
+ else:
+ break
+ elif self.neg_sampling == 'semihard': # semihard negative sampling
+ pos_label = pos_label.cuda().unsqueeze(1)
+ similarity = torch.mm(feat_norm.detach(),
+ self.memory_bank.feature_bank.permute(1, 0))
+ _, neg_I = torch.topk(
+ similarity, self.semihard_neg_pool_num, dim=1, sorted=False)
+ weights = torch.ones((bs, self.semihard_neg_pool_num),
+ dtype=torch.float,
+ device='cuda')
+ neg_i = torch.multinomial(weights, self.neg_num)
+ neg_idx = torch.gather(neg_I, 1, neg_i)
+ while True:
+ neg_label = self.memory_bank.label_bank[neg_idx.cpu()].cuda()
+ pos_in_neg_idx = (neg_label == pos_label)
+ if pos_in_neg_idx.sum().item() > 0:
+ neg_i = torch.multinomial(weights, self.neg_num)
+ neg_idx[pos_in_neg_idx] = torch.gather(
+ neg_I, 1, neg_i)[pos_in_neg_idx]
+ else:
+ break
+ elif self.neg_sampling == 'semieasy': # semieasy negative sampling
+ pos_label = pos_label.cuda().unsqueeze(1)
+ similarity = torch.mm(feat_norm.detach(),
+ self.memory_bank.feature_bank.permute(1, 0))
+ _, neg_I = torch.topk(
+ similarity,
+ self.semieasy_neg_pool_num,
+ dim=1,
+ largest=False,
+ sorted=False)
+ weights = torch.ones((bs, self.semieasy_neg_pool_num),
+ dtype=torch.float,
+ device='cuda')
+ neg_i = torch.multinomial(weights, self.neg_num)
+ neg_idx = torch.gather(neg_I, 1, neg_i)
+ while True:
+ neg_label = self.memory_bank.label_bank[neg_idx.cpu()].cuda()
+ pos_in_neg_idx = (neg_label == pos_label)
+ if pos_in_neg_idx.sum().item() > 0:
+ neg_i = torch.multinomial(weights, self.neg_num)
+ neg_idx[pos_in_neg_idx] = torch.gather(
+ neg_I, 1, neg_i)[pos_in_neg_idx]
+ else:
+ break
+ elif self.neg_sampling == 'hard': # hard negative sampling
+ similarity = torch.mm(feat_norm.detach(),
+ self.memory_bank.feature_bank.permute(1, 0))
+ maximal_cls_size = np.bincount(
+ self.memory_bank.label_bank.numpy(),
+ minlength=self.num_classes).max().item()
+ _, neg_I = torch.topk(
+ similarity, self.neg_num + maximal_cls_size, dim=1)
+ neg_I = neg_I.cpu()
+ neg_label = self.memory_bank.label_bank[neg_I].numpy()
+ neg_idx_list = []
+ for i, l in enumerate(pos_label):
+ pos_in_neg_idx = np.where(neg_label[i] == l)[0]
+ if len(pos_in_neg_idx) > 0:
+ neg_idx_pool = torch.from_numpy(
+ np.delete(neg_I[i].numpy(), pos_in_neg_idx))
+ else:
+ neg_idx_pool = neg_I[i]
+ neg_idx_list.append(neg_idx_pool[:self.neg_num])
+ neg_idx = torch.stack(neg_idx_list, dim=0).cuda()
+ else:
+ raise Exception(
+ f'No {self.neg_sampling} negative sampling strategy.')
+
+ pos_feat = torch.index_select(self.memory_bank.feature_bank, 0,
+ pos_idx) # BXC
+ neg_feat = torch.index_select(self.memory_bank.feature_bank, 0,
+ neg_idx.flatten()).view(
+ bs, self.neg_num, feat_dim) # BxKxC
+
+ pos_logits = torch.einsum('nc,nc->n',
+ [pos_feat, feat_norm]).unsqueeze(-1)
+ neg_logits = torch.bmm(neg_feat, feat_norm.unsqueeze(2)).squeeze(2)
+
+ # use cosine margin
+ if self.inter_cos:
+ cosine = pos_logits.clone()
+ phi = cosine - self.inter_cos_margin
+ pos_logits.copy_(phi)
+ # use arc margin
+ if self.inter_arc:
+ cosine = pos_logits.clone()
+ sine = torch.sqrt((1.0 - cosine**2).clamp(0, 1))
+ phi = cosine * math.cos(self.inter_arc_margin) - sine * math.sin(
+ self.inter_arc_margin)
+ if self.inter_arc_margin < 0:
+ phi = torch.where(
+ cosine < math.cos(self.inter_arc_margin), phi, cosine +
+ math.sin(self.inter_arc_margin) * self.inter_arc_margin)
+ else:
+ phi = torch.where(
+ cosine > math.cos(math.pi - self.inter_arc_margin), phi,
+ cosine - math.sin(math.pi - self.inter_arc_margin) *
+ self.inter_arc_margin)
+ pos_logits.copy_(phi)
+
+ losses = self.head(pos_logits, neg_logits)
+
+ return losses
+
+ def extract_feat(self, img):
+ """Function to extract features from backbone.
+
+ Args:
+ img (Tensor): Input images of shape (N, C, H, W).
+ Typically these should be mean centered and std scaled.
+
+ Returns:
+ tuple[Tensor]: backbone outputs.
+ """
+ x = self.backbone(img)
+ return x
+
+ def forward_train(self, img, idx, **kwargs):
+ """Forward computation during training.
+
+ Args:
+ img (list[Tensor]): A list of input images with shape
+ (N, C, H, W). Typically these should be mean centered
+ and std scaled.
+ idx (Tensor): Index corresponding to each image.
+ kwargs: Any keyword arguments to be used to forward.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ assert isinstance(img, list)
+ im_q = img[0]
+ im_k = img[1]
+
+ # compute query features
+ q_b = self.encoder_q[0](im_q) # backbone features
+ q = self.encoder_q[1](q_b)[0] # queries: NxC
+ if not self.share_neck:
+ q2 = self.inter_neck_q(q_b)[0] # inter queries: NxC
+
+ # compute key features
+ with torch.no_grad(): # no gradient to keys
+ self._momentum_update_key_encoder() # update the key encoder
+
+ # shuffle for making use of BN
+ im_k, idx_unshuffle = batch_shuffle_ddp(im_k)
+
+ k_b = self.encoder_k[0](im_k) # backbone features
+ k = self.encoder_k[1](k_b)[0] # keys: NxC
+ if not self.share_neck:
+ k2 = self.inter_neck_k(k_b)[0] # inter keys: NxC
+
+ # undo shuffle
+ k = batch_unshuffle_ddp(k, idx_unshuffle)
+ if not self.share_neck:
+ k2 = batch_unshuffle_ddp(k2, idx_unshuffle)
+
+ idx = idx.cuda()
+ self.memory_bank.broadcast_feature_bank()
+ # compute intra loss
+ intra_losses = self.contrast_intra(q, k)
+ # compute inter loss
+ if self.share_neck:
+ inter_losses = self.contrast_inter(q, idx)
+ else:
+ inter_losses = self.contrast_inter(q2, idx)
+ losses = dict()
+ losses['intra_loss'] = self.intra_loss_weight * intra_losses['loss']
+ losses['inter_loss'] = self.inter_loss_weight * inter_losses['loss']
+
+ self._dequeue_and_enqueue(k)
+
+ # update memory bank
+ if self.online_labels:
+ if self.share_neck:
+ change_ratio = self.memory_bank.update_samples_memory(
+ idx, k.detach())
+ else:
+ change_ratio = self.memory_bank.update_samples_memory(
+ idx, k2.detach())
+ losses['change_ratio'] = change_ratio
+ else:
+ if self.share_neck:
+ self.memory_bank.update_simple_memory(idx, k.detach())
+ else:
+ self.memory_bank.update_simple_memory(idx, k2.detach())
+
+ return losses
diff --git a/mmselfsup/models/memories/__init__.py b/mmselfsup/models/memories/__init__.py
index 53f07622f..15f499af7 100644
--- a/mmselfsup/models/memories/__init__.py
+++ b/mmselfsup/models/memories/__init__.py
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .interclr_memory import InterCLRMemory
from .odc_memory import ODCMemory
from .simple_memory import SimpleMemory
-__all__ = ['ODCMemory', 'SimpleMemory']
+__all__ = ['InterCLRMemory', 'ODCMemory', 'SimpleMemory']
diff --git a/mmselfsup/models/memories/interclr_memory.py b/mmselfsup/models/memories/interclr_memory.py
new file mode 100644
index 000000000..60810a573
--- /dev/null
+++ b/mmselfsup/models/memories/interclr_memory.py
@@ -0,0 +1,254 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from mmcv.runner import BaseModule, get_dist_info
+from sklearn.cluster import KMeans
+
+from mmselfsup.utils import AliasMethod
+from ..builder import MEMORIES
+
+
+@MEMORIES.register_module()
+class InterCLRMemory(BaseModule):
+ """Memory bank for InterCLR.
+
+ Args:
+ length (int): Number of features stored in the memory bank.
+ feat_dim (int): Dimension of stored features.
+ momentum (float): Momentum coefficient for updating features.
+ num_classes (int): Number of clusters.
+ min_cluster (int): Minimal cluster size.
+ """
+
+ def __init__(self, length, feat_dim, momentum, num_classes, min_cluster,
+ **kwargs):
+ super(InterCLRMemory, self).__init__()
+ self.rank, self.num_replicas = get_dist_info()
+ self.feature_bank = torch.zeros((length, feat_dim),
+ dtype=torch.float32,
+ device='cuda')
+ self.label_bank = torch.zeros((length, ), dtype=torch.long)
+ self.centroids = torch.zeros((num_classes, feat_dim),
+ dtype=torch.float32,
+ device='cuda')
+ self.kmeans = KMeans(n_clusters=2, random_state=0, max_iter=20)
+ self.feat_dim = feat_dim
+ self.initialized = False
+ self.momentum = momentum
+ self.num_classes = num_classes
+ self.min_cluster = min_cluster
+ self.multinomial = AliasMethod(torch.ones(length))
+ self.multinomial.cuda()
+ self.debug = kwargs.get('debug', False)
+
+ def init_memory(self, feature, label):
+ self.initialized = True
+ self.label_bank.copy_(torch.from_numpy(label).long())
+ # make sure no empty clusters
+ assert (np.bincount(label, minlength=self.num_classes) != 0).all()
+ if self.rank == 0:
+ feature /= (np.linalg.norm(feature, axis=1).reshape(-1, 1) + 1e-12)
+ self.feature_bank.copy_(torch.from_numpy(feature))
+ centroids = self._compute_centroids()
+ self.centroids.copy_(centroids)
+ dist.broadcast(self.centroids, 0)
+
+ def assign_label(self, label):
+ """Assign offline labels for each epoch."""
+ self.label_bank.copy_(torch.from_numpy(label).long())
+ # make sure no empty clusters
+ assert (np.bincount(label, minlength=self.num_classes) != 0).all()
+
+ def broadcast_feature_bank(self):
+ assert self.initialized
+ dist.broadcast(self.feature_bank, 0)
+
+ def _compute_centroids_ind(self, cinds):
+ """Compute a few centroids."""
+ assert self.rank == 0
+ num = len(cinds)
+ centroids = torch.zeros((num, self.feat_dim),
+ dtype=torch.float32,
+ device='cuda')
+ for i, c in enumerate(cinds):
+ ind = np.where(self.label_bank.numpy() == c)[0]
+ centroids[i, :] = self.feature_bank[ind, :].mean(dim=0)
+ return centroids
+
+ def _compute_centroids(self):
+ """Compute all non-empty centroids."""
+ assert self.rank == 0
+ label_bank_np = self.label_bank.numpy()
+ argl = np.argsort(label_bank_np)
+ sortl = label_bank_np[argl]
+ diff_pos = np.where(sortl[1:] - sortl[:-1] != 0)[0] + 1
+ start = np.insert(diff_pos, 0, 0)
+ end = np.insert(diff_pos, len(diff_pos), len(label_bank_np))
+ class_start = sortl[start]
+ # keep empty class centroids unchanged
+ centroids = self.centroids.clone()
+ for i, st, ed in zip(class_start, start, end):
+ centroids[i, :] = self.feature_bank[argl[st:ed], :].mean(dim=0)
+ return centroids
+
+ def _gather(self, ind, feature):
+ """Gather indices and features."""
+ ind_gathered = [
+ torch.ones_like(ind).cuda() for _ in range(self.num_replicas)
+ ]
+ feature_gathered = [
+ torch.ones_like(feature).cuda() for _ in range(self.num_replicas)
+ ]
+ dist.all_gather(ind_gathered, ind)
+ dist.all_gather(feature_gathered, feature)
+ ind_gathered = torch.cat(ind_gathered, dim=0)
+ feature_gathered = torch.cat(feature_gathered, dim=0)
+ return ind_gathered, feature_gathered
+
+ def update_simple_memory(self, ind, feature): # ind, feature: cuda tensor
+ """Update features in the memory bank."""
+ feature_norm = nn.functional.normalize(feature) # normalize
+ ind, feature_norm = self._gather(
+ ind, feature_norm) # ind: (N*w), feature: (N*w)xk, cuda tensor
+ if self.rank == 0:
+ feature_old = self.feature_bank[ind, :]
+ feature_new = (1 - self.momentum) * feature_old + \
+ self.momentum * feature_norm
+ feature_norm = nn.functional.normalize(feature_new)
+ self.feature_bank[ind, :] = feature_norm
+
+ def update_samples_memory(self, ind, feature): # ind, feature: cuda tensor
+ """Update features and labels in the memory bank."""
+ feature_norm = nn.functional.normalize(feature) # normalize
+ ind, feature_norm = self._gather(
+ ind, feature_norm) # ind: (N*w), feature: (N*w)xk, cuda tensor
+ if self.rank == 0:
+ feature_old = self.feature_bank[ind, :]
+ feature_new = (1 - self.momentum) * feature_old + \
+ self.momentum * feature_norm
+ feature_norm = nn.functional.normalize(feature_new)
+ self.feature_bank[ind, :] = feature_norm
+ dist.barrier()
+ dist.broadcast(feature_norm, 0)
+ # compute new labels
+ ind = ind.cpu()
+ similarity_to_centroids = torch.mm(self.centroids,
+ feature_norm.permute(1, 0)) # CxN
+ newlabel = similarity_to_centroids.argmax(dim=0) # cuda tensor
+ newlabel_cpu = newlabel.cpu()
+ change_ratio = (newlabel_cpu != self.label_bank[ind]
+ ).sum().float().cuda() / float(newlabel_cpu.shape[0])
+ self.label_bank[ind] = newlabel_cpu.clone() # copy to cpu
+ return change_ratio
+
+ def deal_with_small_clusters(self):
+ """Deal with small clusters."""
+ # check empty class
+ histogram = np.bincount(
+ self.label_bank.numpy(), minlength=self.num_classes)
+ small_clusters = np.where(histogram < self.min_cluster)[0].tolist()
+ if self.debug and self.rank == 0:
+ print(f'mincluster: {histogram.min()}, '
+ f'num of small class: {len(small_clusters)}')
+ if len(small_clusters) == 0:
+ return
+ # re-assign samples in small clusters to make them empty
+ for s in small_clusters:
+ ind = np.where(self.label_bank.numpy() == s)[0]
+ if len(ind) > 0:
+ inclusion = torch.from_numpy(
+ np.setdiff1d(
+ np.arange(self.num_classes),
+ np.array(small_clusters),
+ assume_unique=True)).cuda()
+ if self.rank == 0:
+ target_ind = torch.mm(
+ self.centroids[inclusion, :],
+ self.feature_bank[ind, :].permute(1, 0)).argmax(dim=0)
+ target = inclusion[target_ind]
+ else:
+ target = torch.zeros((ind.shape[0], ),
+ dtype=torch.int64,
+ device='cuda')
+ dist.all_reduce(target)
+ self.label_bank[ind] = torch.from_numpy(target.cpu().numpy())
+ # deal with empty cluster
+ self._redirect_empty_clusters(small_clusters)
+
+ def update_centroids_memory(self, cinds=None):
+ """Update centroids in the memory bank."""
+ if self.rank == 0:
+ if self.debug:
+ print('updating centroids ...')
+ if cinds is None:
+ centroids = self._compute_centroids()
+ self.centroids.copy_(centroids)
+ else:
+ centroids = self._compute_centroids_ind(cinds)
+ self.centroids[torch.cuda.LongTensor(cinds), :] = centroids
+ dist.broadcast(self.centroids, 0)
+
+ def _partition_max_cluster(self, max_cluster):
+ """Partition the largest cluster into two sub-clusters."""
+ assert self.rank == 0
+ max_cluster_inds = np.where(self.label_bank.numpy() == max_cluster)[0]
+
+ assert len(max_cluster_inds) >= 2
+ max_cluster_features = self.feature_bank[
+ max_cluster_inds, :].cpu().numpy()
+ if np.any(np.isnan(max_cluster_features)):
+ raise Exception('Has nan in features.')
+ kmeans_ret = self.kmeans.fit(max_cluster_features)
+ sub_cluster1_ind = max_cluster_inds[kmeans_ret.labels_ == 0]
+ sub_cluster2_ind = max_cluster_inds[kmeans_ret.labels_ == 1]
+ if not (len(sub_cluster1_ind) > 0 and len(sub_cluster2_ind) > 0):
+ print(
+ 'Warning: kmeans partition fails, resort to random partition.')
+ sub_cluster1_ind = np.random.choice(
+ max_cluster_inds, len(max_cluster_inds) // 2, replace=False)
+ sub_cluster2_ind = np.setdiff1d(
+ max_cluster_inds, sub_cluster1_ind, assume_unique=True)
+ return sub_cluster1_ind, sub_cluster2_ind
+
+ def _redirect_empty_clusters(self, empty_clusters):
+ """Re-direct empty clusters."""
+ for e in empty_clusters:
+ assert (self.label_bank != e).all().item(), \
+ f'Cluster #{e} is not an empty cluster.'
+ max_cluster = np.bincount(
+ self.label_bank.numpy(),
+ minlength=self.num_classes).argmax().item()
+ # gather partitioning indices
+ if self.rank == 0:
+ sub_cluster1_ind, sub_cluster2_ind = \
+ self._partition_max_cluster(max_cluster)
+ size1 = torch.cuda.LongTensor([len(sub_cluster1_ind)])
+ size2 = torch.cuda.LongTensor([len(sub_cluster2_ind)])
+ sub_cluster1_ind_tensor = torch.from_numpy(
+ sub_cluster1_ind).long().cuda()
+ sub_cluster2_ind_tensor = torch.from_numpy(
+ sub_cluster2_ind).long().cuda()
+ else:
+ size1 = torch.cuda.LongTensor([0])
+ size2 = torch.cuda.LongTensor([0])
+ dist.all_reduce(size1)
+ dist.all_reduce(size2)
+ if self.rank != 0:
+ sub_cluster1_ind_tensor = torch.zeros((size1, ),
+ dtype=torch.int64,
+ device='cuda')
+ sub_cluster2_ind_tensor = torch.zeros((size2, ),
+ dtype=torch.int64,
+ device='cuda')
+ dist.broadcast(sub_cluster1_ind_tensor, 0)
+ dist.broadcast(sub_cluster2_ind_tensor, 0)
+ if self.rank != 0:
+ sub_cluster1_ind = sub_cluster1_ind_tensor.cpu().numpy()
+ sub_cluster2_ind = sub_cluster2_ind_tensor.cpu().numpy()
+
+ # reassign samples in partition #2 to the empty class
+ self.label_bank[sub_cluster2_ind] = e
+ # update centroids of max_cluster and e
+ self.update_centroids_memory([max_cluster, e])
diff --git a/tools/train.py b/tools/train.py
index bfcf6a757..d963ba46b 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -130,8 +130,8 @@ def main():
if args.launcher == 'none':
distributed = False
assert cfg.model.type not in [
- 'DeepCluster', 'MoCo', 'SimCLR', 'ODC', 'NPID', 'SimSiam',
- 'DenseCL', 'BYOL'
+ 'BYOL', 'DeepCluster', 'DenseCL', 'InterCLRMoCo', 'MoCo', 'NPID',
+ 'ODC', 'SimCLR', 'SimSiam'
], f'{cfg.model.type} does not support non-dist training.'
else:
distributed = True
From d73154650bee5129c28db80300d39bc9a5b35d1a Mon Sep 17 00:00:00 2001
From: Jiahao Xie <52497952+Jiahao000@users.noreply.github.com>
Date: Tue, 27 Dec 2022 14:09:02 +0800
Subject: [PATCH 2/4] [Fix] fix sampling_replace config kwargs bug (#646)
* [Fix] fix sampling_replace config kwargs bug
* fix lint
* fix lint
---
configs/selfsup/_base_/datasets/imagenet_odc.py | 2 +-
mmselfsup/datasets/pipelines/transforms.py | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/configs/selfsup/_base_/datasets/imagenet_odc.py b/configs/selfsup/_base_/datasets/imagenet_odc.py
index e41b13948..a210480c8 100644
--- a/configs/selfsup/_base_/datasets/imagenet_odc.py
+++ b/configs/selfsup/_base_/datasets/imagenet_odc.py
@@ -32,7 +32,7 @@
# dataset summary
data = dict(
samples_per_gpu=64, # 64*8
- sampling_replace=True,
+ replace=True,
workers_per_gpu=4,
train=dict(
type=dataset_type,
diff --git a/mmselfsup/datasets/pipelines/transforms.py b/mmselfsup/datasets/pipelines/transforms.py
index 13afed8f4..85535d4ac 100644
--- a/mmselfsup/datasets/pipelines/transforms.py
+++ b/mmselfsup/datasets/pipelines/transforms.py
@@ -166,7 +166,7 @@ def _mask(self, mask: np.ndarray, max_mask_patches: int) -> int:
def __call__(
self, img: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]:
- mask = np.zeros(shape=self.get_shape(), dtype=np.int)
+ mask = np.zeros(shape=self.get_shape(), dtype=int)
mask_count = 0
while mask_count != self.num_masking_patches:
max_mask_patches = self.num_masking_patches - mask_count
@@ -580,7 +580,7 @@ def __call__(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Returns:
Tuple[torch.Tensor, torch.Tensor]: Input image and mask.
"""
- mask = np.zeros(shape=self.get_shape(), dtype=np.int)
+ mask = np.zeros(shape=self.get_shape(), dtype=int)
mask_count = 0
while mask_count < self.num_masking_patches:
max_mask_patches = self.num_masking_patches - mask_count
From f73306d8934f3469a9f32eea64527579d38b3365 Mon Sep 17 00:00:00 2001
From: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>
Date: Tue, 27 Dec 2022 15:48:03 +0800
Subject: [PATCH 3/4] [Fix] fix potential bug (#647)
---
mmselfsup/apis/train.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mmselfsup/apis/train.py b/mmselfsup/apis/train.py
index 4b2df956a..d322b41c9 100644
--- a/mmselfsup/apis/train.py
+++ b/mmselfsup/apis/train.py
@@ -158,7 +158,7 @@ def train_model(model,
# register hooks
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config)
- if distributed:
+ if distributed and cfg.runner.type == 'EpochBasedRunner':
runner.register_hook(DistSamplerSeedHook())
# register custom hooks
From eae0b84761415e0310243b85a5e2706acd48ddb3 Mon Sep 17 00:00:00 2001
From: Yixiao Fang <36138628+fangyixiao18@users.noreply.github.com>
Date: Fri, 30 Dec 2022 17:03:14 +0800
Subject: [PATCH 4/4] Bump version to v0.11.0 (#648)
* Bump version to v0.11.0
* Update README.md
* Update README.md
* Update README_zh-CN.md
* update
* fix lint
* update
* recover
Co-authored-by: Jiahao Xie <52497952+Jiahao000@users.noreply.github.com>
---
README.md | 43 +++++++++++++++++++-----------------
README_zh-CN.md | 45 ++++++++++++++++++++------------------
docs/en/changelog.md | 21 ++++++++++++++++++
docs/en/faq.md | 3 ++-
docs/zh_cn/changelog.md | 21 ++++++++++++++++++
docs/zh_cn/faq.md | 3 ++-
mmselfsup/version.py | 2 +-
model-index.yml | 1 +
requirements/mminstall.txt | 8 +++----
requirements/runtime.txt | 2 +-
10 files changed, 100 insertions(+), 49 deletions(-)
diff --git a/README.md b/README.md
index a8374560b..c78057fd5 100644
--- a/README.md
+++ b/README.md
@@ -43,9 +43,10 @@ English | [简体中文](README_zh-CN.md)
MMSelfSup is an open source self-supervised representation learning toolbox based on PyTorch. It is a part of the [OpenMMLab](https://openmmlab.com/) project.
-The master branch works with **PyTorch 1.5** or higher.
+The master branch works with **PyTorch 1.5+**.
-### Major features
+
+Major features
- **Methods All in One**
@@ -63,37 +64,39 @@ The master branch works with **PyTorch 1.5** or higher.
Since MMSelfSup adopts similar design of modulars and interfaces as those in other OpenMMLab projects, it supports smooth evaluation on downstream tasks with other OpenMMLab projects like object detection and segmentation.
+
+
## What's New
-### Preview of 1.x version
+### 💎 Stable version
-A brand new version of **MMSelfSup v1.0.0rc1** was released in 01/09/2022:
+MMSelfSup **v0.11.0** was released in 30/12/2022.
Highlights of the new version:
-- Based on [MMEngine](https://github.com/open-mmlab/mmengine) and [MMCV](https://github.com/open-mmlab/mmcv/tree/2.x).
-- Released with refactor.
-- Refine all [documents](https://mmselfsup.readthedocs.io/en/1.x/).
-- Support `MAE`, `SimMIM`, `MoCoV3` with different pre-training epochs and backbones of different scales.
-- More concise API.
-- More powerful data pipeline.
-- Higher accurcy for some algorithms.
+- Support `InterCLR`
+- Fix some bugs
-Find more new features in [1.x branch](https://github.com/open-mmlab/mmselfsup/tree/1.x). Issues and PRs are welcome!
+Please refer to [changelog.md](docs/en/changelog.md) for details and release history.
-### Stable version
+Differences between MMSelfSup and OpenSelfSup codebases can be found in [compatibility.md](docs/en/compatibility.md).
-MMSelfSup **v0.10.1** was released in 01/11/2022.
+### 🌟 Preview of 1.x version
-Highlights of the new version:
+A brand new version of **MMSelfSup v1.0.0rc4** was released in 07/12/2022.
-- Support MaskFeat
-- Update issue form
-- Fix some typo in documents
+Highlights of the new version:
-Please refer to [changelog.md](docs/en/changelog.md) for details and release history.
+- Based on [MMEngine](https://github.com/open-mmlab/mmengine) and [MMCV](https://github.com/open-mmlab/mmcv/tree/2.x).
+- Refine all [documents](https://mmselfsup.readthedocs.io/en/1.x/).
+- Support `BEiT v1`, `BEiT v2`, `MILAN`, `MixMIM`, `EVA`.
+- Support `MAE`, `SimMIM`, `MoCoV3` with different pre-training epochs and backbones of different scales.
+- More concise APIs.
+- More visualization tools.
+- More powerful data pipeline.
+- Higher accurcy for some algorithms.
-Differences between MMSelfSup and OpenSelfSup codebases can be found in [compatibility.md](docs/en/compatibility.md).
+Find more new features in [1.x branch](https://github.com/open-mmlab/mmselfsup/tree/1.x). Issues and PRs are welcome!
## Installation
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 9c72ce939..f60578453 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -39,13 +39,14 @@
-## 介绍
+## 简介
MMSelfSup 是一个基于 PyTorch 实现的开源自监督表征学习工具箱,是 [OpenMMLab](https://openmmlab.com/) 项目成员之一。
主分支代码支持 **PyTorch 1.5** 及以上的版本。
-### 主要特性
+
+主要特性
- **多方法集成**
@@ -63,37 +64,39 @@ MMSelfSup 是一个基于 PyTorch 实现的开源自监督表征学习工具箱
兼容 OpenMMLab 各大算法库,拥有丰富的下游评测任务和预训练模型的应用。
-## 更新
+
-### 1.x 预览版本
+## 最新进展
-全新的 **MMSelfSup v1.0.0rc1** 版本已在 2022.09.01 发布。
+### 💎 稳定版本
+
+最新的 **v0.11.0** 版本已经在 2022.12.30 发布。
新版本亮点:
-- 基于全新的 [MMEngine](https://github.com/open-mmlab/mmengine) 和 [MMCV](https://github.com/open-mmlab/mmcv/tree/2.x)。
-- 代码库重构,统一接口。
-- 完善了新版本 [文档](https://mmselfsup.readthedocs.io/en/1.x/)。
-- 支持了不同训练时间、不同尺寸的 `MAE`, `SimMIM`, `MoCoV3` 的预训练模型。
-- 更加简洁的 API。
-- 更加强大的数据管道。
-- 部分模型具有更高的准确率。
+- 支持 `InterCLR`
+- 修复部分 bugs
-在 [1.x 分支](https://github.com/open-mmlab/mmselfsup/tree/1.x) 查看更多新特性。 欢迎大家提 Issues 和 PRs!
+请参考 [更新日志](docs/zh_cn/changelog.md) 获取更多细节和历史版本信息。
-### 稳定版本
+MMSelfSup 和 OpenSelfSup 的不同点写在 [对比文档](docs/en/compatibility.md) 中。
-最新的 **v0.10.1** 版本已经在 2022.11.1 发布。
+### 🌟 1.x 预览版本
-新版本亮点:
+全新的 **v1.0.0rc4** 版本已经在 2022.12.07 发布:
-- 支持 MaskFeat
-- 更新 issue 模板
-- 修复部分文档的错误
+新版本亮点:
-请参考 [更新日志](docs/zh_cn/changelog.md) 获取更多细节和历史版本信息。
+- 基于全新的 [MMEngine](https://github.com/open-mmlab/mmengine) 和 [MMCV](https://github.com/open-mmlab/mmcv/tree/2.x)。
+- 完善了新版本 [文档](https://mmselfsup.readthedocs.io/en/1.x/)。
+- 支持了 `BEiT v1`, `BEiT v2`, `MILAN`, `MixMIM`, `EVA`。
+- 支持了不同训练时间、不同尺寸的 `MAE`, `SimMIM`, `MoCoV3` 的预训练模型。
+- 更加简洁的 APIs。
+- 更丰富的可视化工具。
+- 更加强大的数据管道。
+- 部分模型具有更高的准确率。
-MMSelfSup 和 OpenSelfSup 的不同点写在 [对比文档](docs/en/compatibility.md) 中。
+在 [1.x 分支](https://github.com/open-mmlab/mmselfsup/tree/1.x) 查看更多新特性。 欢迎大家提 Issues 和 PRs!
## 安装
diff --git a/docs/en/changelog.md b/docs/en/changelog.md
index 8ea62d126..3cec9f41e 100644
--- a/docs/en/changelog.md
+++ b/docs/en/changelog.md
@@ -2,6 +2,27 @@
## MMSelfSup
+### v0.11.0 (30/12/2022)
+
+#### New Features
+
+- Support InterCLR ([#609](https://github.com/open-mmlab/mmselfsup/pull/609))
+
+#### Bug Fixes
+
+- Fix potential bug of hook registration ([#647](https://github.com/open-mmlab/mmselfsup/pull/647))
+- Fix sampling_replace config kwargs bug ([#646](https://github.com/open-mmlab/mmselfsup/pull/646))
+- Change sklearn to scikit-learn in requirements ([#583](https://github.com/open-mmlab/mmselfsup/pull/583))
+
+#### Improvements
+
+- Update CI check rules ([#581](https://github.com/open-mmlab/mmselfsup/pull/581))
+- Update assignee schedule ([#606](https://github.com/open-mmlab/mmselfsup/pull/606))
+
+#### Docs
+
+- Add global notes and the version switcher menu ([#573](https://github.com/open-mmlab/mmselfsup/pull/573))
+
### v0.10.1 (01/11/2022)
#### Improvements
diff --git a/docs/en/faq.md b/docs/en/faq.md
index 5bdb56d26..6ce2ecf87 100644
--- a/docs/en/faq.md
+++ b/docs/en/faq.md
@@ -8,7 +8,8 @@ Compatible MMCV, MMClassification, MMDetection and MMSegmentation versions are s
| MMSelfSup version | MMCV version | MMClassification version | MMSegmentation version | MMDetection version |
| :---------------: | :--------------------------: | :-------------------------: | :--------------------: | :-----------------: |
-| 0.10.1 (master) | mmcv-full >= 1.4.2, \< 1.9.0 | mmcls >= 0.21.0, \< 0.27.0 | mmseg >= 0.20.2 | mmdet >= 2.19.0 |
+| 0.11.0 (master) | mmcv-full >= 1.4.2, \< 1.9.0 | mmcls >= 0.21.0, \< 0.27.0 | mmseg >= 0.20.2 | mmdet >= 2.19.0 |
+| 0.10.1 | mmcv-full >= 1.4.2, \< 1.9.0 | mmcls >= 0.21.0, \< 0.27.0 | mmseg >= 0.20.2 | mmdet >= 2.19.0 |
| 0.10.0 | mmcv-full >= 1.4.2, \< 1.7.0 | mmcls >= 0.21.0 | mmseg >= 0.20.2 | mmdet >= 2.19.0 |
| 0.9.2 | mmcv-full >= 1.4.2, \< 1.7.0 | mmcls >= 0.21.0 | mmseg >= 0.20.2 | mmdet >= 2.19.0 |
| 0.9.1 | mmcv-full >= 1.4.2, \< 1.6.0 | mmcls >= 0.21.0 | mmseg >= 0.20.2 | mmdet >= 2.19.0 |
diff --git a/docs/zh_cn/changelog.md b/docs/zh_cn/changelog.md
index 6bd856f75..e8c9fa220 100644
--- a/docs/zh_cn/changelog.md
+++ b/docs/zh_cn/changelog.md
@@ -2,6 +2,27 @@
## MMSelfSup
+### v0.11.0 (30/12/2022)
+
+#### 新特性
+
+- 支持了算法 InterCLR ([#609](https://github.com/open-mmlab/mmselfsup/pull/609))
+
+#### Bug Fixes
+
+- 修复钩子注册时的潜在 bug ([#647](https://github.com/open-mmlab/mmselfsup/pull/647))
+- 修复 sampling_replace 的字段错误 ([#646](https://github.com/open-mmlab/mmselfsup/pull/646))
+- 更新 scikit-learn 安装包名 ([#583](https://github.com/open-mmlab/mmselfsup/pull/583))
+
+#### Improvements
+
+- 更新 CI 检查策略 ([#581](https://github.com/open-mmlab/mmselfsup/pull/581))
+- 更新值班表 ([#606](https://github.com/open-mmlab/mmselfsup/pull/606))
+
+#### Docs
+
+- 增加全局通知和版本切换按钮 ([#573](https://github.com/open-mmlab/mmselfsup/pull/573))
+
### v0.10.1 (01/11/2022)
#### Improvements
diff --git a/docs/zh_cn/faq.md b/docs/zh_cn/faq.md
index 228ed3077..b972965e5 100644
--- a/docs/zh_cn/faq.md
+++ b/docs/zh_cn/faq.md
@@ -8,7 +8,8 @@
| MMSelfSup version | MMCV version | MMClassification version | MMSegmentation version | MMDetection version |
| :---------------: | :--------------------------: | :-------------------------: | :--------------------: | :-----------------: |
-| 0.10.1 (master) | mmcv-full >= 1.4.2, \< 1.9.0 | mmcls >= 0.21.0, \< 0.27.0 | mmseg >= 0.20.2 | mmdet >= 2.19.0 |
+| 0.11.0 (master) | mmcv-full >= 1.4.2, \< 1.9.0 | mmcls >= 0.21.0, \< 0.27.0 | mmseg >= 0.20.2 | mmdet >= 2.19.0 |
+| 0.10.1 (master) | mmcv-full >= 1.4.2, \< 1.9.0 | mmcls >= 0.21.0, \< 0.27.0 | mmseg >= 0.20.2 | mmdet >= 2.19.0 |
| 0.10.0 | mmcv-full >= 1.4.2, \< 1.7.0 | mmcls >= 0.21.0 | mmseg >= 0.20.2 | mmdet >= 2.19.0 |
| 0.9.2 | mmcv-full >= 1.4.2, \< 1.7.0 | mmcls >= 0.21.0 | mmseg >= 0.20.2 | mmdet >= 2.19.0 |
| 0.9.1 | mmcv-full >= 1.4.2, \< 1.6.0 | mmcls >= 0.21.0 | mmseg >= 0.20.2 | mmdet >= 2.19.0 |
diff --git a/mmselfsup/version.py b/mmselfsup/version.py
index 9178c80cb..20ccbf74b 100644
--- a/mmselfsup/version.py
+++ b/mmselfsup/version.py
@@ -1,6 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
-__version__ = '0.10.1'
+__version__ = '0.11.0'
def parse_version_info(version_str):
diff --git a/model-index.yml b/model-index.yml
index 81897fe9f..1f18c7d13 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -17,3 +17,4 @@ Import:
- configs/selfsup/barlowtwins/metafile.yml
- configs/selfsup/cae/metafile.yml
- configs/selfsup/maskfeat/metafile.yml
+ - configs/selfsup/interclr/metafile.yml
diff --git a/requirements/mminstall.txt b/requirements/mminstall.txt
index 6bdb5cd04..1cafca8cf 100644
--- a/requirements/mminstall.txt
+++ b/requirements/mminstall.txt
@@ -1,4 +1,4 @@
-mmcls >= 0.21.0
-mmcv-full>=1.4.2
-mmdet >= 2.16.0
-mmsegmentation >= 0.20.2
+mmcls>=0.21.0,<0.27.0
+mmcv-full>=1.4.2,<1.9.0
+mmdet>=2.19.0
+mmsegmentation>=0.20.2
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
index b1df46d0c..6f9e20145 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime.txt
@@ -1,7 +1,7 @@
attrs
future
matplotlib
-mmcls
+mmcls>=0.21.0,<0.27.0
numpy
packaging
scikit-learn