Skip to content

Commit

Permalink
Modify Task for ImageNet
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 12, 2024
1 parent cde8ba3 commit a618164
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions examples/imagenet/ddp_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,33 +67,36 @@ def compute_model_output(self, batch: BATCH_DTYPE, model: nn.Module) -> torch.Te
def compute_train_loss(
self,
batch: BATCH_DTYPE,
outputs: torch.Tensor,
model: nn.Module,
sample: bool = False,
) -> torch.Tensor:
_, labels = batch
inputs, labels = batch
logits = model(inputs)

if not sample:
return F.cross_entropy(outputs, labels, reduction="sum")
return F.cross_entropy(logits, labels, reduction="sum")
with torch.no_grad():
probs = torch.nn.functional.softmax(outputs, dim=-1)
probs = torch.nn.functional.softmax(logits, dim=-1)
sampled_labels = torch.multinomial(
probs,
num_samples=1,
).flatten()
return F.cross_entropy(outputs, sampled_labels.detach(), reduction="sum")
return F.cross_entropy(logits, sampled_labels.detach(), reduction="sum")

def compute_measurement(
self,
batch: BATCH_DTYPE,
outputs: torch.Tensor,
model: nn.Module,
) -> torch.Tensor:
_, labels = batch
# Copied from https://github.com/MadryLab/trak/blob/main/trak/modelout_functions.py.
inputs, labels = batch
logits = model(inputs)

bindex = torch.arange(outputs.shape[0]).to(device=outputs.device, non_blocking=False)
logits_correct = outputs[bindex, labels]
bindex = torch.arange(logits.shape[0]).to(device=logits.device, non_blocking=False)
logits_correct = logits[bindex, labels]

cloned_logits = outputs.clone()
cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=outputs.device, dtype=outputs.dtype)
cloned_logits = logits.clone()
cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=logits.device, dtype=logits.dtype)

margins = logits_correct - cloned_logits.logsumexp(dim=-1)
return -margins.sum()
Expand Down

0 comments on commit a618164

Please sign in to comment.