This is a minimal implementation of DETR using jax
and flax
.
Features:
- Supports Flash Attention (up to 50% faster over and above the following optimizations).
- Supports Sinkhorn solver (faster training for roughly the same final AP).
- Parallel bipartite matching for all auxiliary outputs (up to 30% faster training using Hungarian matcher).
- Uses
optax
API. - Bug fixes from scenic to match official DETR implementation.
- Supports BigTransfer (BiT-S) ResNet-50 backbone.
You can read more about these optimizations here.
-
Setup:
git clone https://github.com/MasterSkepticista/detr.git && cd detr python3.10 -m venv venv && source venv/bin/activate pip install -U pip setuptools wheel pip install -r requirements.txt
-
You may need to download MS-COCO dataset in TFDS. Run the following to download and create TFRecords:
python -c "import tensorflow_datasets as tfds; tfds.load('coco/2017')"
Download torch resnet50 checkpoint from GDrive.
pip install gdown
# gdown <gdrive-file-id> -O <output-dir>
gdown 1q-PYc6ZshX12Nelb30V6Cp1FkmxhUdD2 -O artifacts/
Backbone | Top-1 Acc. | Checkpoint |
---|---|---|
BiT-R50x1-i1k | 76.8% | Link |
R50x1-i1k (from torchvision) | 76.1% | Link (created using this gist) |
# Trains the default DETR-R50-1333 model using `float32` precision.
# ~4.5 days on 8x A6000.
python main.py \
--config configs/hungarian.py --workdir artifacts/`date '+%m-%d_%H%M'`
Checkpoints (all non-DC5 variants) using the torchvision R50 backbone:
Checkpoint | GFLOPs | ||||||
---|---|---|---|---|---|---|---|
DETR-R50-1333* | 174.2 | 40.80 | 61.88 | 42.45 | 19.2 | 44.31 | 60.32 |
DETR-R50-640 | 38.5 | 33.14 | 52.89 | 34.00 | 10.54 | 35.10 | 55.53 |
*official DETR baseline, except that these models were trained for 300 epochs instead of 500 epochs.
- Download one of the pretrained checkpoints.
# In configs/common.py (or any) config.init_from = ml_collections.ConfigDict() config.init_from.checkpoint_path = '/path/to/checkpoint'
- Replace
config.total_epochs
withconfig.total_steps = 0
to skip to eval.
Parts of this codebase are based on scenic.
DETR implementation in PyTorch: facebookresearch/detr.
MIT