From 301c49e21fa6cd730764a5625ae5aa910cb927bc Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 19 Jun 2024 18:30:26 +0000 Subject: [PATCH] Simplify some of the algos some more Signed-off-by: Fabrice Normandin --- project/algorithms/bases/algorithm.py | 2 +- project/algorithms/jax_algo.py | 7 +++---- project/algorithms/manual_optimization_example.py | 5 ++--- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/project/algorithms/bases/algorithm.py b/project/algorithms/bases/algorithm.py index c37eed53..d65391de 100644 --- a/project/algorithms/bases/algorithm.py +++ b/project/algorithms/bases/algorithm.py @@ -107,7 +107,7 @@ def configure_callbacks(self) -> list[Callback]: @property def device(self) -> torch.device: if self._device is None: - self._device = next(p.device for p in self.parameters()) + self._device = next((p.device for p in self.parameters()), torch.device("cpu")) device = self._device # make this more explicit to always include the index if device.type == "cuda" and device.index is None: diff --git a/project/algorithms/jax_algo.py b/project/algorithms/jax_algo.py index 5b527e7d..3459ef26 100644 --- a/project/algorithms/jax_algo.py +++ b/project/algorithms/jax_algo.py @@ -6,9 +6,6 @@ import flax.linen import jax -import lightning -import lightning.pytorch -import lightning.pytorch.callbacks import rich import rich.logging import torch @@ -196,11 +193,13 @@ def main(): logging.basicConfig( level=logging.INFO, format="%(message)s", handlers=[rich.logging.RichHandler()] ) + from lightning.pytorch.callbacks import RichProgressBar + trainer = Trainer( devices="auto", max_epochs=10, accelerator="auto", - callbacks=[lightning.pytorch.callbacks.RichProgressBar()], + callbacks=[RichProgressBar()], ) datamodule = MNISTDataModule(num_workers=4, batch_size=512) network = CNN(num_classes=datamodule.num_classes) diff --git a/project/algorithms/manual_optimization_example.py b/project/algorithms/manual_optimization_example.py index 01c4dc03..89b5d053 100644 --- a/project/algorithms/manual_optimization_example.py +++ b/project/algorithms/manual_optimization_example.py @@ -99,11 +99,10 @@ def shared_step( # NOTE: You don't need to call `loss.backward()`, you could also just set .grads # directly! - loss.backward() + self.manual_backward(loss) for name, parameter in self.named_parameters(): - if parameter.grad is None: - continue + assert parameter.grad is not None, name parameter.grad += self.hp.gradient_noise_std * torch.randn_like(parameter.grad) optimizer.step()