Skip to content

Codebase for "Conservative Prediction via Data-Driven Confidence Minimization"

Notifications You must be signed in to change notification settings

tajwarfahim/dcm

Repository files navigation

Conservative Prediction via Data-Driven Confidence Minimization

License

This repository contains the official code for "Conservative Prediction via Data-Driven Confidence Minimization" by Caroline Choi*, Fahim Tajwar*, Yoonho Lee*, Huaxiu Yao, Ananya Kumar, and Chelsea Finn.

Any correspondence about the code should be addressed to Caroline Choi ([email protected]) or Fahim Tajwar ([email protected]).

Citing our paper

If you use our code, you can cite our paper as follows:

@misc{choi2023conservative,
      title={Conservative Prediction via Data-Driven Confidence Minimization}, 
      author={Caroline Choi and Fahim Tajwar and Yoonho Lee and Huaxiu Yao and Ananya Kumar and Chelsea Finn},
      year={2023},
      eprint={2306.04974},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Setup Conda environment for the experiments

If you have not set up conda, please use appropriate instructions to download and set up conda on your local device. Use the following commands in a shell terminal.

conda create -n dcm python=3.9
conda activate dcm
pip install -r requirements.txt

In case some packages are missing when trying to run the experiments or version mismatches happen, please install them on your own. Most of our code assumes one GPU is available, and the code is not guaranteed to work without any GPU. Please make necessary changes in the code if this is the case. Also, please make sure to follow the instructions in pytorch or pytorch previous versions to download the pytorch version that matches the cuda version of the local device, the torch version provided in the requirements file might not always be suitable.

Download datasets

Please see the instructions in OOD Detection Inconsistency, WILDS, Robustness and Group DRO repositories to download the required datasets.

Pre-trained model weights

Pre-trained model weights for OOD detection experiments can be found in the following link: link. The weights trained with seeds 1 to 5 correspond to models trained on the complete ID set, and weights trained with seeds 11 to 15 correspond to models trained on half of the ID set (classes 0 to 4 for CIFAR-10, classes 0-49 for CIFAR-100) for the near-OOD detection task.

Similarly, weights for the selective classification experiments can be found here.

Example scripts for OOD detection

To train models, run the following command:

cd ood_detection
bash train_script.sh

To test pre-trained models (in this case, on CIFAR-10), run the following command after downloading appropriate datasets and pre-trained models:

bash test_script.sh

Inside these scripts are example python commands that can be used to reproduce our experiment results.

Example scripts for selective classification

Similarly, to run the selective classification experiments on CIFAR-10, run the following:

cd selective_classification
bash exp_script.sh

Acknowledgements

We gratefully acknowledge authors of the following repositories:

  1. Outlier Exposure
  2. Energy based Out-of-distribution Detection
  3. Mahalanobis Method for OOD Detection
  4. Siamese Network
  5. Deep Gamblers
  6. Self-Adaptive Training
  7. pytorch-cifar
  8. OOD Detection Inconsistency

In addition, please use the code provided in ERD repository to reproduce the performance of ERD/binary classification.

We thank the authors of these repositories for providing us with easy-to-work-with codebases.

About

Codebase for "Conservative Prediction via Data-Driven Confidence Minimization"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published