Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed Mar 31, 2024
1 parent 6a87725 commit b9c0cb2
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
6 changes: 3 additions & 3 deletions analog/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def __init__(
self._lora_epoch = -1
self._analog_state_schedule = []

self.sanity_check()
self.sanity_check(lora, hessian, save)
self.configure_lora_epoch(lora)
self.configure_schedule(ekfac, lora, sample)
self.configure_schedule(lora, hessian, save)

self._schedule_iterator = iter(self._analog_state_schedule)

Expand Down Expand Up @@ -72,4 +72,4 @@ def __next__(self):
return self._epoch

def __len__(self):
return len(self.analog_state_schedule)
return len(self._analog_state_schedule)
4 changes: 0 additions & 4 deletions examples/mnist_influence/compute_influences_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,12 @@

if not args.resume:
for epoch in al_scheduler:
sample = True if epoch < (len(al_scheduler) - 1) and args.sample else False
for inputs, targets in tqdm(train_loader):
data_id = id_gen(inputs)
with analog(data_id=data_id):
inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
model.zero_grad()
outs = model(inputs)
if sample:
probs = torch.nn.functional.softmax(outs, dim=-1)
targets = torch.multinomial(probs, 1).flatten().detach()
loss = torch.nn.functional.cross_entropy(outs, targets, reduction="sum")
loss.backward()
analog.finalize()
Expand Down

0 comments on commit b9c0cb2

Please sign in to comment.