Skip to content

Commit

Permalink
ignore index in testing (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
VMarsocci authored Sep 17, 2024
1 parent 2b035ee commit 8b5def3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion engine/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, args, val_loader, exp_dir, device):
self.split = self.val_loader.dataset.split
self.num_classes = len(self.classes)
self.max_name_len = max([len(name) for name in self.classes])
self.ignore_index = args["dataset"]["ignore_index"]

if args.use_wandb:
import wandb
Expand Down Expand Up @@ -67,7 +68,7 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None):
pred = (torch.sigmoid(logits) > 0.5).type(torch.int64).squeeze(dim=1)
else:
pred = torch.argmax(logits, dim=1)
valid_mask = target != -1
valid_mask = target != self.ignore_index
pred, target = pred[valid_mask], target[valid_mask]
count = torch.bincount((pred * self.num_classes + target), minlength=self.num_classes ** 2)
confusion_matrix += count.view(self.num_classes, self.num_classes)
Expand Down
3 changes: 2 additions & 1 deletion engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, args, model, train_loader, criterion, optimizer, lr_scheduler
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.evaluator = evaluator
self.ignore_index = args["dataset"]["ignore_index"]
self.logger = logging.getLogger()
self.training_stats = {name: RunningAverageMeter(length=self.batch_per_epoch) for name in ['loss', 'data_time', 'batch_time', 'eval_time']}
self.training_metrics = {}
Expand Down Expand Up @@ -234,7 +235,7 @@ def compute_logging_metrics(self, logits, target):
else:
pred = torch.argmax(logits, dim=1, keepdim=True)
target = target.unsqueeze(1)
ignore_mask = target == -1
ignore_mask = target == self.ignore_index
target[ignore_mask] = 0
ignore_mask = ignore_mask.expand(-1, num_classes if num_classes > 1 else 2, -1, -1)

Expand Down

0 comments on commit 8b5def3

Please sign in to comment.