Skip to content

Commit

Permalink
Added helper to ignore iscrowd
Browse files Browse the repository at this point in the history
  • Loading branch information
justinkay committed Mar 20, 2024
1 parent 5b1a3fa commit e5db124
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
20 changes: 19 additions & 1 deletion aldi/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import random
import torch

from detectron2.evaluation import COCOEvaluator


class SaveIO:
"""Simple PyTorch hook to save the output of a nn.module."""
Expand Down Expand Up @@ -58,4 +60,20 @@ def backward(ctx, grad_output):
return ctx.weight*grad_input, None

def grad_reverse(x):
return _GradientScalarLayer.apply(x, -1.0)
return _GradientScalarLayer.apply(x, -1.0)

def _maybe_add_iscrowd_annotations(cocoapi) -> None:
for ann in cocoapi.dataset["annotations"]:
if "iscrowd" not in ann:
ann["iscrowd"] = 0

class Detectron2COCOEvaluatorAdapter(COCOEvaluator):
"""A COCOEvaluator that makes iscrowd optional."""
def __init__(
self,
dataset_name,
output_dir=None,
distributed=True,
):
super().__init__(dataset_name, output_dir=output_dir, distributed=distributed)
_maybe_add_iscrowd_annotations(self._coco_api)
5 changes: 3 additions & 2 deletions aldi/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from detectron2.checkpoint.detection_checkpoint import DetectionCheckpointer
from detectron2.data.build import build_detection_train_loader, get_detection_dataset_dicts
from detectron2.engine import hooks, BestCheckpointer
from detectron2.evaluation import COCOEvaluator, DatasetEvaluators
from detectron2.evaluation import DatasetEvaluators
from detectron2.modeling.meta_arch.build import build_model
from detectron2.solver import build_optimizer
from detectron2.utils.events import get_event_storage
Expand All @@ -17,6 +17,7 @@
from aldi.distill import Distiller
from aldi.dropin import DefaultTrainer, AMPTrainer, SimpleTrainer
from aldi.dataloader import SaveWeakDatasetMapper, UnlabeledDatasetMapper, WeakStrongDataloader
from aldi.helpers import Detectron2COCOEvaluatorAdapter
from aldi.ema import EMA


Expand Down Expand Up @@ -158,7 +159,7 @@ def build_evaluator(cls, cfg, dataset_name, output_folder=None):
"""Just do COCO Evaluation."""
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
return DatasetEvaluators([COCOEvaluator(dataset_name, output_dir=output_folder)])
return DatasetEvaluators([Detectron2COCOEvaluatorAdapter(dataset_name, output_dir=output_folder)])

def build_hooks(self):
ret = super(DATrainer, self).build_hooks()
Expand Down
4 changes: 2 additions & 2 deletions tools/slurm_train_net.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ module load anaconda/2023a
# run training script inside anaconda environment
srun -N$SLURM_JOB_NUM_NODES bash -c "\
CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES \
conda run --no-capture-output -n aldi \
python train_net.py \
conda run --no-capture-output -n aldi4 \
python tools/train_net.py \
--machine-rank \$SLURM_PROCID \
--num-machines $SLURM_JOB_NUM_NODES \
--num-gpus 2 \
Expand Down

0 comments on commit e5db124

Please sign in to comment.