The official implementation for Bi-directional Training for Composed Image Retrieval via Text Prompt Learning.
Site navigation > Setting up | Usage | Directions for Performance Increase & Further Development 🔭
If you find this code useful for your research, please consider citing our work.
@InProceedings{Liu_2024_WACV,
author = {Liu, Zheyuan and Sun, Weixuan and Hong, Yicong and Teney, Damien and Gould, Stephen},
title = {Bi-Directional Training for Composed Image Retrieval via Text Prompt Learning},
booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
month = {January},
year = {2024},
pages = {5753-5762}
}
If you are interested in this task, feel free to check out our new work Candidate Set Re-ranking for Composed Image Retrieval with Dual Multi-modal Encoder at TMLR (code and checkpoints are released here), currently ranked:
@article{liu2024candidate,
title = {Candidate Set Re-ranking for Composed Image Retrieval with Dual Multi-modal Encoder},
author = {Zheyuan Liu and Weixuan Sun and Damien Teney and Stephen Gould},
journal = {Transactions on Machine Learning Research},
issn = {2835-8856},
year = {2024},
url = {https://openreview.net/forum?id=fJAwemcvpL}
}
News and upcoming updates
- Jan-2024 Code and pre-trained checkpoints released for our new work above.
- Nov-2023 Code and pre-trained checkpoints released for our WACV 2024 paper.
- Nov-2023 Readme instructions released.
Existing approaches on Composed image retrieval (CIR) learn a mapping from the (reference image, modification text)-pair to an image embedding that is then matched against a large image corpus.
One area that has not yet been explored is the reverse direction, which asks the question, what reference image when modified as described by the text would produce the given target image?
We propose a bi-directional training scheme that leverages such reversed queries and can be applied to existing CIR architectures with minimum changes, which improves the performance of the model.
Our method is tested on BLIP4CIR, a two-stage approach, as shown below. This is a new BLIP-based baseline we proposed on top of the existing method CLIP4Cir1. For details please check out our paper.
Click to see our pipeline
In the first stage (noted as stage-I), to encode the bi-directional query, we prepend a learnable token to the modification text that designates the direction of the query and then finetune the parameters of the BLIP text embedding module.
We make no other changes to the network architecture, which allows us to train the second stage (noted as stage-II) as-is, but with queries of both directions.
First, clone the repository to a desired location.
Prerequisites
The following commands will create a local Anaconda environment with the necessary packages installed.
conda create -n cirr_dev -y python=3.8
conda activate cirr_dev
pip install -r requirements.txt
BLIP pre-trained checkpoint
Download the BLIP pre-trained checkpoint, verify with SHA1: 5f1d8cdfae91e22a35e98a4bbb4c43be7bd0ac50
.
By default, we recommend storing the downloaded checkpoint file at models/model_base.pth
.
Here, we use BLIP w/ ViT-B. For BLIP checkpoint options, see here.
Datasets
Experiments are conducted on two standard datasets -- Fashion-IQ and CIRR, please see their repositories for download instructions.
The downloaded file structure should look like this.
Optional -- Set up Comet
We use Comet to log the experiments. If you are unfamiliar with it, see the quick start guide. You will need to obtain an API Key for --api-key
and create a personal workspace for --workspace
.
If these arguments are not provided, the experiment will be logged only locally.
Note
Our code has been tested on torch 1.11.0
and 2.1.1
. Presumably, any version in between shall be fine.
Modify requirements.txt
to specify your specific wheel of PyTorch+CUDA versions.
Note that BLIP supports transformers<=4.25
, otherwise errors will occur.
Our code is based on CLIP4Cir1 with additional modules from BLIP.
From the perspective of implementation, compared to the original CLIP4Cir codebase, differences are mostly in the following two aspects:
- we replaced the CLIP image/text encoders with BLIP as defined in
src/blip_modules/
; - we involve the reversed queries during training, which are constructed on the fly (see codeblocks surrounding
loss_r
insrc/clip_fine_tune.py, src/combiner_train.py
).
A brief introduction to the CLIP4Cir codebase is in CLIP4Cir - Usage. The structures are mostly preserved. We made some minor changes (to experiment logging etc.) but they should be easy to understand.
The following configurations are used for training on one NVIDIA A100 80GB, in practice we observe the maximum VRAM usage to be approx. 36G (CIRR, stage-II training). You can also adjust the batch size to lower the VRAM consumption.
Stage-I BLIP text encoder finetuning
# Optional: comet experiment logging --api-key and --workspace
python src/clip_fine_tune.py --dataset FashionIQ \
--api-key <your comet api> --workspace <your comet workspace> \
--num-epochs 20 --batch-size 128 \
--blip-max-epoch 10 --blip-min-lr 0 \
--blip-learning-rate 5e-5 \
--transform targetpad --target-ratio 1.25 \
--save-training --save-best --validation-frequency 1 \
--experiment-name BLIP_cos10_loss_r.40_5e-5
Stage-II Combiner training
# Optional: comet experiment logging --api-key and --workspace
# Required: Load the blip text encoder weights finetuned in the previous step in --blip-model-path
python src/combiner_train.py --dataset FashionIQ \
--api-key <your comet api> --workspace <your comet workspace> \
--num-epochs 300 --batch-size 512 --blip-bs 32 \
--projection-dim 2560 --hidden-dim 5120 --combiner-lr 2e-5 \
--transform targetpad --target-ratio 1.25 \
--save-training --save-best --validation-frequency 1 \
--blip-model-path <BLIP text encoder finetuned weights path>/saved_models/tuned_blip_best.pt \
--experiment-name Combiner_loss_r.50_2e-5__BLIP_cos10_loss_r_.40_5e-5
Stage-I BLIP text encoder finetuning
# Optional: comet experiment logging --api-key and --workspace
python src/clip_fine_tune.py --dataset CIRR \
--api-key <your comet api> --workspace <your comet workspace> \
--num-epochs 20 --batch-size 128 \
--blip-max-epoch 10 --blip-min-lr 0 \
--blip-learning-rate 5e-5 \
--transform targetpad --target-ratio 1.25 \
--save-training --save-best --validation-frequency 1 \
--experiment-name BLIP_5e-5_cos10_loss_r.1
Stage-II Combiner training
# Optional: comet experiment logging --api-key and --workspace
# Required: Load the blip text encoder weights finetuned in the previous step in --blip-model-path
python src/combiner_train.py --dataset CIRR \
--api-key <your comet api> --workspace <your comet workspace> \
--num-epochs 300 --batch-size 512 --blip-bs 32 \
--projection-dim 2560 --hidden-dim 5120 --combiner-lr 2e-5 \
--transform targetpad --target-ratio 1.25 \
--save-training --save-best --validation-frequency 1 \
--blip-model-path <BLIP text encoder finetuned weights path>/saved_models/tuned_blip_mean.pt \
--experiment-name Combiner_loss_r.10__BLIP_5e-5_cos10_loss_r.1
The following weights shall reproduce our results reported in Tables 1 and 2 (hosted on OneDrive, check the SHA1 hash against the listed value):
checkpoints | Combiner (for --combiner-path ) |
BLIP text encoder (for --blip-model-path ) |
---|---|---|
Fashion-IQ SHA1 |
combiner.pt 4a1ba45bf52033c245c420b30873f68bc8e60732 |
tuned_blip_best.pt 80f0db536f588253fca416af83cb50fab709edda |
CIRR SHA1 |
combiner_mean.pt 327703361117400de83936674d5c3032af37bd7a |
tuned_blip_mean.pt 67dca8a1905802cfd4cd02f640abb0579f1f88fd |
To validate saved checkpoints, please see below.
For Fashion-IQ, obtain results on the validation split by:
python src/validate.py --dataset fashionIQ \
--combining-function combiner \
--combiner-path <combiner trained weights path>/combiner.pt \
--blip-model-path <BLIP text encoder finetuned weights path>/tuned_blip_best.pt
For CIRR, obtain results on the validation split by:
python src/validate.py --dataset CIRR \
--combining-function combiner \
--combiner-path <combiner trained weights path>/combiner_mean.pt \
--blip-model-path <BLIP text encoder finetuned weights path>/tuned_blip_mean.pt
For CIRR test split, the following command will generate recall_submission_combiner-bi.json
and recall_subset_submission_combiner-bi.json
at /submission/CIRR/
for submission:
python src/cirr_test_submission.py --submission-name combiner-bi \
--combining-function combiner \
--combiner-path <combiner trained weights path>/combiner_mean.pt \
--blip-model-path <BLIP text encoder finetuned weights path>/tuned_blip_mean.pt
Our generated .json
files are also available here. To try submitting and receiving the test split results, please refer to CIRR test split server.
Tuning hyperparameters
The following hyperparameters may warrant further tunings for better performance:
- reversed loss scale in both stages (see paper - supplementary material - Section A);
- learning rate and cosine learning rate schedule in stage-I;
Note that this is not a comprehensive list.
Additionally, we discovered that an extended stage-I finetuning -- even if the validation shows no sign of overfitting -- may not necessarily benefit the stage-II training.
Applying CLIP4Cir combiner upgrades
This implementation and our WACV 2024 paper is based on the combiner architecture in CLIP4Cir (v2).
Since our work, the authors of CLIP4Cir have released an upgrade to their combiner architecture termed Clip4Cir (v3).
We anticipate that applying the model upgrade to our method (while still replacing CLIP with BLIP encoders) will yield a performance increase.
Finetuning BLIP image encoder
In our work, we elect to freeze the BLIP image encoder during stage-I finetuning. However, it is also possible to finetune it alongside the BLIP text encoder.
Note that finetuning the BLIP image encoder would require much more VRAM.
BLIP4Cir baseline -- Training without bi-directional queries
Simply comment out the sections related to loss_r
in both stages. The model can then be used as a BLIP4Cir baseline for future research.
MIT License applied. In line with licenses from CLIP4Cir and BLIP.
Our implementation is based on CLIP4Cir and BLIP.
- Raise a new GitHub issue
- Contact us
Footnotes
-
Our code is based on this specific commit of CLIP4Cir. Note that their code has since been updated to a newer version, see Directions for Further Development 🔭 -- Applying CLIP4Cir combiner upgrades. ↩ ↩2