Skip to content

Latest commit

 

History

History
340 lines (236 loc) · 16.2 KB

README.md

File metadata and controls

340 lines (236 loc) · 16.2 KB

BLIP4CIR with bi-directional training

The official implementation for Bi-directional Training for Composed Image Retrieval via Text Prompt Learning.

Site navigation > Setting up  |  Usage  |  Directions for Performance Increase & Further Development 🔭

cvf arXiv License: MIT

PWC PWC

If you find this code useful for your research, please consider citing our work.

@InProceedings{Liu_2024_WACV,
    author    = {Liu, Zheyuan and Sun, Weixuan and Hong, Yicong and Teney, Damien and Gould, Stephen},
    title     = {Bi-Directional Training for Composed Image Retrieval via Text Prompt Learning},
    booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
    month     = {January},
    year      = {2024},
    pages     = {5753-5762}
}

If you are interested in this task, feel free to check out our new work Candidate Set Re-ranking for Composed Image Retrieval with Dual Multi-modal Encoder at TMLR (code and checkpoints are released here), currently ranked:

PWC PWC.

@article{liu2024candidate,
    title     = {Candidate Set Re-ranking for Composed Image Retrieval with Dual Multi-modal Encoder},
    author    = {Zheyuan Liu and Weixuan Sun and Damien Teney and Stephen Gould},
    journal   = {Transactions on Machine Learning Research},
    issn      = {2835-8856},
    year      = {2024},
    url       = {https://openreview.net/forum?id=fJAwemcvpL}
}

News and upcoming updates

  • Jan-2024 Code and pre-trained checkpoints released for our new work above.
  • Nov-2023 Code and pre-trained checkpoints released for our WACV 2024 paper.
  • Nov-2023 Readme instructions released.

Introduction

Existing approaches on Composed image retrieval (CIR) learn a mapping from the (reference image, modification text)-pair to an image embedding that is then matched against a large image corpus.

One area that has not yet been explored is the reverse direction, which asks the question, what reference image when modified as described by the text would produce the given target image?

We propose a bi-directional training scheme that leverages such reversed queries and can be applied to existing CIR architectures with minimum changes, which improves the performance of the model.

Our method is tested on BLIP4CIR, a two-stage approach, as shown below. This is a new BLIP-based baseline we proposed on top of the existing method CLIP4Cir1. For details please check out our paper.

Click to see our pipeline

In the first stage (noted as stage-I), to encode the bi-directional query, we prepend a learnable token to the modification text that designates the direction of the query and then finetune the parameters of the BLIP text embedding module.

model architecture for the first stage, BLIP text encoder finetuning

We make no other changes to the network architecture, which allows us to train the second stage (noted as stage-II) as-is, but with queries of both directions.

model architecture for the second stage, combiner model training

Setting up

First, clone the repository to a desired location.

Prerequisites

The following commands will create a local Anaconda environment with the necessary packages installed.

conda create -n cirr_dev -y python=3.8
conda activate cirr_dev
pip install -r requirements.txt

BLIP pre-trained checkpoint

Download the BLIP pre-trained checkpoint, verify with SHA1: 5f1d8cdfae91e22a35e98a4bbb4c43be7bd0ac50.

By default, we recommend storing the downloaded checkpoint file at models/model_base.pth.

Here, we use BLIP w/ ViT-B. For BLIP checkpoint options, see here.

Datasets

Experiments are conducted on two standard datasets -- Fashion-IQ and CIRR, please see their repositories for download instructions.

The downloaded file structure should look like this.

Optional -- Set up Comet

We use Comet to log the experiments. If you are unfamiliar with it, see the quick start guide. You will need to obtain an API Key for --api-key and create a personal workspace for --workspace.

If these arguments are not provided, the experiment will be logged only locally.

Note

Our code has been tested on torch 1.11.0 and 2.1.1. Presumably, any version in between shall be fine.

Modify requirements.txt to specify your specific wheel of PyTorch+CUDA versions.

Note that BLIP supports transformers<=4.25, otherwise errors will occur.

Code breakdown

Our code is based on CLIP4Cir1 with additional modules from BLIP.

From the perspective of implementation, compared to the original CLIP4Cir codebase, differences are mostly in the following two aspects:

  • we replaced the CLIP image/text encoders with BLIP as defined in src/blip_modules/;
  • we involve the reversed queries during training, which are constructed on the fly (see codeblocks surrounding loss_r in src/clip_fine_tune.py, src/combiner_train.py).

A brief introduction to the CLIP4Cir codebase is in CLIP4Cir - Usage. The structures are mostly preserved. We made some minor changes (to experiment logging etc.) but they should be easy to understand.

Usage

Training

The following configurations are used for training on one NVIDIA A100 80GB, in practice we observe the maximum VRAM usage to be approx. 36G (CIRR, stage-II training). You can also adjust the batch size to lower the VRAM consumption.

on Fashion-IQ

Stage-I BLIP text encoder finetuning

# Optional: comet experiment logging --api-key and --workspace
python src/clip_fine_tune.py --dataset FashionIQ \
                             --api-key <your comet api> --workspace <your comet workspace> \
                             --num-epochs 20 --batch-size 128 \
                             --blip-max-epoch 10 --blip-min-lr 0 \
                             --blip-learning-rate 5e-5 \
                             --transform targetpad --target-ratio 1.25 \
                             --save-training --save-best --validation-frequency 1 \
                             --experiment-name BLIP_cos10_loss_r.40_5e-5

Stage-II Combiner training

# Optional: comet experiment logging --api-key and --workspace
# Required: Load the blip text encoder weights finetuned in the previous step in --blip-model-path
python src/combiner_train.py --dataset FashionIQ \
                             --api-key <your comet api> --workspace <your comet workspace> \
                             --num-epochs 300 --batch-size 512 --blip-bs 32 \
                             --projection-dim 2560 --hidden-dim 5120  --combiner-lr 2e-5 \
                             --transform targetpad --target-ratio 1.25 \
                             --save-training --save-best --validation-frequency 1 \
                             --blip-model-path <BLIP text encoder finetuned weights path>/saved_models/tuned_blip_best.pt \
                             --experiment-name Combiner_loss_r.50_2e-5__BLIP_cos10_loss_r_.40_5e-5

on CIRR

Stage-I BLIP text encoder finetuning

# Optional: comet experiment logging --api-key and --workspace
python src/clip_fine_tune.py --dataset CIRR \
                             --api-key <your comet api> --workspace <your comet workspace> \
                             --num-epochs 20 --batch-size 128 \
                             --blip-max-epoch 10 --blip-min-lr 0 \
                             --blip-learning-rate 5e-5 \
                             --transform targetpad --target-ratio 1.25 \
                             --save-training --save-best --validation-frequency 1 \
                             --experiment-name BLIP_5e-5_cos10_loss_r.1

Stage-II Combiner training

# Optional: comet experiment logging --api-key and --workspace
# Required: Load the blip text encoder weights finetuned in the previous step in --blip-model-path
python src/combiner_train.py --dataset CIRR \
                             --api-key <your comet api> --workspace <your comet workspace> \
                             --num-epochs 300 --batch-size 512 --blip-bs 32 \
                             --projection-dim 2560 --hidden-dim 5120 --combiner-lr 2e-5 \
                             --transform targetpad --target-ratio 1.25 \
                             --save-training --save-best --validation-frequency 1 \
                             --blip-model-path <BLIP text encoder finetuned weights path>/saved_models/tuned_blip_mean.pt \
                             --experiment-name Combiner_loss_r.10__BLIP_5e-5_cos10_loss_r.1

Validating and testing

Checkpoints

The following weights shall reproduce our results reported in Tables 1 and 2 (hosted on OneDrive, check the SHA1 hash against the listed value):

checkpoints Combiner (for --combiner-path) BLIP text encoder (for --blip-model-path)
Fashion-IQ
SHA1
combiner.pt
4a1ba45bf52033c245c420b30873f68bc8e60732
tuned_blip_best.pt
80f0db536f588253fca416af83cb50fab709edda
CIRR
SHA1
combiner_mean.pt
327703361117400de83936674d5c3032af37bd7a
tuned_blip_mean.pt
67dca8a1905802cfd4cd02f640abb0579f1f88fd

Reproducing results

To validate saved checkpoints, please see below.

For Fashion-IQ, obtain results on the validation split by:

python src/validate.py --dataset fashionIQ \
                       --combining-function combiner \
                       --combiner-path <combiner trained weights path>/combiner.pt \
                       --blip-model-path <BLIP text encoder finetuned weights path>/tuned_blip_best.pt

For CIRR, obtain results on the validation split by:

python src/validate.py --dataset CIRR \
                       --combining-function combiner \
                       --combiner-path <combiner trained weights path>/combiner_mean.pt \
                       --blip-model-path <BLIP text encoder finetuned weights path>/tuned_blip_mean.pt

For CIRR test split, the following command will generate recall_submission_combiner-bi.json and recall_subset_submission_combiner-bi.json at /submission/CIRR/ for submission:

python src/cirr_test_submission.py --submission-name combiner-bi \
                                   --combining-function combiner \
                                   --combiner-path <combiner trained weights path>/combiner_mean.pt \
                                   --blip-model-path <BLIP text encoder finetuned weights path>/tuned_blip_mean.pt

Our generated .json files are also available here. To try submitting and receiving the test split results, please refer to CIRR test split server.

Interested in further development? 🔭

Tuning hyperparameters

The following hyperparameters may warrant further tunings for better performance:

  • reversed loss scale in both stages (see paper - supplementary material - Section A);
  • learning rate and cosine learning rate schedule in stage-I;

Note that this is not a comprehensive list.

Additionally, we discovered that an extended stage-I finetuning -- even if the validation shows no sign of overfitting -- may not necessarily benefit the stage-II training.

Applying CLIP4Cir combiner upgrades

This implementation and our WACV 2024 paper is based on the combiner architecture in CLIP4Cir (v2).

Since our work, the authors of CLIP4Cir have released an upgrade to their combiner architecture termed Clip4Cir (v3).

We anticipate that applying the model upgrade to our method (while still replacing CLIP with BLIP encoders) will yield a performance increase.

Finetuning BLIP image encoder

In our work, we elect to freeze the BLIP image encoder during stage-I finetuning. However, it is also possible to finetune it alongside the BLIP text encoder.

Note that finetuning the BLIP image encoder would require much more VRAM.

BLIP4Cir baseline -- Training without bi-directional queries

Simply comment out the sections related to loss_r in both stages. The model can then be used as a BLIP4Cir baseline for future research.

License

MIT License applied. In line with licenses from CLIP4Cir and BLIP.

Acknowledgement

Our implementation is based on CLIP4Cir and BLIP.

Contact

Footnotes

  1. Our code is based on this specific commit of CLIP4Cir. Note that their code has since been updated to a newer version, see Directions for Further Development 🔭 -- Applying CLIP4Cir combiner upgrades. 2