Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 20, 2024
1 parent da593de commit a27ce20
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 22 deletions.
24 changes: 10 additions & 14 deletions examples/cifar/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def train(
num_train_epochs: int,
learning_rate: float,
weight_decay: float,
disable_tqdm: bool = False,
) -> nn.Module:
train_dataloader = data.DataLoader(
dataset=dataset,
Expand All @@ -116,19 +115,16 @@ def train(
model.train()
for epoch in range(num_train_epochs):
total_loss = 0.0
with tqdm(train_dataloader, unit="batch", disable=disable_tqdm) as tepoch:
for batch in tepoch:
tepoch.set_description(f"Epoch {epoch}")
model.zero_grad()
inputs, labels = batch
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
outputs = model(inputs)
loss = F.cross_entropy(outputs, labels)
loss.backward()
optimizer.step()
scheduler.step()
total_loss += loss.detach().float()
tepoch.set_postfix(loss=total_loss.item() / len(train_dataloader))
for batch in train_dataloader:
model.zero_grad()
inputs, labels = batch
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
outputs = model(inputs)
loss = F.cross_entropy(outputs, labels)
loss.backward()
optimizer.step()
scheduler.step()
total_loss += loss.detach().float()
return model


Expand Down
12 changes: 6 additions & 6 deletions examples/glue/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ def train(
total_loss = 0.0
for batch in train_dataloader:
model.zero_grad()
outputs = model(
loss = model(
input_ids=batch["input_ids"].to(device=DEVICE),
attention_mask=batch["attention_mask"].to(device=DEVICE),
token_type_ids=batch["token_type_ids"].to(device=DEVICE),
).logits
loss = F.cross_entropy(outputs, batch["labels"].to(device=DEVICE))
labels=batch["labels"].to(device=DEVICE)
)
loss.backward()
optimizer.step()
total_loss += loss.detach().float()
Expand All @@ -127,14 +127,14 @@ def evaluate_model(model: nn.Module, dataset: data.Dataset, batch_size: int) ->
total_loss = 0.0
for batch in dataloader:
with torch.no_grad():
outputs = model(
logits = model(
batch["input_ids"].to(device=DEVICE),
batch["token_type_ids"].to(device=DEVICE),
batch["attention_mask"].to(device=DEVICE),
).logits
labels = batch["labels"].to(device=DEVICE)
total_loss += F.cross_entropy(outputs, labels, reduction="sum").detach()
predictions = outputs.argmax(dim=-1)
total_loss += F.cross_entropy(logits, labels, reduction="sum").detach()
predictions = logits.argmax(dim=-1)
metric.add_batch(
predictions=predictions,
references=labels,
Expand Down
12 changes: 11 additions & 1 deletion examples/imagenet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,21 @@ python analyze.py --dataset_dir /mfs1/datasets/imagenet_pytorch/ \
--train_batch_size 256 \
--factor_strategy ekfac
```
On A100 (80GB), it takes roughly 1.5 minutes to compute the pairwise scores (including computing EKFAC factors).
On A100 (80GB), it takes roughly 10 hours to compute the pairwise scores (including computing EKFAC factors).

We can also use query batching to compute influence scores with larger query batch size.
```bash
python analyze.py --dataset_dir /mfs1/datasets/imagenet_pytorch/ \
--query_gradient_rank 32 \
--query_batch_size 500 \
--train_batch_size 512 \
--factor_strategy ekfac
```


## Computing Pairwise Influence Scores with DDP

You can also use DistributedDataParallel to speed up influence computations.
```bash
torchrun --standalone --nnodes=1 --nproc-per-node=4 ddp_analyze.py
```
2 changes: 1 addition & 1 deletion examples/imagenet/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def main():
)

rank = args.query_gradient_rank if args.query_gradient_rank != -1 else None
score_args = ScoreArguments(query_gradient_rank=rank)
score_args = ScoreArguments(query_gradient_rank=rank, query_gradient_svd_dtype=torch.float32)
scores_name = "pairwise"
if rank is not None:
scores_name += f"_qlr{rank}"
Expand Down

0 comments on commit a27ce20

Please sign in to comment.