Skip to content

Commit

Permalink
Remove the older (uglier) test suite for algos
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jul 17, 2024
1 parent 8631015 commit f0c0020
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 125 deletions.
19 changes: 12 additions & 7 deletions project/algorithms/example_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from typing import ClassVar

import torch

from project.algorithms.testsuites.classification_tests import ClassificationAlgorithmTests
from project.algorithms.testsuites.algorithm_tests import LearningAlgorithmTests
from project.datamodules.image_classification.image_classification import (
ImageClassificationDataModule,
)
from project.utils.testutils import run_for_all_configs_of_type

from .example import ExampleAlgorithm


class TestExampleAlgorithm(ClassificationAlgorithmTests[ExampleAlgorithm]):
algorithm_type = ExampleAlgorithm
unsupported_datamodule_names: ClassVar[list[str]] = ["rl"]
_supported_network_types: ClassVar[list[type]] = [torch.nn.Module]
@run_for_all_configs_of_type("datamodule", ImageClassificationDataModule)
@run_for_all_configs_of_type("network", torch.nn.Module)
class TestExampleAlgo(LearningAlgorithmTests[ExampleAlgorithm]):
"""Tests for the `ExampleAlgorithm`.
See `LearningAlgorithmTests` for more information on the built-in tests.
"""
19 changes: 10 additions & 9 deletions project/algorithms/jax_example_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from typing import ClassVar

import flax
import flax.linen
import torch
import pytest

from project.algorithms.jax_example import JaxExample
from project.datamodules.image_classification.image_classification import (
ImageClassificationDataModule,
)
from project.utils.testutils import run_for_all_configs_of_type

from .testsuites.algorithm_tests import AlgorithmTests

from .testsuites.algorithm_tests import LearningAlgorithmTests

class TestJaxExample(AlgorithmTests[JaxExample]):
"""This algorithm only works with Jax modules."""

unsupported_network_types: ClassVar[list[type]] = [torch.nn.Module]
_supported_network_types: ClassVar[list[type]] = [flax.linen.Module]
@pytest.mark.timeout(10)
@run_for_all_configs_of_type("datamodule", ImageClassificationDataModule)
@run_for_all_configs_of_type("network", flax.linen.Module)
class TestJaxExample(LearningAlgorithmTests[JaxExample]): ...
62 changes: 5 additions & 57 deletions project/algorithms/testsuites/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Literal, NotRequired, Protocol, TypedDict
from typing import NotRequired, Protocol, TypedDict

import torch
from lightning import Callback, Trainer
from lightning import LightningDataModule, LightningModule, Trainer
from torch import Tensor
from typing_extensions import TypeVar

Expand Down Expand Up @@ -33,7 +33,7 @@ class Algorithm(Module, Protocol[BatchType, StepOutputType]):
architecture.
"""

datamodule: DataModule[BatchType]
datamodule: LightningDataModule | DataModule[BatchType]
network: Module

def __init__(
Expand All @@ -47,57 +47,5 @@ def __init__(
self.network = network
self.trainer: Trainer

def training_step(self, batch: BatchType, batch_index: int) -> StepOutputType:
"""Performs a training step.
See `LightningModule.training_step` for more information.
"""
return self.shared_step(batch=batch, batch_index=batch_index, phase="train")

def validation_step(self, batch: BatchType, batch_index: int) -> StepOutputType:
"""Performs a validation step."""
return self.shared_step(batch=batch, batch_index=batch_index, phase="val")

def test_step(self, batch: BatchType, batch_index: int) -> StepOutputType:
"""Performs a test step."""
return self.shared_step(batch=batch, batch_index=batch_index, phase="test")

def shared_step(
self, batch: BatchType, batch_index: int, phase: Literal["train", "val", "test"]
) -> StepOutputType:
"""Performs a training/validation/test step.
This must return a nested dictionary of tensors matching the `StepOutputType` typedict for
this algorithm. By default,
`loss` entry. This is so that the training of the model is easier to parallelize the
training across GPUs:
- the cross entropy loss gets calculated using the global batch size
- the main metrics are logged inside `training_step_end` (supposed to be better for DP/DDP)
"""
raise NotImplementedError

def configure_optimizers(self):
# """Creates the optimizers and the learning rate schedulers."""'
raise NotImplementedError

def forward(self, x: Tensor) -> Tensor:
"""Performs a forward pass.
Feel free to overwrite this to do whatever you'd like.
"""
assert self.network is not None
return self.network(x)

def configure_callbacks(self) -> list[Callback]:
"""Use this to add some callbacks that should always be included with the model."""
return []

@property
def device(self) -> torch.device:
if self._device is None:
self._device = next((p.device for p in self.parameters()), torch.device("cpu"))
device = self._device
# make this more explicit to always include the index
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
return device
training_step = LightningModule.training_step
# validation_step = LightningModule.validation_step
Loading

0 comments on commit f0c0020

Please sign in to comment.