This repository contains the official code for "Conservative Prediction via Data-Driven Confidence Minimization" by Caroline Choi*, Fahim Tajwar*, Yoonho Lee*, Huaxiu Yao, Ananya Kumar, and Chelsea Finn.
Any correspondence about the code should be addressed to Caroline Choi ([email protected]) or Fahim Tajwar ([email protected]).
If you use our code, you can cite our paper as follows:
@misc{choi2023conservative,
title={Conservative Prediction via Data-Driven Confidence Minimization},
author={Caroline Choi and Fahim Tajwar and Yoonho Lee and Huaxiu Yao and Ananya Kumar and Chelsea Finn},
year={2023},
eprint={2306.04974},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
If you have not set up conda, please use appropriate instructions to download and set up conda on your local device. Use the following commands in a shell terminal.
conda create -n dcm python=3.9
conda activate dcm
pip install -r requirements.txt
In case some packages are missing when trying to run the experiments or version mismatches happen, please install them on your own. Most of our code assumes one GPU is available, and the code is not guaranteed to work without any GPU. Please make necessary changes in the code if this is the case. Also, please make sure to follow the instructions in pytorch or pytorch previous versions to download the pytorch version that matches the cuda version of the local device, the torch version provided in the requirements file might not always be suitable.
Please see the instructions in OOD Detection Inconsistency, WILDS, Robustness and Group DRO repositories to download the required datasets.
Pre-trained model weights for OOD detection experiments can be found in the following link: link. The weights trained with seeds 1 to 5 correspond to models trained on the complete ID set, and weights trained with seeds 11 to 15 correspond to models trained on half of the ID set (classes 0 to 4 for CIFAR-10, classes 0-49 for CIFAR-100) for the near-OOD detection task.
Similarly, weights for the selective classification experiments can be found here.
To train models, run the following command:
cd ood_detection
bash train_script.sh
To test pre-trained models (in this case, on CIFAR-10), run the following command after downloading appropriate datasets and pre-trained models:
bash test_script.sh
Inside these scripts are example python commands that can be used to reproduce our experiment results.
Similarly, to run the selective classification experiments on CIFAR-10, run the following:
cd selective_classification
bash exp_script.sh
We gratefully acknowledge authors of the following repositories:
- Outlier Exposure
- Energy based Out-of-distribution Detection
- Mahalanobis Method for OOD Detection
- Siamese Network
- Deep Gamblers
- Self-Adaptive Training
- pytorch-cifar
- OOD Detection Inconsistency
In addition, please use the code provided in ERD repository to reproduce the performance of ERD/binary classification.
We thank the authors of these repositories for providing us with easy-to-work-with codebases.