From b9c0cb28d0532412bcec24d04e72fbc5f21e0997 Mon Sep 17 00:00:00 2001 From: Sang Choe Date: Sat, 30 Mar 2024 23:07:34 -0400 Subject: [PATCH] minor fix --- analog/scheduler.py | 6 +++--- examples/mnist_influence/compute_influences_scheduler.py | 4 ---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/analog/scheduler.py b/analog/scheduler.py index 71eb778b..ed75b0aa 100644 --- a/analog/scheduler.py +++ b/analog/scheduler.py @@ -17,9 +17,9 @@ def __init__( self._lora_epoch = -1 self._analog_state_schedule = [] - self.sanity_check() + self.sanity_check(lora, hessian, save) self.configure_lora_epoch(lora) - self.configure_schedule(ekfac, lora, sample) + self.configure_schedule(lora, hessian, save) self._schedule_iterator = iter(self._analog_state_schedule) @@ -72,4 +72,4 @@ def __next__(self): return self._epoch def __len__(self): - return len(self.analog_state_schedule) + return len(self._analog_state_schedule) diff --git a/examples/mnist_influence/compute_influences_scheduler.py b/examples/mnist_influence/compute_influences_scheduler.py index 4f0e82fa..9d062f1c 100644 --- a/examples/mnist_influence/compute_influences_scheduler.py +++ b/examples/mnist_influence/compute_influences_scheduler.py @@ -50,16 +50,12 @@ if not args.resume: for epoch in al_scheduler: - sample = True if epoch < (len(al_scheduler) - 1) and args.sample else False for inputs, targets in tqdm(train_loader): data_id = id_gen(inputs) with analog(data_id=data_id): inputs, targets = inputs.to(DEVICE), targets.to(DEVICE) model.zero_grad() outs = model(inputs) - if sample: - probs = torch.nn.functional.softmax(outs, dim=-1) - targets = torch.multinomial(probs, 1).flatten().detach() loss = torch.nn.functional.cross_entropy(outs, targets, reduction="sum") loss.backward() analog.finalize()