diff --git a/examples/imagenet/ddp_analyze.py b/examples/imagenet/ddp_analyze.py index 957467b..eedcfd6 100644 --- a/examples/imagenet/ddp_analyze.py +++ b/examples/imagenet/ddp_analyze.py @@ -67,33 +67,36 @@ def compute_model_output(self, batch: BATCH_DTYPE, model: nn.Module) -> torch.Te def compute_train_loss( self, batch: BATCH_DTYPE, - outputs: torch.Tensor, + model: nn.Module, sample: bool = False, ) -> torch.Tensor: - _, labels = batch + inputs, labels = batch + logits = model(inputs) if not sample: - return F.cross_entropy(outputs, labels, reduction="sum") + return F.cross_entropy(logits, labels, reduction="sum") with torch.no_grad(): - probs = torch.nn.functional.softmax(outputs, dim=-1) + probs = torch.nn.functional.softmax(logits, dim=-1) sampled_labels = torch.multinomial( probs, num_samples=1, ).flatten() - return F.cross_entropy(outputs, sampled_labels.detach(), reduction="sum") + return F.cross_entropy(logits, sampled_labels.detach(), reduction="sum") def compute_measurement( self, batch: BATCH_DTYPE, - outputs: torch.Tensor, + model: nn.Module, ) -> torch.Tensor: - _, labels = batch + # Copied from https://github.com/MadryLab/trak/blob/main/trak/modelout_functions.py. + inputs, labels = batch + logits = model(inputs) - bindex = torch.arange(outputs.shape[0]).to(device=outputs.device, non_blocking=False) - logits_correct = outputs[bindex, labels] + bindex = torch.arange(logits.shape[0]).to(device=logits.device, non_blocking=False) + logits_correct = logits[bindex, labels] - cloned_logits = outputs.clone() - cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=outputs.device, dtype=outputs.dtype) + cloned_logits = logits.clone() + cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=logits.device, dtype=logits.dtype) margins = logits_correct - cloned_logits.logsumexp(dim=-1) return -margins.sum()