Skip to content

Commit

Permalink
Make Algorithm a protocol class
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jul 8, 2024
1 parent 494d240 commit 2eb3dc1
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 800 deletions.
2 changes: 0 additions & 2 deletions project/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from . import algorithms, configs, datamodules, experiment, main, networks, utils
from .algorithms import Algorithm
from .configs import Config
from .experiment import Experiment

Expand All @@ -14,7 +13,6 @@
"configs",
"datamodules",
"networks",
"Algorithm",
"DataModule",
"utils",
# "ExampleAlgorithm",
Expand Down
2 changes: 0 additions & 2 deletions project/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from project.algorithms.jax_example import ExampleJaxAlgo
from project.algorithms.no_op import NoOp

from .algorithm import Algorithm
from .example_algo import ExampleAlgorithm
from .manual_optimization_example import ManualGradientsExample

Expand All @@ -23,7 +22,6 @@
algorithm_store.add_to_hydra_store()

__all__ = [
"Algorithm",
"ExampleAlgorithm",
"ManualGradientsExample",
"ExampleJaxAlgo",
Expand Down
65 changes: 10 additions & 55 deletions project/algorithms/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import NotRequired, TypedDict
from typing import NotRequired, Protocol, TypedDict

import torch
from lightning import Callback, LightningModule, Trainer
Expand Down Expand Up @@ -31,7 +29,7 @@ class StepOutputDict(TypedDict, total=False):
)


class Algorithm(LightningModule, ABC, Generic[BatchType, StepOutputType]):
class Algorithm(Module, Protocol[BatchType, StepOutputType]):
"""Base class for a learning algorithm.
This is an extension of the LightningModule class from PyTorch Lightning, with some common
Expand All @@ -42,21 +40,21 @@ class Algorithm(LightningModule, ABC, Generic[BatchType, StepOutputType]):
architecture.
"""

@dataclass
class HParams:
"""Hyper-parameters of the algorithm."""
datamodule: DataModule[BatchType]
network: Module

example_input_array = LightningModule.example_input_array
_device: torch.device | None = None

def __init__(
self,
*,
datamodule: DataModule[BatchType] | None = None,
network: Module | None = None,
hp: HParams | None = None,
datamodule: DataModule[BatchType],
network: Module,
):
super().__init__()
self.datamodule = datamodule
self.network = network
self.hp = hp or self.HParams()
# fix for `self.device` property which defaults to cpu.
self._device = None

Expand Down Expand Up @@ -91,11 +89,9 @@ def shared_step(self, batch: BatchType, batch_index: int, phase: PhaseStr) -> St
"""
raise NotImplementedError

@abstractmethod
def configure_optimizers(self):
# """Creates the optimizers and the learning rate schedulers."""'
# super().configure_optimizers()
...
raise NotImplementedError

def forward(self, x: Tensor) -> Tensor:
"""Performs a forward pass.
Expand All @@ -105,47 +101,6 @@ def forward(self, x: Tensor) -> Tensor:
assert self.network is not None
return self.network(x)

def training_step_end(self, step_output: StepOutputDict) -> StepOutputDict:
"""Called with the results of each worker / replica's output.
See the `training_step_end` of pytorch-lightning for more info.
"""
return self.shared_step_end(step_output, phase="train")

def validation_step_end[Out: torch.Tensor | StepOutputDict](self, step_output: Out) -> Out:
return self.shared_step_end(step_output, phase="val")

def test_step_end[Out: torch.Tensor | StepOutputDict](self, step_output: Out) -> Out:
return self.shared_step_end(step_output, phase="test")

def shared_step_end[Out: torch.Tensor | StepOutputDict](
self, step_output: Out, phase: PhaseStr
) -> Out:
"""This is a default implementation for `[train/validation/test]_step_end`.
This does the following:
- Averages out the `loss` tensor if it was left unreduced.
- the main metrics are logged inside `training_step_end` (supposed to be better for DP/DDP)
"""

if (
isinstance(step_output, dict)
and isinstance((loss := step_output.get("loss")), torch.Tensor)
and loss.shape
):
# Replace the loss with its mean. This is useful when automatic
# optimization is enabled, for example in the example algo, where each replica
# returns the un-reduced cross-entropy loss. Here we need to reduce it to a scalar.
# todo: find out if this was already logged, to not log it twice.
# self.log(f"{phase}/loss", loss.mean(), sync_dist=True)
return step_output | {"loss": loss.mean()}

elif isinstance(step_output, torch.Tensor) and (loss := step_output).shape:
return loss.mean()

# self.log(f"{phase}/loss", torch.as_tensor(loss).mean(), sync_dist=True)
return step_output

def configure_callbacks(self) -> list[Callback]:
"""Use this to add some callbacks that should always be included with the model."""
return []
Expand Down
Loading

0 comments on commit 2eb3dc1

Please sign in to comment.