diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 03f67b11..b5eeb12c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ default_language_version: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.6.0 hooks: # list of supported hooks: https://pre-commit.com/hooks.html - id: trailing-whitespace @@ -32,7 +32,7 @@ repos: - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: "v0.3.3" + rev: "v0.5.1" hooks: - id: ruff args: ['--line-length', '99', '--fix'] @@ -41,7 +41,7 @@ repos: # python docstring formatting - repo: https://github.com/myint/docformatter - rev: v1.5.1 + rev: v1.7.5 hooks: - id: docformatter args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] @@ -64,7 +64,7 @@ repos: # Dependency management - repo: https://github.com/pdm-project/pdm - rev: 2.12.4 + rev: 2.16.1 hooks: - id: pdm-lock-check require_serial: true @@ -91,7 +91,7 @@ repos: # word spelling linter - repo: https://github.com/codespell-project/codespell - rev: v2.2.2 + rev: v2.3.0 hooks: - id: codespell args: diff --git a/project/__init__.py b/project/__init__.py index 9a0b0091..de8fc9a0 100644 --- a/project/__init__.py +++ b/project/__init__.py @@ -1,11 +1,14 @@ from . import algorithms, configs, datamodules, experiment, main, networks, utils -from .algorithms import Algorithm -from .configs import Config +from .configs import Config, add_configs_to_hydra_store from .experiment import Experiment +from .utils.hydra_utils import patched_safe_name # noqa # from .networks import FcNet from .utils.types import DataModule +add_configs_to_hydra_store() + + __all__ = [ "algorithms", "experiment", @@ -14,7 +17,6 @@ "configs", "datamodules", "networks", - "Algorithm", "DataModule", "utils", # "ExampleAlgorithm", diff --git a/project/algorithms/__init__.py b/project/algorithms/__init__.py index cd52e402..d5e78cc3 100644 --- a/project/algorithms/__init__.py +++ b/project/algorithms/__init__.py @@ -1,30 +1,10 @@ -from hydra_zen import builds, store - -from project.algorithms.jax_algo import JaxAlgorithm +from project.algorithms.jax_example import JaxExample from project.algorithms.no_op import NoOp -from .algorithm import Algorithm -from .example_algo import ExampleAlgorithm -from .manual_optimization_example import ManualGradientsExample - -# NOTE: This works the same way as creating config files for each algorithm under -# `configs/algorithm`. From the command-line, you can select both configs that are yaml files as -# well as structured config (dataclasses). - -# If you add a configuration file under `configs/algorithm`, it will also be available as an option -# from the command-line, and be validated against the schema. -# todo: It might be nicer if we did this this `configs/algorithms` instead of here, no? -algorithm_store = store(group="algorithm") -algorithm_store(ExampleAlgorithm.HParams(), name="example_algo") -algorithm_store(ManualGradientsExample.HParams(), name="manual_optimization") -algorithm_store(builds(NoOp, populate_full_signature=False), name="no_op") -algorithm_store(JaxAlgorithm.HParams(), name="jax_algo") - -algorithm_store.add_to_hydra_store() +from .example import ExampleAlgorithm __all__ = [ - "Algorithm", "ExampleAlgorithm", - "ManualGradientsExample", - "JaxAlgorithm", + "JaxExample", + "NoOp", ] diff --git a/project/algorithms/callbacks/callback.py b/project/algorithms/callbacks/callback.py index 8b0bb3ca..0bbde608 100644 --- a/project/algorithms/callbacks/callback.py +++ b/project/algorithms/callbacks/callback.py @@ -1,24 +1,30 @@ from __future__ import annotations +from collections.abc import Mapping from logging import getLogger as get_logger from pathlib import Path -from typing import Literal, override +from typing import Any, Generic, Literal, override import torch -from lightning import Trainer +from lightning import LightningModule, Trainer from lightning import pytorch as pl -from typing_extensions import Generic # noqa +from typing_extensions import TypeVar -from project.algorithms.algorithm import Algorithm, BatchType, StepOutputDict, StepOutputType -from project.utils.types import PhaseStr, PyTree +from project.utils.types import PyTree from project.utils.utils import get_log_dir logger = get_logger(__name__) +BatchType = TypeVar("BatchType", bound=PyTree[torch.Tensor], contravariant=True) +StepOutputType = TypeVar( + "StepOutputType", + bound=torch.Tensor | Mapping[str, Any] | None, + default=dict[str, torch.Tensor], + contravariant=True, +) -class Callback[BatchType: PyTree[torch.Tensor], StepOutputType: torch.Tensor | StepOutputDict]( - pl.Callback -): + +class Callback(pl.Callback, Generic[BatchType, StepOutputType]): """Adds a bit of typing info and shared functions to the PyTorch Lightning Callback class. Adds the following typing information: @@ -40,7 +46,7 @@ def __init__(self) -> None: def setup( self, trainer: pl.Trainer, - pl_module: Algorithm[BatchType, StepOutputType], + pl_module: LightningModule, # todo: "tune" is mentioned in the docstring, is it still used? stage: Literal["fit", "validate", "test", "predict", "tune"], ) -> None: @@ -49,10 +55,10 @@ def setup( def on_shared_batch_start( self, trainer: Trainer, - pl_module: Algorithm[BatchType, StepOutputType], + pl_module: LightningModule, batch: BatchType, batch_index: int, - phase: PhaseStr, + phase: Literal["train", "val", "test"], dataloader_idx: int | None = None, ): """Shared hook, called by `on_[train/validation/test]_batch_start`. @@ -63,11 +69,11 @@ def on_shared_batch_start( def on_shared_batch_end( self, trainer: Trainer, - pl_module: Algorithm[BatchType, StepOutputType], + pl_module: LightningModule, outputs: StepOutputType, batch: BatchType, batch_index: int, - phase: PhaseStr, + phase: Literal["train", "val", "test"], dataloader_idx: int | None = None, ): """Shared hook, called by `on_[train/validation/test]_batch_end`. @@ -78,8 +84,8 @@ def on_shared_batch_end( def on_shared_epoch_start( self, trainer: Trainer, - pl_module: Algorithm[BatchType, StepOutputType], - phase: PhaseStr, + pl_module: LightningModule, + phase: Literal["train", "val", "test"], ) -> None: """Shared hook, called by `on_[train/validation/test]_epoch_start`. @@ -89,8 +95,8 @@ def on_shared_epoch_start( def on_shared_epoch_end( self, trainer: Trainer, - pl_module: Algorithm[BatchType, StepOutputType], - phase: PhaseStr, + pl_module: LightningModule, + phase: Literal["train", "val", "test"], ) -> None: """Shared hook, called by `on_[train/validation/test]_epoch_end`. @@ -101,7 +107,7 @@ def on_shared_epoch_end( def on_train_batch_end( self, trainer: Trainer, - pl_module: Algorithm[BatchType, StepOutputType], + pl_module: LightningModule, outputs: StepOutputType, batch: BatchType, batch_index: int, @@ -126,7 +132,7 @@ def on_train_batch_end( def on_validation_batch_end( self, trainer: Trainer, - pl_module: Algorithm[BatchType, StepOutputType], + pl_module: LightningModule, outputs: StepOutputType, batch: BatchType, batch_index: int, @@ -154,7 +160,7 @@ def on_validation_batch_end( def on_test_batch_end( self, trainer: Trainer, - pl_module: Algorithm[BatchType, StepOutputType], + pl_module: LightningModule, outputs: StepOutputType, batch: BatchType, batch_index: int, @@ -182,7 +188,7 @@ def on_test_batch_end( def on_train_batch_start( self, trainer: Trainer, - pl_module: Algorithm[BatchType, StepOutputType], + pl_module: LightningModule, batch: BatchType, batch_index: int, ) -> None: @@ -199,7 +205,7 @@ def on_train_batch_start( def on_validation_batch_start( self, trainer: Trainer, - pl_module: Algorithm[BatchType, StepOutputType], + pl_module: LightningModule, batch: BatchType, batch_index: int, dataloader_idx: int = 0, @@ -218,7 +224,7 @@ def on_validation_batch_start( def on_test_batch_start( self, trainer: Trainer, - pl_module: Algorithm[BatchType, StepOutputType], + pl_module: LightningModule, batch: BatchType, batch_index: int, dataloader_idx: int = 0, @@ -234,43 +240,31 @@ def on_test_batch_start( ) @override - def on_train_epoch_start( - self, trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType] - ) -> None: + def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_train_epoch_start(trainer, pl_module) self.on_shared_epoch_start(trainer, pl_module, phase="train") @override - def on_validation_epoch_start( - self, trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType] - ) -> None: + def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_validation_epoch_start(trainer, pl_module) self.on_shared_epoch_start(trainer, pl_module, phase="val") @override - def on_test_epoch_start( - self, trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType] - ) -> None: + def on_test_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_test_epoch_start(trainer, pl_module) self.on_shared_epoch_start(trainer, pl_module, phase="test") @override - def on_train_epoch_end( - self, trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType] - ) -> None: + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_train_epoch_end(trainer, pl_module) self.on_shared_epoch_end(trainer, pl_module, phase="train") @override - def on_validation_epoch_end( - self, trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType] - ) -> None: + def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_validation_epoch_end(trainer, pl_module) self.on_shared_epoch_end(trainer, pl_module, phase="val") @override - def on_test_epoch_end( - self, trainer: Trainer, pl_module: Algorithm[BatchType, StepOutputType] - ) -> None: + def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_test_epoch_end(trainer, pl_module) self.on_shared_epoch_end(trainer, pl_module, phase="test") diff --git a/project/algorithms/callbacks/classification_metrics.py b/project/algorithms/callbacks/classification_metrics.py index c76d9db9..48cd38da 100644 --- a/project/algorithms/callbacks/classification_metrics.py +++ b/project/algorithms/callbacks/classification_metrics.py @@ -1,6 +1,6 @@ import warnings from logging import getLogger as get_logger -from typing import NotRequired, Required, TypedDict, override +from typing import Literal, NotRequired, Required, TypedDict, override import torch import torchmetrics @@ -8,17 +8,16 @@ from torch import Tensor from torchmetrics.classification import MulticlassAccuracy -from project.algorithms.algorithm import Algorithm, BatchType -from project.algorithms.callbacks.callback import Callback -from project.utils.types import PhaseStr +from project.algorithms.callbacks.callback import BatchType, Callback from project.utils.types.protocols import ClassificationDataModule logger = get_logger(__name__) class ClassificationOutputs(TypedDict, total=False): - """The dictionary format that is minimally required to be returned from - `training/val/test_step` for classification algorithms.""" + """The outputs that should be minimally returned from the training/val/test_step of + classification LightningModules so that metrics can be added aumatically by the + `ClassificationMetricsCallback`.""" loss: NotRequired[torch.Tensor | float] """The loss at this step.""" @@ -31,14 +30,14 @@ class ClassificationOutputs(TypedDict, total=False): class ClassificationMetricsCallback(Callback[BatchType, ClassificationOutputs]): - """Callback that adds classification metrics to the pl module.""" + """Callback that adds classification metrics to a LightningModule.""" def __init__(self) -> None: super().__init__() self.disabled = False @classmethod - def attach_to(cls, algorithm: Algorithm, num_classes: int): + def attach_to(cls, algorithm: LightningModule, num_classes: int): callback = cls() callback.add_metrics_to(algorithm, num_classes=num_classes) return callback @@ -84,8 +83,8 @@ def _get_metric(pl_module: LightningModule, name: str): def setup( self, trainer: Trainer, - pl_module: Algorithm[BatchType, ClassificationOutputs], - stage: PhaseStr, + pl_module: LightningModule, + stage: Literal["fit", "validate", "test", "predict", "tune"], ) -> None: if self.disabled: return @@ -108,11 +107,11 @@ def setup( def on_shared_batch_end( self, trainer: Trainer, - pl_module: Algorithm[BatchType, ClassificationOutputs], + pl_module: LightningModule, outputs: ClassificationOutputs, batch: BatchType, batch_index: int, - phase: PhaseStr, + phase: Literal["train", "val", "test"], dataloader_idx: int | None = None, ): if self.disabled: diff --git a/project/algorithms/callbacks/samples_per_second.py b/project/algorithms/callbacks/samples_per_second.py index 187bd247..fc3b6de7 100644 --- a/project/algorithms/callbacks/samples_per_second.py +++ b/project/algorithms/callbacks/samples_per_second.py @@ -1,19 +1,18 @@ import time -from typing import override +from typing import Literal, override from lightning import LightningModule, Trainer from torch import Tensor from torch.optim import Optimizer -from project.algorithms.algorithm import Algorithm, BatchType, StepOutputDict -from project.algorithms.callbacks.callback import Callback -from project.utils.types import PhaseStr, is_sequence_of +from project.algorithms.callbacks.callback import BatchType, Callback, StepOutputType +from project.utils.types import is_sequence_of -class MeasureSamplesPerSecondCallback(Callback[BatchType, StepOutputDict]): +class MeasureSamplesPerSecondCallback(Callback[BatchType, StepOutputType]): def __init__(self): super().__init__() - self.last_step_times: dict[PhaseStr, float] = {} + self.last_step_times: dict[Literal["train", "val", "test"], float] = {} self.last_update_time: dict[int, float | None] = {} self.num_optimizers: int | None = None @@ -21,8 +20,8 @@ def __init__(self): def on_shared_epoch_start( self, trainer: Trainer, - pl_module: Algorithm[BatchType, StepOutputDict], - phase: PhaseStr, + pl_module: LightningModule, + phase: Literal["train", "val", "test"], ) -> None: self.last_update_time.clear() self.last_step_times.pop(phase, None) @@ -37,11 +36,11 @@ def on_shared_epoch_start( def on_shared_batch_end( self, trainer: Trainer, - pl_module: Algorithm[BatchType, StepOutputDict], - outputs: StepOutputDict, + pl_module: LightningModule, + outputs: StepOutputType, batch: BatchType, batch_index: int, - phase: PhaseStr, + phase: Literal["train", "val", "test"], dataloader_idx: int | None = None, ): super().on_shared_batch_end( @@ -71,7 +70,11 @@ def on_shared_batch_end( @override def on_before_optimizer_step( - self, trainer: Trainer, pl_module: LightningModule, optimizer: Optimizer, opt_idx: int = 0 + self, + trainer: Trainer, + pl_module: LightningModule, + optimizer: Optimizer, + opt_idx: int = 0, ) -> None: if opt_idx not in self.last_update_time or self.last_update_time[opt_idx] is None: self.last_update_time[opt_idx] = time.perf_counter() diff --git a/project/algorithms/example.py b/project/algorithms/example.py new file mode 100644 index 00000000..3fbbe7e7 --- /dev/null +++ b/project/algorithms/example.py @@ -0,0 +1,166 @@ +"""Example of an algorithm, which is a Pytorch Lightning image classifier. + +Uses regular backpropagation. +""" + +import dataclasses +import functools +from logging import getLogger +from typing import Annotated, Any, Literal + +import pydantic +import torch +from hydra_zen import instantiate +from lightning import LightningModule +from lightning.pytorch.callbacks import Callback, EarlyStopping +from pydantic import NonNegativeInt, PositiveInt +from torch import Tensor +from torch.nn import functional as F +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + +from project.algorithms.callbacks.classification_metrics import ClassificationMetricsCallback +from project.configs.algorithm.lr_scheduler import CosineAnnealingLRConfig +from project.configs.algorithm.optimizer import AdamConfig +from project.datamodules.image_classification import ImageClassificationDataModule + +logger = getLogger(__name__) + +LRSchedulerConfig = Annotated[Any, pydantic.Field(default_factory=CosineAnnealingLRConfig)] + + +class ExampleAlgorithm(LightningModule): + """Example learning algorithm for image classification.""" + + @pydantic.dataclasses.dataclass(frozen=True) + class HParams: + """Hyper-Parameters.""" + + # Arguments to be passed to the LR scheduler. + lr_scheduler: LRSchedulerConfig = dataclasses.field( + default=CosineAnnealingLRConfig(T_max=85, eta_min=1e-5), + metadata={"omegaconf_ignore": True}, + ) + + lr_scheduler_interval: Literal["step", "epoch"] = "epoch" + + # Frequency of the LR scheduler. Set to 0 to disable the lr scheduler. + lr_scheduler_frequency: NonNegativeInt = 1 + + # Hyper-parameters for the optimizer + optimizer: Any = AdamConfig(lr=3e-4) + + batch_size: PositiveInt = 128 + + # Max number of epochs to train for without an improvement to the validation + # accuracy before the training is stopped. + early_stopping_patience: NonNegativeInt = 0 + + def __init__( + self, + datamodule: ImageClassificationDataModule, + network: torch.nn.Module, + hp: HParams = HParams(), + ): + super().__init__() + self.datamodule = datamodule + self.network = network + self.hp = hp or self.HParams() + + # Used by Pytorch-Lightning to compute the input/output shapes of the network. + self.example_input_array = torch.zeros( + (datamodule.batch_size, *datamodule.dims), device=self.device + ) + # Do a forward pass to initialize any lazy weights. This is necessary for distributed + # training and to infer shapes. + _ = self.network(self.example_input_array) + + # Save hyper-parameters. + self.save_hyperparameters({"hp": dataclasses.asdict(self.hp)}) + + def forward(self, input: Tensor) -> Tensor: + logits = self.network(input) + return logits + + def training_step(self, batch: tuple[Tensor, Tensor], batch_index: int): + return self.shared_step(batch, batch_index=batch_index, phase="train") + + def validation_step(self, batch: tuple[Tensor, Tensor], batch_index: int): + return self.shared_step(batch, batch_index=batch_index, phase="val") + + def test_step(self, batch: tuple[Tensor, Tensor], batch_index: int): + return self.shared_step(batch, batch_index=batch_index, phase="test") + + def shared_step( + self, + batch: tuple[Tensor, Tensor], + batch_index: int, + phase: Literal["train", "val", "test"], + ): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y, reduction="mean") + self.log(f"{phase}/loss", loss.detach().mean()) + acc = logits.detach().argmax(-1).eq(y).float().mean() + self.log(f"{phase}/accuracy", acc) + return {"loss": loss, "logits": logits, "y": y} + + def configure_optimizers(self) -> dict: + """Creates the optimizers and the LR scheduler (if needed).""" + optimizer_partial: functools.partial[Optimizer] + if isinstance(self.hp.optimizer, functools.partial): + optimizer_partial = self.hp.optimizer + else: + optimizer_partial = instantiate(self.hp.optimizer) + optimizer = optimizer_partial(self.parameters()) + optimizers: dict[str, Any] = {"optimizer": optimizer} + + lr_scheduler_partial: functools.partial[_LRScheduler] + if isinstance(self.hp.lr_scheduler, functools.partial): + lr_scheduler_partial = self.hp.lr_scheduler + else: + lr_scheduler_partial = instantiate(self.hp.lr_scheduler) + + if self.hp.lr_scheduler_frequency != 0: + lr_scheduler = lr_scheduler_partial(optimizer) + optimizers["lr_scheduler"] = { + "scheduler": lr_scheduler, + # NOTE: These two keys are ignored if doing manual optimization. + # https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html#learning-rate-scheduling + "interval": self.hp.lr_scheduler_interval, + "frequency": self.hp.lr_scheduler_frequency, + } + return optimizers + + def configure_callbacks(self) -> list[Callback]: + callbacks: list[Callback] = [] + callbacks.append( + # Log some classification metrics. (This callback adds some metrics on this module). + ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes) + ) + if self.hp.lr_scheduler_frequency != 0: + from lightning.pytorch.callbacks import LearningRateMonitor + + callbacks.append(LearningRateMonitor()) + if self.hp.early_stopping_patience != 0: + # If early stopping is enabled, add a Callback for it: + callbacks.append( + EarlyStopping( + "val/accuracy", + mode="max", + patience=self.hp.early_stopping_patience, + verbose=True, + ) + ) + return callbacks + + @property + def device(self) -> torch.device: + """Small fixup for the `device` property in LightningModule, which is CPU by default.""" + if self._device.type == "cpu": + self._device = next((p.device for p in self.parameters()), torch.device("cpu")) + device = self._device + # make this more explicit to always include the index + if device.type == "cuda" and device.index is None: + return torch.device("cuda", index=torch.cuda.current_device()) + return device diff --git a/project/algorithms/example_algo.py b/project/algorithms/example_algo.py deleted file mode 100644 index 9d1d424c..00000000 --- a/project/algorithms/example_algo.py +++ /dev/null @@ -1,142 +0,0 @@ -"""Example of an algorithm, which is a Pytorch Lightning image classifier. - -Uses regular backpropagation. -""" - -from __future__ import annotations - -import dataclasses -import functools -from dataclasses import dataclass -from logging import getLogger -from typing import Any - -import torch -from hydra_zen import instantiate -from lightning.pytorch.callbacks import Callback, EarlyStopping -from torch import Tensor -from torch.nn import functional as F -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler - -from project.algorithms.algorithm import Algorithm -from project.algorithms.callbacks.classification_metrics import ( - ClassificationMetricsCallback, - ClassificationOutputs, -) -from project.configs.algorithm.lr_scheduler import CosineAnnealingLRConfig -from project.configs.algorithm.optimizer import AdamConfig -from project.datamodules.image_classification.image_classification import ( - ImageClassificationDataModule, -) -from project.utils.types import PhaseStr - -logger = getLogger(__name__) - - -class ExampleAlgorithm(Algorithm): - """Example algorithm for image classification.""" - - # TODO: Make this less specific to Image classification once we add other supervised learning - # settings. - - @dataclass - class HParams(Algorithm.HParams): - """Hyper-Parameters of the baseline model.""" - - # Arguments to be passed to the LR scheduler. - lr_scheduler: CosineAnnealingLRConfig = CosineAnnealingLRConfig(T_max=85, eta_min=1e-5) - - lr_scheduler_interval: str = "epoch" - - # Frequency of the LR scheduler. Set to 0 to disable the lr scheduler. - lr_scheduler_frequency: int = 1 - - # Max number of training epochs in total. - max_epochs: int = 90 - - # Hyper-parameters for the forward optimizer - # BUG: seems to be reproducible given a seed when using SGD, but not when using Adam! - optimizer: AdamConfig = AdamConfig(lr=3e-4) - - # batch size - batch_size: int = 128 - - # Max number of epochs to train for without an improvement to the validation - # accuracy before the training is stopped. - early_stopping_patience: int = 0 - - def __init__( - self, - datamodule: ImageClassificationDataModule, - network: torch.nn.Module, - hp: ExampleAlgorithm.HParams | None = None, - ): - super().__init__() - self.datamodule = datamodule - self.network = network - self.hp = hp or self.HParams() - - self.automatic_optimization = True - - # Used by PL to compute the input/output shapes of the network. - self.example_input_array = torch.zeros( - (datamodule.batch_size, *datamodule.dims), device=self.device - ) - # Initialize any lazy weights. Necessary for distributed training and to infer shapes. - _ = self.network(self.example_input_array) - # Save hyper-parameters. - self.save_hyperparameters({"hp": dataclasses.asdict(self.hp)}) - - def forward(self, input: Tensor) -> Tensor: - logits = self.network(input) - return logits - - def shared_step( - self, - batch: tuple[Tensor, Tensor], - batch_index: int, - phase: PhaseStr, - ) -> ClassificationOutputs: - x, y = batch - logits = self(x) - loss = F.cross_entropy(logits, y, reduction="mean") - self.log(f"{phase}/loss", loss.detach().mean()) - acc = logits.detach().argmax(-1).eq(y).float().mean() - self.log(f"{phase}/accuracy", acc) - return {"loss": loss, "logits": logits, "y": y} - - def configure_optimizers(self) -> dict: - """Creates the optimizers and the LR schedulers (if needed).""" - optimizer_partial: functools.partial[Optimizer] = instantiate(self.hp.optimizer) - lr_scheduler_partial: functools.partial[_LRScheduler] = instantiate(self.hp.lr_scheduler) - optimizer = optimizer_partial(self.parameters()) - - optimizers: dict[str, Any] = {"optimizer": optimizer} - - if self.hp.lr_scheduler_frequency != 0: - lr_scheduler = lr_scheduler_partial(optimizer) - optimizers["lr_scheduler"] = { - "scheduler": lr_scheduler, - # NOTE: These two keys are ignored if doing manual optimization. - # https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html#learning-rate-scheduling - "interval": self.hp.lr_scheduler_interval, - "frequency": self.hp.lr_scheduler_frequency, - } - return optimizers - - def configure_callbacks(self) -> list[Callback]: - callbacks: list[Callback] = [ - ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes) - ] - if self.hp.early_stopping_patience != 0: - # If early stopping is enabled, add a PL Callback for it: - callbacks.append( - EarlyStopping( - "val/accuracy", - mode="max", - patience=self.hp.early_stopping_patience, - verbose=True, - ) - ) - return callbacks diff --git a/project/algorithms/example_algo_test.py b/project/algorithms/example_test.py similarity index 64% rename from project/algorithms/example_algo_test.py rename to project/algorithms/example_test.py index 52b624b0..bf4dcc4b 100644 --- a/project/algorithms/example_algo_test.py +++ b/project/algorithms/example_test.py @@ -2,13 +2,12 @@ import torch -from project.algorithms.classification_tests import ClassificationAlgorithmTests +from project.algorithms.testsuites.classification_tests import ClassificationAlgorithmTests -from .example_algo import ExampleAlgorithm +from .example import ExampleAlgorithm class TestExampleAlgorithm(ClassificationAlgorithmTests[ExampleAlgorithm]): algorithm_type = ExampleAlgorithm - algorithm_name: str = "example_algo" unsupported_datamodule_names: ClassVar[list[str]] = ["rl"] _supported_network_types: ClassVar[list[type]] = [torch.nn.Module] diff --git a/project/algorithms/jax_algo.py b/project/algorithms/jax_example.py similarity index 78% rename from project/algorithms/jax_algo.py rename to project/algorithms/jax_example.py index 97becd04..257a24c7 100644 --- a/project/algorithms/jax_algo.py +++ b/project/algorithms/jax_example.py @@ -10,17 +10,15 @@ import rich.logging import torch import torch.distributed -from lightning import Callback, Trainer +from lightning import Callback, LightningModule, Trainer from torch_jax_interop import WrappedJaxFunction, torch_to_jax -from project.algorithms.algorithm import Algorithm from project.algorithms.callbacks.classification_metrics import ClassificationMetricsCallback from project.algorithms.callbacks.samples_per_second import MeasureSamplesPerSecondCallback from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) from project.datamodules.image_classification.mnist import MNISTDataModule -from project.utils.types import PhaseStr from project.utils.types.protocols import ClassificationDataModule os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" @@ -75,15 +73,17 @@ def _parameter_to_jax_array(value: torch.nn.Parameter) -> jax.Array: return torch_to_jax(value.data) -class JaxAlgorithm(Algorithm): - """Example of an algorithm that uses Jax. +class JaxExample(LightningModule): + """Example of a learning algorithm (`LightningModule`) that uses Jax. In this case, the network is a flax.linen.Module, and its forward and backward passes are - written in Jax. + written in Jax, and the loss function is in pytorch. """ - @dataclasses.dataclass - class HParams(Algorithm.HParams): + @dataclasses.dataclass(frozen=True) + class HParams: + """Hyper-parameters of the algo.""" + lr: float = 1e-3 seed: int = 123 debug: bool = True @@ -93,10 +93,11 @@ def __init__( *, network: flax.linen.Module, datamodule: ImageClassificationDataModule, - hp: HParams | None = None, + hp: HParams = HParams(), ): - super().__init__(datamodule=datamodule) - self.hp: JaxAlgorithm.HParams = hp or self.HParams() + super().__init__() + self.datamodule = datamodule + self.hp = hp or self.HParams() example_input = torch.zeros( (datamodule.batch_size, *datamodule.dims), @@ -117,8 +118,24 @@ def __init__( self.example_input_array = example_input + def forward(self, input: torch.Tensor) -> torch.Tensor: + logits = self.network(input) + return logits + + def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int): + return self.shared_step(batch, batch_index=batch_index, phase="train") + + def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int): + return self.shared_step(batch, batch_index=batch_index, phase="val") + + def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int): + return self.shared_step(batch, batch_index=batch_index, phase="test") + def shared_step( - self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int, phase: PhaseStr + self, + batch: tuple[torch.Tensor, torch.Tensor], + batch_index: int, + phase: Literal["train", "val", "test"], ): x, y = batch assert not x.requires_grad @@ -137,11 +154,22 @@ def configure_optimizers(self): def configure_callbacks(self) -> list[Callback]: assert isinstance(self.datamodule, ClassificationDataModule) - return super().configure_callbacks() + [ + return [ MeasureSamplesPerSecondCallback(), ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes), ] + @property + def device(self) -> torch.device: + """Small fixup for the `device` property in LightningModule, which is CPU by default.""" + if self._device.type == "cpu": + self._device = next((p.device for p in self.parameters()), torch.device("cpu")) + device = self._device + # make this more explicit to always include the index + if device.type == "cuda" and device.index is None: + return torch.device("cuda", index=torch.cuda.current_device()) + return device + def is_channels_first(shape: tuple[int, ...]) -> bool: if len(shape) == 4: @@ -206,7 +234,7 @@ def main(): datamodule = MNISTDataModule(num_workers=4, batch_size=512) network = CNN(num_classes=datamodule.num_classes) - model = JaxAlgorithm(network=network, datamodule=datamodule) + model = JaxExample(network=network, datamodule=datamodule) trainer.fit(model, datamodule=datamodule) ... diff --git a/project/algorithms/jax_algo_test.py b/project/algorithms/jax_example_test.py similarity index 58% rename from project/algorithms/jax_algo_test.py rename to project/algorithms/jax_example_test.py index ed280101..f8ac8cac 100644 --- a/project/algorithms/jax_algo_test.py +++ b/project/algorithms/jax_example_test.py @@ -4,14 +4,13 @@ import flax.linen import torch -from project.algorithms.jax_algo import JaxAlgorithm +from project.algorithms.jax_example import JaxExample -from .algorithm_test import AlgorithmTests +from .testsuites.algorithm_tests import AlgorithmTests -class TestJaxAlgorithm(AlgorithmTests[JaxAlgorithm]): +class TestJaxExample(AlgorithmTests[JaxExample]): """This algorithm only works with Jax modules.""" - algorithm_name: ClassVar[str] = "jax_algo" unsupported_network_types: ClassVar[list[type]] = [torch.nn.Module] _supported_network_types: ClassVar[list[type]] = [flax.linen.Module] diff --git a/project/algorithms/manual_optimization_example.py b/project/algorithms/manual_optimization_example.py deleted file mode 100644 index 317a6e1b..00000000 --- a/project/algorithms/manual_optimization_example.py +++ /dev/null @@ -1,112 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass - -import torch -from torch import Tensor, nn - -from project.algorithms.algorithm import Algorithm -from project.algorithms.callbacks.classification_metrics import ( - ClassificationMetricsCallback, - ClassificationOutputs, -) -from project.datamodules.image_classification.image_classification import ( - ImageClassificationDataModule, -) -from project.utils.types import PhaseStr - - -class ManualGradientsExample(Algorithm): - """Example of an algorithm that calculates the gradients manually instead of having PL do the - backward pass.""" - - @dataclass - class HParams(Algorithm.HParams): - """Hyper-parameters of this example algorithm.""" - - lr: float = 0.1 - - gradient_noise_std: float = 0.01 - """Standard deviation of the Gaussian noise added to the gradients.""" - - def __init__( - self, - datamodule: ImageClassificationDataModule, - network: nn.Module, - hp: ManualGradientsExample.HParams | None = None, - ): - super().__init__() - self.datamodule = datamodule - self.network = network - self.hp = hp or self.HParams() - # Just to let the type checker know the right type. - self.hp: ManualGradientsExample.HParams - - # Setting this to False tells PL that we will be calculating the gradients manually. - # This turns off a few nice things in PL that we might not care about here, such as - # easy multi-gpu / multi-node / TPU / mixed precision training. - self.automatic_optimization = False - - # Instantiate any lazy weights with a dummy forward pass (optional). - self.example_input_array = torch.zeros( - (datamodule.batch_size, *datamodule.dims), device=self.device - ) - self.network(self.example_input_array) - - def forward(self, x: Tensor) -> Tensor: - return self.network(x) - - def training_step( - self, batch: tuple[Tensor, Tensor], batch_index: int - ) -> ClassificationOutputs: - return self.shared_step(batch, batch_index, "train") - - def validation_step( - self, batch: tuple[Tensor, Tensor], batch_index: int - ) -> ClassificationOutputs: - return self.shared_step(batch, batch_index, "val") - - def shared_step( - self, batch: tuple[Tensor, Tensor], batch_index: int, phase: PhaseStr - ) -> ClassificationOutputs: - """Performs a training/validation/test step.""" - x, y = batch - logits = self(x) - - loss = torch.nn.functional.cross_entropy(logits, y) - - if phase == "train": - # We don't care about AMP, TPUs, gradient accumulation here, so just get the "real" - # optimizers instead of the PL optimizer wrappers: - optimizers = self.optimizers(use_pl_optimizer=False) - # We only have one optimizer in this example. Otherwise we'd have a list here. - assert not isinstance(optimizers, list) - optimizer = optimizers - - optimizer.zero_grad() - - # NOTE: Whenever possible, if you have a simple "loss" tensor, use `manual_backward`. - # self.manual_backward(loss) - # However, in this example here, it still works even if we manipulate the grads - # directly. We're also not training on multiple GPUs, which makes this easier. - - # NOTE: You don't need to call `loss.backward()`, you could also just set .grads - # directly! - self.manual_backward(loss) - - for name, parameter in self.named_parameters(): - assert parameter.grad is not None, name - parameter.grad += self.hp.gradient_noise_std * torch.randn_like(parameter.grad) - - optimizer.step() - - return {"y": y, "logits": logits, "loss": loss.detach()} - - def configure_optimizers(self): - """Creates the optimizer(s) and learning rate scheduler(s).""" - return torch.optim.SGD(self.parameters(), lr=self.hp.lr) - - def configure_callbacks(self): - return super().configure_callbacks() + [ - ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes) - ] diff --git a/project/algorithms/manual_optimization_example_test.py b/project/algorithms/manual_optimization_example_test.py deleted file mode 100644 index 1185e153..00000000 --- a/project/algorithms/manual_optimization_example_test.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import ClassVar - -import torch - -from project.algorithms.classification_tests import ClassificationAlgorithmTests -from project.datamodules.vision import VisionDataModule - -from .manual_optimization_example import ManualGradientsExample - - -class TestManualOptimizationExample(ClassificationAlgorithmTests[ManualGradientsExample]): - algorithm_type = ManualGradientsExample - algorithm_name: str = "manual_optimization" - - _supported_datamodule_types: ClassVar[list[type]] = [VisionDataModule] - _supported_network_types: ClassVar[list[type]] = [torch.nn.Module] diff --git a/project/algorithms/no_op.py b/project/algorithms/no_op.py index 029f8b16..11bfb898 100644 --- a/project/algorithms/no_op.py +++ b/project/algorithms/no_op.py @@ -1,36 +1,44 @@ -from typing import Any +from typing import Any, Literal import torch -from lightning import Callback +from lightning import Callback, LightningModule from torch import nn -from project.algorithms.algorithm import Algorithm, StepOutputDict from project.algorithms.callbacks.samples_per_second import MeasureSamplesPerSecondCallback -from project.utils.types import PhaseStr from project.utils.types.protocols import DataModule -class NoOp(Algorithm): +class NoOp(LightningModule): """No-op algorithm that does no learning and is used to benchmark the dataloading speed.""" def __init__(self, datamodule: DataModule, network: nn.Module): - super().__init__(datamodule=datamodule, network=network) + super().__init__() + self.datamodule = datamodule + self.network = network # Set this so PyTorch-Lightning doesn't try to train the model using our 'loss' self.automatic_optimization = False - self.last_step_times: dict[PhaseStr, float] = {} + + def training_step(self, batch: Any, batch_index: int): + return self.shared_step(batch, batch_index, "train") + + def validation_step(self, batch: Any, batch_index: int): + return self.shared_step(batch, batch_index, "val") + + def test_step(self, batch: Any, batch_index: int): + return self.shared_step(batch, batch_index, "test") def shared_step( self, batch: Any, batch_index: int, - phase: PhaseStr, - ) -> StepOutputDict: + phase: Literal["train", "val", "test"], + ): fake_loss = torch.rand(1) self.log(f"{phase}/loss", fake_loss) - return {"loss": fake_loss} + return fake_loss def configure_callbacks(self) -> list[Callback]: - return super().configure_callbacks() + [MeasureSamplesPerSecondCallback()] + return [MeasureSamplesPerSecondCallback()] def configure_optimizers(self): return torch.optim.SGD(self.parameters(), lr=0.123) diff --git a/project/algorithms/algorithm.py b/project/algorithms/testsuites/algorithm.py similarity index 52% rename from project/algorithms/algorithm.py rename to project/algorithms/testsuites/algorithm.py index 71e98dbc..bd612992 100644 --- a/project/algorithms/algorithm.py +++ b/project/algorithms/testsuites/algorithm.py @@ -1,16 +1,14 @@ -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import NotRequired, TypedDict +from typing import Literal, NotRequired, Protocol, TypedDict import torch from lightning import Callback, LightningModule, Trainer from torch import Tensor -from typing_extensions import Generic, TypeVar # noqa +from typing_extensions import TypeVar from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) -from project.utils.types import PhaseStr, PyTree +from project.utils.types import PyTree from project.utils.types.protocols import DataModule, Module @@ -23,43 +21,36 @@ class StepOutputDict(TypedDict, total=False): BatchType = TypeVar("BatchType", bound=PyTree[torch.Tensor], contravariant=True) -# StepOutputType = TypeVar( -# "StepOutputType", bound=StepOutputDict | PyTree[torch.Tensor], covariant=True -# ) -StepOutputType = TypeVar( - "StepOutputType", - bound=torch.Tensor | StepOutputDict, - default=StepOutputDict, - covariant=True, -) +StepOutputType = TypeVar("StepOutputType", bound=StepOutputDict, covariant=True) -class Algorithm(LightningModule, ABC, Generic[BatchType, StepOutputType]): - """Base class for a learning algorithm. +class Algorithm(Module, Protocol[BatchType, StepOutputType]): + """Protocol that adds more type information to the `lightning.LightningModule` class. - This is an extension of the LightningModule class from PyTorch Lightning, with some common - boilerplate code to keep the algorithm implementations as simple as possible. + This adds some type information on top of the LightningModule class, namely: + - `BatchType`: The type of batch that is produced by the dataloaders of the datamodule + - `StepOutputType`, the output type created by the step methods. The networks themselves are created separately and passed as a constructor argument. This is meant to make it easier to compare different learning algorithms on the same network architecture. """ - @dataclass - class HParams: - """Hyper-parameters of the algorithm.""" + datamodule: DataModule[BatchType] + network: Module + + example_input_array = LightningModule.example_input_array + _device: torch.device | None = None def __init__( self, *, - datamodule: DataModule[BatchType] | None = None, - network: Module | None = None, - hp: HParams | None = None, + datamodule: DataModule[BatchType], + network: Module, ): super().__init__() self.datamodule = datamodule self.network = network - self.hp = hp or self.HParams() # fix for `self.device` property which defaults to cpu. self._device = None @@ -82,7 +73,9 @@ def test_step(self, batch: BatchType, batch_index: int) -> StepOutputType: """Performs a test step.""" return self.shared_step(batch=batch, batch_index=batch_index, phase="test") - def shared_step(self, batch: BatchType, batch_index: int, phase: PhaseStr) -> StepOutputType: + def shared_step( + self, batch: BatchType, batch_index: int, phase: Literal["train", "val", "test"] + ) -> StepOutputType: """Performs a training/validation/test step. This must return a nested dictionary of tensors matching the `StepOutputType` typedict for @@ -94,11 +87,9 @@ def shared_step(self, batch: BatchType, batch_index: int, phase: PhaseStr) -> St """ raise NotImplementedError - @abstractmethod def configure_optimizers(self): # """Creates the optimizers and the learning rate schedulers."""' - # super().configure_optimizers() - ... + raise NotImplementedError def forward(self, x: Tensor) -> Tensor: """Performs a forward pass. @@ -108,47 +99,6 @@ def forward(self, x: Tensor) -> Tensor: assert self.network is not None return self.network(x) - def training_step_end(self, step_output: StepOutputDict) -> StepOutputDict: - """Called with the results of each worker / replica's output. - - See the `training_step_end` of pytorch-lightning for more info. - """ - return self.shared_step_end(step_output, phase="train") - - def validation_step_end[Out: torch.Tensor | StepOutputDict](self, step_output: Out) -> Out: - return self.shared_step_end(step_output, phase="val") - - def test_step_end[Out: torch.Tensor | StepOutputDict](self, step_output: Out) -> Out: - return self.shared_step_end(step_output, phase="test") - - def shared_step_end[Out: torch.Tensor | StepOutputDict]( - self, step_output: Out, phase: PhaseStr - ) -> Out: - """This is a default implementation for `[train/validation/test]_step_end`. - - This does the following: - - Averages out the `loss` tensor if it was left unreduced. - - the main metrics are logged inside `training_step_end` (supposed to be better for DP/DDP) - """ - - if ( - isinstance(step_output, dict) - and isinstance((loss := step_output.get("loss")), torch.Tensor) - and loss.shape - ): - # Replace the loss with its mean. This is useful when automatic - # optimization is enabled, for example in the example algo, where each replica - # returns the un-reduced cross-entropy loss. Here we need to reduce it to a scalar. - # todo: find out if this was already logged, to not log it twice. - # self.log(f"{phase}/loss", loss.mean(), sync_dist=True) - return step_output | {"loss": loss.mean()} - - elif isinstance(step_output, torch.Tensor) and (loss := step_output).shape: - return loss.mean() - - # self.log(f"{phase}/loss", torch.as_tensor(loss).mean(), sync_dist=True) - return step_output - def configure_callbacks(self) -> list[Callback]: """Use this to add some callbacks that should always be included with the model.""" return [] diff --git a/project/algorithms/algorithm_test.py b/project/algorithms/testsuites/algorithm_tests.py similarity index 93% rename from project/algorithms/algorithm_test.py rename to project/algorithms/testsuites/algorithm_tests.py index 4983b026..f45e52f8 100644 --- a/project/algorithms/algorithm_test.py +++ b/project/algorithms/testsuites/algorithm_tests.py @@ -8,7 +8,7 @@ from collections.abc import Callable, Sequence from logging import getLogger as get_logger from pathlib import Path -from typing import Any, ClassVar, Literal +from typing import Any, ClassVar, Generic, Literal, TypeVar import pytest import torch @@ -18,17 +18,16 @@ from tensor_regression import TensorRegressionFixture from torch import Tensor, nn from torch.utils.data import DataLoader -from typing_extensions import ParamSpec -from project.algorithms.algorithm import Algorithm from project.algorithms.callbacks.callback import Callback +from project.algorithms.testsuites.algorithm import Algorithm from project.configs import Config, cs from project.conftest import setup_hydra_for_tests_and_compose from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) -from project.datamodules.vision import VisionDataModule from project.experiment import ( + instantiate_algorithm, instantiate_datamodule, instantiate_network, ) @@ -45,14 +44,15 @@ from project.utils.types.protocols import DataModule logger = get_logger(__name__) -P = ParamSpec("P") SKIP_OR_XFAIL = pytest.xfail if "-vvv" in sys.argv else pytest.skip """Either skips the test entirely (default) or tries to run it and expect it to fail (slower).""" +AlgorithmType = TypeVar("AlgorithmType", bound=Algorithm) -class AlgorithmTests[AlgorithmType: Algorithm]: + +class AlgorithmTests(Generic[AlgorithmType]): """Unit tests for an algorithm class. The algorithm creation is parametrized with all the datasets and all the networks, but the @@ -70,7 +70,11 @@ class AlgorithmTests[AlgorithmType: Algorithm]: """ algorithm_type: type[AlgorithmType] - algorithm_name: ClassVar[str] + algorithm_config_name: ClassVar[str | None] = None + """Name of the config to use for the algorithm. + + Defaults to the algorithm class name. + """ unsupported_datamodule_names: ClassVar[list[str]] = [] unsupported_datamodule_types: ClassVar[list[type[DataModule]]] = [] @@ -171,7 +175,7 @@ def test_overfit_training_batch( testing_callbacks=testing_callbacks, ) - def _train( + def _train[**P]( self, algorithm: AlgorithmType, tmp_path: Path, @@ -241,7 +245,7 @@ def test_experiment_reproducible_given_seed( if "resnet" in network_name and datamodule_name in ["mnist", "fashion_mnist"]: pytest.skip(reason="ResNet's can't be used on MNIST datasets.") - algorithm_name = self.algorithm_name or self.algorithm_cls.__name__.lower() + algorithm_name = self.algorithm_config_name or self.algorithm_cls.__name__ assert isinstance(algorithm_name, str) assert isinstance(datamodule_name, str) assert isinstance(network_name, str) @@ -333,8 +337,7 @@ def _hydra_config( if "resnet" in network_name and datamodule_name in ["mnist", "fashion_mnist"]: pytest.skip(reason="ResNet's can't be used on MNIST datasets.") - # todo: Get the name of the algorithm from the hydra config? - algorithm_name = self.algorithm_name + algorithm_name = self.algorithm_config_name or self.algorithm_type.__name__ combination = set([datamodule_name, network_name, algorithm_name]) for configs, marks in default_marks_for_config_combinations.items(): @@ -401,28 +404,33 @@ def network( network = network.to(device=device) return network - @pytest.fixture(scope="class") - def hp(self, experiment_config: Config) -> Algorithm.HParams: # type: ignore - """The hyperparameters for the algorithm. + # @pytest.fixture(scope="class") + # def hp(self, experiment_config: Config) -> Algorithm.HParams: # type: ignore + # """The hyperparameters for the algorithm. - NOTE: This should ideally be parametrized to test different hyperparameter settings. - """ - return experiment_config.algorithm - # return self.algorithm_cls.HParams() + # NOTE: This should ideally be parametrized to test different hyperparameter settings. + # """ + # return experiment_config.algorithm + # return self.algorithm_cls.HParams() - @pytest.fixture(scope="function") - def algorithm_kwargs( - self, datamodule: VisionDataModule, network: nn.Module, hp: Algorithm.HParams - ): - """Fixture that gives the keyword arguments to use to create the algorithm. + # @pytest.fixture(scope="function") + # def algorithm_kwargs( + # self, datamodule: VisionDataModule, network: nn.Module, hp: Algorithm.HParams + # ): + # """Fixture that gives the keyword arguments to use to create the algorithm. - NOTE: This should be further parametrized by base classes as needed. - """ - return dict(datamodule=datamodule, network=copy.deepcopy(network), hp=hp) + # NOTE: This should be further parametrized by base classes as needed. + # """ + # return dict(datamodule=datamodule, network=copy.deepcopy(network), hp=hp) @pytest.fixture(scope="function") - def algorithm(self, algorithm_kwargs: dict) -> AlgorithmType: - return self.algorithm_cls(**algorithm_kwargs) + def algorithm( + self, experiment_config: Config, datamodule: DataModule, network: nn.Module + ) -> AlgorithmType: + algo = instantiate_algorithm( + experiment_config, datamodule=datamodule, network=copy.deepcopy(network) + ) + return algo @property def algorithm_cls(self) -> type[AlgorithmType]: @@ -445,7 +453,9 @@ def _algorithm_cls(cls) -> type[AlgorithmType]: from typing import get_args class_under_test = get_args(cls.__orig_bases__[0])[0] # type: ignore - if not (inspect.isclass(class_under_test) and issubclass(class_under_test, Algorithm)): + if not ( + inspect.isclass(class_under_test) and issubclass(class_under_test, LightningModule) + ): raise RuntimeError( "Your test class needs to pass the class under test to the generic base class.\n" "for example: `class TestMyAlgorithm(AlgorithmTests[MyAlgorithm]):`\n" @@ -655,7 +665,7 @@ def on_after_backward(self, trainer: Trainer, pl_module: LightningModule) -> Non def on_train_batch_end( self, trainer: Trainer, - pl_module: Algorithm, + pl_module: LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_index: int, diff --git a/project/algorithms/classification_tests.py b/project/algorithms/testsuites/classification_tests.py similarity index 88% rename from project/algorithms/classification_tests.py rename to project/algorithms/testsuites/classification_tests.py index 3833311a..e74fa84b 100644 --- a/project/algorithms/classification_tests.py +++ b/project/algorithms/testsuites/classification_tests.py @@ -1,3 +1,17 @@ +"""Suite of example tests for classification algorithms. + +You can use this as a template to create tests for your own algorithm by inheriting from the class: + +```python + +from project.algorithms.classification_tests import ClassificationAlgorithmTests + +class TestMyAlgorithm(ClassificationAlgorithmTests[MyAlgorithm]): + algorithm_type = MyAlgorithm + algorithm_name: str = "my_algo" # name of your algorithm's config. +``` +""" + from __future__ import annotations from pathlib import Path @@ -8,13 +22,13 @@ from torch import Tensor, nn from torch.utils.data import DataLoader, TensorDataset -from project.algorithms.algorithm import Algorithm -from project.algorithms.algorithm_test import ( +from project.algorithms.callbacks.classification_metrics import ClassificationOutputs +from project.algorithms.testsuites.algorithm import Algorithm +from project.algorithms.testsuites.algorithm_tests import ( AlgorithmTests, CheckBatchesAreTheSameAtEachStep, MetricShouldImprove, ) -from project.algorithms.callbacks.classification_metrics import ClassificationOutputs from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) @@ -35,7 +49,8 @@ class ClassificationAlgorithmTests[ unsupported_network_types: ClassVar[list[type[nn.Module]]] = [] _supported_datamodule_types: ClassVar[list[type[ClassificationDataModule]]] = [ # VisionDataModule, - ClassificationDataModule, # type: ignore (we actually support this case). + # (we actually support this case). + ClassificationDataModule, # type: ignore # ImageClassificationDataModule, ] diff --git a/project/configs/__init__.py b/project/configs/__init__.py index 13e6a8d7..bbd0f684 100644 --- a/project/configs/__init__.py +++ b/project/configs/__init__.py @@ -2,17 +2,22 @@ from hydra.core.config_store import ConfigStore -from ..utils.env_vars import REPO_ROOTDIR, SLURM_JOB_ID, SLURM_TMPDIR -from .config import Config -from .datamodule import ( - datamodule_store, -) -from .network import network_store +from project.configs.algorithm import register_algorithm_configs +from project.configs.config import Config +from project.configs.datamodule import datamodule_store +from project.configs.network import network_store +from project.utils.env_vars import REPO_ROOTDIR, SLURM_JOB_ID, SLURM_TMPDIR cs = ConfigStore.instance() cs.store(name="base_config", node=Config) -datamodule_store.add_to_hydra_store() -network_store.add_to_hydra_store() + + +def add_configs_to_hydra_store(with_dynamic_configs: bool = True): + datamodule_store.add_to_hydra_store() + network_store.add_to_hydra_store() + register_algorithm_configs(with_dynamic_configs=with_dynamic_configs) + + # todo: move the algorithm_store.add_to_hydra_store() here? __all__ = [ diff --git a/project/configs/algorithm/__init__.py b/project/configs/algorithm/__init__.py index e69de29b..572f3167 100644 --- a/project/configs/algorithm/__init__.py +++ b/project/configs/algorithm/__init__.py @@ -0,0 +1,57 @@ +from hydra_zen import make_custom_builds_fn, store +from lightning import LightningModule + +from project.utils.hydra_utils import make_config_and_store + +from .lr_scheduler import add_configs_for_all_torch_schedulers, lr_scheduler_store +from .optimizer import add_configs_for_all_torch_optimizers, optimizers_store + +build_algorithm_config_fn = make_custom_builds_fn( + zen_partial=True, + populate_full_signature=True, + zen_exclude=["datamodule", "network"], +) + + +# NOTE: This works the same way as creating config files for each algorithm under +# `configs/algorithm`. From the command-line, you can select both configs that are yaml files as +# well as structured config (dataclasses). + +# If you add a configuration file under `project/configs/algorithm`, it will also be available as an option +# from the command-line, and can use these configs in their default list. +algorithm_store = store(group="algorithm") + + +def register_algorithm_configs(with_dynamic_configs: bool = True): + if with_dynamic_configs: + add_configs_for_all_torch_optimizers() + add_configs_for_all_torch_schedulers() + + optimizers_store.add_to_hydra_store() + lr_scheduler_store.add_to_hydra_store() + + import inspect + + # Note: import algorithms here to avoid circular import errors. + import project.algorithms + + for algo_name, algo_class in [ + (k, v) + for (k, v) in vars(project.algorithms).items() + if inspect.isclass(v) and issubclass(v, LightningModule) + ]: + make_config_and_store( + algo_class, store=algorithm_store, zen_exclude=["datamodule", "network"] + ) + # config_class_name = f"{algo_name}Config" + # config_class = build_algorithm_config_fn( + # algo_class, zen_dataclass={"cls_name": config_class_name} + # ) + # algorithm_store(config_class, name=algo_name) + + # from project.algorithms import ExampleAlgorithm, JaxExample, NoOp + # algorithm_store(build_algorithm_config_fn(ExampleAlgorithm), name="example") + # algorithm_store(build_algorithm_config_fn(NoOp), name="no_op") + # algorithm_store(build_algorithm_config_fn(JaxExample), name="jax_example") + + algorithm_store.add_to_hydra_store() diff --git a/project/configs/algorithm/example_from_config.yaml b/project/configs/algorithm/example_from_config.yaml new file mode 100644 index 00000000..f79db6c6 --- /dev/null +++ b/project/configs/algorithm/example_from_config.yaml @@ -0,0 +1,17 @@ +defaults: + # Use the example as a schema for this config, and inherit its default values. + # BUG: This doesn't work when the lr scheduler or optimizer types change from their defaults, + # because OmegaConf seems to use the type of the value (not of the field?) when merging configs. + # - ExampleAlgorithm + + # Use a custom config for the Adam optimizer (optimizer/custom_adam.yaml) at hp.optimizer + - optimizer/custom_adam@hp.optimizer + # Apply the config for the StepLR learning rate scheduler at `hp.lr_scheduler` in this config. + - lr_scheduler/StepLR@hp.lr_scheduler + +_target_: project.algorithms.example.ExampleAlgorithm +_partial_: true +hp: + _target_: project.algorithms.example.ExampleAlgorithm.HParams + lr_scheduler: + step_size: 5 # Required argument for the StepLR scheduler. diff --git a/project/configs/algorithm/lr_scheduler/__init__.py b/project/configs/algorithm/lr_scheduler/__init__.py index 67180aa1..54bfd7a6 100644 --- a/project/configs/algorithm/lr_scheduler/__init__.py +++ b/project/configs/algorithm/lr_scheduler/__init__.py @@ -1,33 +1,91 @@ -from typing import TypeVar +"""Configs for learning rate schedulers from `torch.optim.lr_scheduler`. +These configurations are created dynamically using [hydra-zen.builds](https://mit-ll-responsible-ai.github.io/hydra-zen/generated/hydra_zen.builds.html#). +""" + +import inspect +from logging import getLogger as get_logger + +import hydra_zen +import torch import torch.optim.lr_scheduler -from hydra_zen import hydrated_dataclass -from hydra_zen.typing._implementations import PartialBuilds +from hydra_zen.typing import PartialBuilds + +from project.utils.hydra_utils import make_config_and_store + +_logger = get_logger(__name__) + +_LR_SCHEDULER_GROUP = "algorithm/lr_scheduler" +lr_scheduler_store = hydra_zen.ZenStore(name="schedulers") + +lr_scheduler_group_store = lr_scheduler_store(group=_LR_SCHEDULER_GROUP) + +# Some LR Schedulers have constructors with arguments without a default value (in addition to optimizer). +# In this case, we specify the missing arguments here so we get a nice error message if it isn't passed. + +CosineAnnealingLRConfig = make_config_and_store( + torch.optim.lr_scheduler.CosineAnnealingLR, + T_max="???", + store=lr_scheduler_group_store, +) + +StepLRConfig = make_config_and_store( + torch.optim.lr_scheduler.StepLR, + step_size="???", + store=lr_scheduler_group_store, +) + + +def add_configs_for_all_torch_schedulers(): + """Generates configuration dataclasses for all torch.optim.lr_scheduler classes that are not + already configured. + + Registers the configs using the `make_config_and_store` function. + """ + configured_schedulers = [ + hydra_zen.get_target(config) for config in get_all_scheduler_configs() + ] + missing_torch_schedulers = { + _name: _scheduler_type + for _name, _scheduler_type in vars(torch.optim.lr_scheduler).items() + if inspect.isclass(_scheduler_type) + and issubclass(_scheduler_type, torch.optim.lr_scheduler.LRScheduler) + and _scheduler_type + not in (torch.optim.lr_scheduler.LRScheduler, torch.optim.lr_scheduler._LRScheduler) + and _scheduler_type not in configured_schedulers + } + for scheduler_name, scheduler_type in missing_torch_schedulers.items(): + _logger.warning(f"Making a config for {scheduler_type=}") + _config = make_config_and_store(scheduler_type, store=lr_scheduler_group_store) -LRSchedulerType = TypeVar("LRSchedulerType", bound=torch.optim.lr_scheduler._LRScheduler) -# TODO: Double-check this, but defining LRSchedulerConfig like this makes it unusable as a type -# annotation on the hparams, since omegaconf will complain it isn't a valid base class. -type LRSchedulerConfig[LRSchedulerType: torch.optim.lr_scheduler._LRScheduler] = PartialBuilds[ - LRSchedulerType -] +def get_all_config_names() -> list[str]: + return sorted( + [config_name for (_group, config_name) in lr_scheduler_store[_LR_SCHEDULER_GROUP].keys()] + ) -# TODO: getting doctest issues here? -@hydrated_dataclass(target=torch.optim.lr_scheduler.StepLR, zen_partial=True) -class StepLRConfig: - """Config for the StepLR Scheduler.""" - step_size: int = 30 - gamma: float = 0.1 - last_epoch: int = -1 - verbose: bool = False +def get_all_scheduler_configs() -> list[type[PartialBuilds[torch.optim.lr_scheduler.LRScheduler]]]: + return list(lr_scheduler_store[_LR_SCHEDULER_GROUP].values()) -@hydrated_dataclass(target=torch.optim.lr_scheduler.CosineAnnealingLR, zen_partial=True) -class CosineAnnealingLRConfig: - """Config for the CosineAnnealingLR Scheduler.""" +# def __getattr__(config_name: str) -> type[PartialBuilds[torch.optim.lr_scheduler.LRScheduler]]: +# """Get the dynamically generated LR scheduler config with the given name.""" +# if config_name in globals(): +# return globals()[config_name] +# if not config_name.endswith("Config"): +# raise AttributeError +# scheduler_name = config_name.removesuffix("Config") +# # the keys for the config store are tuples of the form (group, config_name) +# store_key = (_LR_SCHEDULER_GROUP, scheduler_name) +# if store_key in lr_scheduler_store[_LR_SCHEDULER_GROUP]: +# _logger.debug(f"Dynamically retrieving the config for {scheduler_name!r}") +# return lr_scheduler_store[store_key] +# available_configs = sorted( +# config_name for (_group, config_name) in lr_scheduler_store[_LR_SCHEDULER_GROUP].keys() +# ) +# _logger.error( +# f"Unable to find the config for {scheduler_name=}. Available configs: {available_configs}." +# ) - T_max: int = 85 - eta_min: float = 0 - last_epoch: int = -1 - verbose: bool = False +# raise AttributeError diff --git a/project/configs/algorithm/optimizer/__init__.py b/project/configs/algorithm/optimizer/__init__.py index 2efcedc0..99900e02 100644 --- a/project/configs/algorithm/optimizer/__init__.py +++ b/project/configs/algorithm/optimizer/__init__.py @@ -1,57 +1,75 @@ -import functools -from collections.abc import Callable -from typing import Any, Protocol, runtime_checkable +import inspect +from logging import getLogger as get_logger +import hydra_zen import torch -from hydra_zen import hydrated_dataclass, instantiate +import torch.optim +from hydra_zen.typing import PartialBuilds -# OptimizerType = TypeVar("OptimizerType", bound=torch.optim.Optimizer, covariant=True) +from project.utils.hydra_utils import make_config_and_store +_OPTIMIZER_GROUP = "algorithm/optimizer" -@runtime_checkable -class OptimizerConfig[OptimizerType: torch.optim.Optimizer](Protocol): - """Configuration for an optimizer. +_logger = get_logger(__name__) - Returns a partial[OptimizerType] when instantiated. - """ - - __call__: Callable[..., functools.partial[OptimizerType]] = instantiate - - -# NOTE: Getting weird bug in omegaconf if I try to make OptimizerConfig generic! -# Instead I'm making it a protocol. - - -@hydrated_dataclass(target=torch.optim.SGD, zen_partial=True) -class SGDConfig: - """Configuration for the SGD optimizer.""" +optimizers_store = hydra_zen.ZenStore(name="optimizers") +optimizers_group_store = optimizers_store(group=_OPTIMIZER_GROUP) - lr: Any - momentum: float = 0.0 - dampening: float = 0.0 - weight_decay: float = 0.0 - nesterov: bool = False +# Create some configs manually so they can get nice type hints when imported. +AdamConfig = make_config_and_store(torch.optim.Adam, store=optimizers_group_store) +SGDConfig = make_config_and_store(torch.optim.SGD, store=optimizers_group_store) - __call__ = instantiate +def add_configs_for_all_torch_optimizers(): + """Generates configuration dataclasses for all `torch.optim.Optimizer` classes that are not + already configured. -# TODO: frozen=True doesn't work here, which is a real bummer. It would have saved us a lot of -# functools.partial(AdamConfig, lr=123) nonsense. -@hydrated_dataclass(target=torch.optim.Adam, zen_partial=True) -class AdamConfig: - """Configuration for the Adam optimizer.""" - - lr: Any = 0.001 - betas: Any = (0.9, 0.999) - eps: float = 1e-08 - weight_decay: float = 0 - amsgrad: bool = False - - __call__ = instantiate - - -# NOTE: we don't add an `optimizer` group, since models could have one or more optimizers. -# Models can register their own groups, e.g. `model/optimizer`. if they want to. -# cs = ConfigStore.instance() -# cs.store(group="optimizer", name="sgd", node=SGDConfig) -# cs.store(group="optimizer", name="adam", node=AdamConfig) + Registers the configs using the `make_config_and_store` function. + """ + configured_schedulers = [ + hydra_zen.get_target(config) for config in get_all_optimizer_configs() + ] + missing_torch_schedulers = { + _name: _optimizer_type + for _name, _optimizer_type in vars(torch.optim).items() + if inspect.isclass(_optimizer_type) + and issubclass(_optimizer_type, torch.optim.Optimizer) + and _optimizer_type is not torch.optim.Optimizer + and _optimizer_type not in configured_schedulers + } + for scheduler_name, scheduler_type in missing_torch_schedulers.items(): + _logger.warning(f"Making a config for {scheduler_type=}") + _config = make_config_and_store(scheduler_type, store=optimizers_group_store) + + +# def __getattr__(config_name: str) -> type[PartialBuilds[torch.optim.Optimizer]]: +# """Get the optimizer config with the given name.""" +# if config_name in globals(): +# return globals()[config_name] + +# if not config_name.endswith("Config"): +# raise AttributeError +# optimizer_name = config_name.removesuffix("Config") +# # the keys for the config store are tuples of the form (group, config_name) +# store_key = (_OPTIMIZER_GROUP, optimizer_name) +# if store_key in optimizer_store[_OPTIMIZER_GROUP]: +# _logger.debug(f"Dynamically retrieving the config for {optimizer_name=}.") +# return optimizer_store[store_key] +# available_optimizers = sorted( +# optimizer_name for (_, optimizer_name) in optimizer_store[_OPTIMIZER_GROUP].keys() +# ) +# _logger.error( +# f"Unable to find the config for optimizer {optimizer_name}. Available optimizers: {available_optimizers}." +# ) + +# raise AttributeError + + +def get_all_config_names() -> list[str]: + return sorted( + [config_name for (_group, config_name) in optimizers_store[_OPTIMIZER_GROUP].keys()] + ) + + +def get_all_optimizer_configs() -> list[type[PartialBuilds[torch.optim.Optimizer]]]: + return list(optimizers_store[_OPTIMIZER_GROUP].values()) diff --git a/project/configs/algorithm/optimizer/adamw.yaml b/project/configs/algorithm/optimizer/adamw.yaml deleted file mode 100644 index 24f58e64..00000000 --- a/project/configs/algorithm/optimizer/adamw.yaml +++ /dev/null @@ -1,6 +0,0 @@ -_target_: torch.optim.AdamW -_partial_: True -# Learning rate of the optimizer. -lr: 4e-3 -# Weight decay coefficient. -weight_decay: null diff --git a/project/configs/config.yaml b/project/configs/config.yaml index f6c31838..0d180a55 100644 --- a/project/configs/config.yaml +++ b/project/configs/config.yaml @@ -2,7 +2,7 @@ defaults: - base_config - _self_ - datamodule: cifar10 - - algorithm: example_algo + - algorithm: ExampleAlgorithm - network: resnet18 - trainer: default.yaml - trainer/callbacks: default.yaml diff --git a/project/configs/config_test.py b/project/configs/config_test.py new file mode 100644 index 00000000..13ae0c2d --- /dev/null +++ b/project/configs/config_test.py @@ -0,0 +1,102 @@ +"""TODO: tests for the configs?""" + +import functools + +import hydra_zen +import pytest +import torch +from hydra_zen.third_party.pydantic import pydantic_parser +from hydra_zen.typing import PartialBuilds + +from project.configs.algorithm.lr_scheduler import get_all_scheduler_configs +from project.configs.algorithm.optimizer import get_all_optimizer_configs +from project.utils.testutils import seeded + + +@pytest.fixture(scope="session") +def net(device: torch.device): + with seeded(123): + net = torch.nn.Linear(10, 1).to(device) + return net + + +@pytest.mark.parametrize("optimizer_config", get_all_optimizer_configs()) +def test_optimizer_configs( + optimizer_config: type[PartialBuilds[torch.optim.Optimizer]], net: torch.nn.Module +): + assert hydra_zen.is_partial_builds(optimizer_config) + target = hydra_zen.get_target(optimizer_config) + assert issubclass(target, torch.optim.Optimizer) + + optimizer_partial = hydra_zen.instantiate(optimizer_config) + assert isinstance(optimizer_partial, functools.partial) + + optimizer = optimizer_partial(net.parameters()) + + assert isinstance(optimizer, torch.optim.Optimizer), optimizer + assert isinstance(optimizer, target) + + +# This could also be used to test with all optimizers, but isn't necessary. +# @pytest.fixture(scope="session", params=get_all_optimizer_configs()) +# @pytest.fixture(scope="session") +# def optimizer(device: torch.device, net: torch.nn.Module): +# # optimizer_config: type[PartialBuilds[torch.optim.Optimizer]] = request.param +# # optimizer = hydra_zen.instantiate(optimizer_config)(net.parameters()) +# return torch.optim.SGD(net.parameters(), lr=0.1) +# return optimizer + + +_optim = torch.optim.SGD([torch.zeros(1, requires_grad=True)]) + +default_kwargs: dict[type[torch.optim.lr_scheduler.LRScheduler], dict] = { + torch.optim.lr_scheduler.StepLR: {"step_size": 1}, + torch.optim.lr_scheduler.CosineAnnealingLR: {"T_max": 10}, + torch.optim.lr_scheduler.LambdaLR: {"lr_lambda": lambda epoch: 0.95**epoch}, + torch.optim.lr_scheduler.MultiplicativeLR: {"lr_lambda": lambda epoch: 0.95**epoch}, + torch.optim.lr_scheduler.MultiStepLR: {"milestones": [0, 1]}, + torch.optim.lr_scheduler.ExponentialLR: {"gamma": 0.8}, + torch.optim.lr_scheduler.SequentialLR: { + "schedulers": [ + torch.optim.lr_scheduler.ExponentialLR(_optim, gamma=0.9), + torch.optim.lr_scheduler.ExponentialLR(_optim, gamma=0.9), + ], + "milestones": [0], + }, + torch.optim.lr_scheduler.CyclicLR: {"base_lr": 0.1, "max_lr": 0.9}, + torch.optim.lr_scheduler.CosineAnnealingWarmRestarts: {"T_0": 1}, + torch.optim.lr_scheduler.OneCycleLR: {"max_lr": 1.0, "total_steps": 10}, +} +"""The missing arguments for some LR schedulers so we can create them during testing. + +The values don't really matter, as long as they are accepted by the constructor. +""" + + +schedulers_to_skip = { + # torch.optim.lr_scheduler.SequentialLR: "Requires other schedulers as arguments. Ugly." + torch.optim.lr_scheduler.ChainedScheduler: "Requires passing a list of schedulers as arguments." +} + + +# pytest.param(scheduler_config, marks=[pytest.mark.skipif(hydra_zen.get_target(scheduler_config) is torch.optim.lr_scheduler.SequentialLR, reason="Requires other schedulers as arguments. Ugly.") +@pytest.mark.parametrize("scheduler_config", get_all_scheduler_configs()) +def test_scheduler_configs( + scheduler_config: type[PartialBuilds[torch.optim.Optimizer]], + net: torch.nn.Module, + # optimizer: torch.optim.Optimizer, +): + assert hydra_zen.is_partial_builds(scheduler_config) + target = hydra_zen.get_target(scheduler_config) + if target in schedulers_to_skip: + pytest.skip(reason=schedulers_to_skip[target]) + + assert issubclass(target, torch.optim.lr_scheduler.LRScheduler) + + scheduler_partial = hydra_zen.instantiate( + scheduler_config, _target_wrapper_=pydantic_parser, **default_kwargs.get(target, {}) + ) + assert isinstance(scheduler_partial, functools.partial) + + lr_scheduler = scheduler_partial(_optim) + assert isinstance(lr_scheduler, target) diff --git a/project/configs/datamodule/vision.yaml b/project/configs/datamodule/vision.yaml index e3f10b79..561a36b1 100644 --- a/project/configs/datamodule/vision.yaml +++ b/project/configs/datamodule/vision.yaml @@ -1,3 +1,4 @@ +# todo: This config should not show up as an option on the command-line. _target_: project.datamodules.VisionDataModule data_dir: ${constant:DATA_DIR} num_workers: ${constant:NUM_WORKERS} diff --git a/project/configs/network/jax_cnn.yaml b/project/configs/network/jax_cnn.yaml index 4fc6dc8c..2b76cb7a 100644 --- a/project/configs/network/jax_cnn.yaml +++ b/project/configs/network/jax_cnn.yaml @@ -1,2 +1,2 @@ -_target_: project.algorithms.jax_algo.CNN +_target_: project.algorithms.jax_example.CNN num_classes: ${instance_attr:datamodule.num_classes} diff --git a/project/configs/network/jax_fcnet.yaml b/project/configs/network/jax_fcnet.yaml index 55ed3023..0c7df8d4 100644 --- a/project/configs/network/jax_fcnet.yaml +++ b/project/configs/network/jax_fcnet.yaml @@ -1,3 +1,3 @@ -_target_: project.algorithms.jax_algo.JaxFcNet +_target_: project.algorithms.jax_example.JaxFcNet num_classes: ${instance_attr:datamodule.num_classes} num_features: 256 diff --git a/project/conftest.py b/project/conftest.py index bbf29d58..bf230708 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -1,3 +1,9 @@ +"""Fixtures and test utilities. + +This module contains [PyTest fixtures](https://docs.pytest.org/en/6.2.x/fixture.html) that are used +by tests. +""" + from __future__ import annotations import os diff --git a/project/experiment.py b/project/experiment.py index a48f6390..af8c45ae 100644 --- a/project/experiment.py +++ b/project/experiment.py @@ -1,8 +1,11 @@ from __future__ import annotations +import dataclasses +import functools import logging import os import random +from collections.abc import Callable from dataclasses import dataclass, is_dataclass from logging import getLogger as get_logger from typing import Any @@ -12,27 +15,31 @@ import rich.logging import rich.traceback import torch -from hydra.utils import instantiate -from lightning import Callback, Trainer, seed_everything +from hydra_zen.third_party.pydantic import pydantic_parser +from lightning import Callback, LightningModule, Trainer, seed_everything from omegaconf import DictConfig from torch import nn -from project.algorithms import Algorithm from project.configs.config import Config from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) from project.utils.hydra_utils import get_outer_class from project.utils.types import Dataclass -from project.utils.types.protocols import ( - DataModule, - Module, -) +from project.utils.types.protocols import DataModule, Module from project.utils.utils import validate_datamodule logger = get_logger(__name__) +# todo: fix this. +def _use_pydantic[C: Callable](fn: C) -> C: + return functools.partial(hydra_zen.instantiate, _target_wrapper_=pydantic_parser) # type: ignore + + +instantiate = _use_pydantic(hydra_zen.instantiate) + + @dataclass class Experiment: """Dataclass containing everything used in an experiment. @@ -42,7 +49,7 @@ class Experiment: come in handy with `submitit` later on. """ - algorithm: Algorithm + algorithm: LightningModule network: nn.Module datamodule: DataModule trainer: Trainer @@ -159,7 +166,7 @@ def instantiate_datamodule(experiment_config: Config) -> DataModule: ) datamodule = datamodule_config else: - datamodule = instantiate(datamodule_config, **datamodule_overrides) + datamodule = hydra_zen.instantiate(datamodule_config, **datamodule_overrides) assert isinstance(datamodule, DataModule) datamodule = validate_datamodule(datamodule) @@ -208,10 +215,10 @@ def instantiate_network(experiment_config: Config, datamodule: DataModule) -> nn def instantiate_algorithm( experiment_config: Config, datamodule: DataModule, network: nn.Module -) -> Algorithm: +) -> LightningModule: # Create the algorithm algo_config = experiment_config.algorithm - if isinstance(algo_config, Algorithm): + if isinstance(algo_config, LightningModule): logger.info( f"Algorithm was already instantiated (probably to interpolate a field value)." f"{algo_config=}" @@ -229,9 +236,9 @@ def instantiate_algorithm( else: algorithm = instantiate(algo_config, datamodule=datamodule, network=network) - if not isinstance(algorithm, Algorithm): + if not isinstance(algorithm, LightningModule): raise NotImplementedError( - f"The algorithm config didn't create an Algorithm instance:\n" + f"The algorithm config didn't create a LightningModule instance:\n" f"{algo_config=}\n" f"{algorithm=}" ) @@ -239,17 +246,25 @@ def instantiate_algorithm( if hasattr(algo_config, "_target_"): # A dataclass of some sort, with a _target_ attribute. - algorithm = instantiate(algo_config, datamodule=datamodule, network=network) - assert isinstance(algorithm, Algorithm) + if hydra_zen.is_partial_builds(algo_config): + algo_partial = instantiate(algo_config) + assert isinstance(algo_partial, functools.partial) + algorithm = algo_partial(network=network, datamodule=datamodule) + else: + algorithm = instantiate(algo_config, datamodule=datamodule, network=network) + assert isinstance(algorithm, LightningModule), algorithm return algorithm - if not isinstance(algo_config, Algorithm.HParams): + if not dataclasses.is_dataclass(algo_config): + if issubclass(algo_class := get_outer_class(type(algo_config)), LightningModule): + return algo_class(datamodule=datamodule, network=network, hp=algo_config) + raise NotImplementedError( f"For now the algorithm config can either have a _target_ set to an Algorithm class, " f"or configure an inner Algorithm.HParams dataclass. Got:\n{algo_config=}" ) - algorithm_type: type[Algorithm] = get_outer_class(type(algo_config)) + algorithm_type: type[LightningModule] = get_outer_class(type(algo_config)) assert isinstance( algo_config, algorithm_type.HParams, # type: ignore diff --git a/project/main_test.py b/project/main_test.py index 32d16e4b..f636ec69 100644 --- a/project/main_test.py +++ b/project/main_test.py @@ -7,7 +7,7 @@ import hydra_zen import pytest -from project.algorithms import Algorithm, ExampleAlgorithm +from project.algorithms.example import ExampleAlgorithm from project.configs.config import Config from project.conftest import setup_hydra_for_tests_and_compose, use_overrides from project.datamodules.image_classification.cifar10 import CIFAR10DataModule @@ -40,7 +40,10 @@ def set_testing_hydra_dir(): @use_overrides([""]) def test_defaults(experiment_config: Config) -> None: - assert isinstance(experiment_config.algorithm, ExampleAlgorithm.HParams) + assert ( + hydra_zen.is_partial_builds(experiment_config.algorithm) + and hydra_zen.get_target(experiment_config.algorithm) is ExampleAlgorithm + ) assert ( isinstance(experiment_config.datamodule, CIFAR10DataModule) or hydra_zen.get_target(experiment_config.datamodule) is CIFAR10DataModule @@ -63,7 +66,7 @@ def _ids(v): ) def test_setting_algorithm( overrides: list[str], - expected_type: type[Algorithm.HParams], + expected_type: type, testing_overrides: list[str], tmp_path: Path, ) -> None: diff --git a/project/utils/__init__.py b/project/utils/__init__.py index d316a228..e9cfa161 100644 --- a/project/utils/__init__.py +++ b/project/utils/__init__.py @@ -1,5 +1,9 @@ from .device import default_device +# Import this patch for https://github.com/mit-ll-responsible-ai/hydra-zen/issues/705 to make sure that it gets applied. +from .hydra_utils import patched_safe_name + __all__ = [ "default_device", + "patched_safe_name", ] diff --git a/project/utils/hydra_utils.py b/project/utils/hydra_utils.py index 1ae7cc0e..1218240d 100644 --- a/project/utils/hydra_utils.py +++ b/project/utils/hydra_utils.py @@ -15,7 +15,10 @@ TypeVar, ) +import hydra_zen.structured_configs._utils from hydra_zen import instantiate +from hydra_zen.structured_configs._utils import safe_name +from hydra_zen.third_party.pydantic import pydantic_parser from hydra_zen.typing._implementations import Partial as _Partial from omegaconf import DictConfig, OmegaConf @@ -24,9 +27,32 @@ logger = get_logger(__name__) + T = TypeVar("T") +def patched_safe_name(obj: Any, repr_allowed: bool = True): + """Patches a bug in Hydra-zen where the _target_ of inner classes is incorrect: + https://github.com/mit-ll-responsible-ai/hydra-zen/issues/705 + """ + + if not hasattr(obj, "__qualname__"): + return safe_name(obj, repr_allowed=repr_allowed) + + name = safe_name(obj, repr_allowed=repr_allowed) + qualname = obj.__qualname__ + assert isinstance(qualname, str) + + if name != qualname and qualname.endswith("." + name): + logger.debug(f"Using patched fn: returning {qualname} for target {obj}") + return qualname + + return name + + +hydra_zen.structured_configs._utils.safe_name = patched_safe_name + + def interpolate_config_attribute(*attributes: str, default: Any | Literal[MISSING] = MISSING): """Use this in a config to to get an attribute from another config after it is instantiated. @@ -377,3 +403,37 @@ def _default_factory( if default_factory is not dataclasses.MISSING: return default_factory() return default # type: ignore + + +def make_config_and_store[Target]( + target: Callable[..., Target], *, store: hydra_zen.ZenStore, **overrides +): + """Creates a config dataclass for the given target and stores it in the config store. + + This uses [hydra_zen.builds](https://mit-ll-responsible-ai.github.io/hydra-zen/generated/hydra_zen.builds.html) + to create the config dataclass and stores it at the name `config_name`, or `target.__name__`. + """ + _current_frame = inspect.currentframe() + assert _current_frame + _calling_module = inspect.getmodule(_current_frame.f_back) + assert _calling_module + + config = hydra_zen.builds( + target, + populate_full_signature=True, + zen_partial=True, + zen_dataclass={ + "cls_name": f"{target.__name__}Config", + # BUG: Causes issues, tries to get the config from the module again, which re-creates + # it? + # "module": _calling_module.__name__, + # TODO: Seems to be causing issues with `_target_` being overwritten? + "frozen": False, + }, + zen_wrappers=pydantic_parser, + **overrides, + ) + name_of_config_in_store = target.__name__ + logger.warning(f"Created a config entry {name_of_config_in_store} for {target.__qualname__}") + store(config, name=name_of_config_in_store) + return config diff --git a/project/utils/testutils.py b/project/utils/testutils.py index 49fae182..5605f011 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -13,7 +13,7 @@ from contextlib import contextmanager from logging import getLogger as get_logger from pathlib import Path -from typing import Any, TypeVar +from typing import Any, Literal, TypeVar import hydra_zen import numpy as np @@ -34,7 +34,6 @@ from project.experiment import instantiate_trainer from project.utils.env_vars import NETWORK_DIR from project.utils.hydra_utils import get_attr, get_outer_class -from project.utils.types import PhaseStr from project.utils.types.protocols import ( DataModule, ) @@ -427,7 +426,10 @@ def validation_step(self, batch: tuple[Tensor, Tensor], batch_index: int) -> Ten return self.shared_step(batch, batch_index, phase="val") def shared_step( - self, batch: tuple[Tensor, Tensor], batch_index: int, phase: PhaseStr + self, + batch: tuple[Tensor, Tensor], + batch_index: int, + phase: Literal["train", "val", "test"], ) -> Tensor: x, _y = batch latents = self.inf_network(x) @@ -461,7 +463,10 @@ def forward(self, input: Tensor) -> Tensor: return output def shared_step( - self, batch: tuple[Tensor, Tensor], batch_index: int, phase: PhaseStr + self, + batch: tuple[Tensor, Tensor], + batch_index: int, + phase: Literal["train", "val", "test"], ) -> Tensor: x, y = batch latents = self.inf_network(x) @@ -500,7 +505,10 @@ def validation_step(self, batch: tuple[Tensor, Tensor], batch_index: int) -> Ten return self.shared_step(batch, batch_index, phase="val") def shared_step( - self, batch: tuple[Tensor, Tensor], batch_index: int, phase: PhaseStr + self, + batch: tuple[Tensor, Tensor], + batch_index: int, + phase: Literal["train", "val", "test"], ) -> Tensor: x, y = batch logits = self.network(x) diff --git a/project/utils/types/__init__.py b/project/utils/types/__init__.py index 381ed570..6ac73ada 100644 --- a/project/utils/types/__init__.py +++ b/project/utils/types/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, NewType, TypeGuard, Unpack +from typing import Annotated, Any, NewType, TypeGuard, Unpack import annotated_types from torch import Tensor @@ -16,16 +16,6 @@ S = NewType("S", int) -# todo: Fix this. Why do we have these enums? Are they necessary? Could we use the same ones as PL if we wanted to? -# from lightning.pytorch.trainer.states import RunningStage as PhaseStr -# from lightning.pytorch.trainer.states import TrainerFn as StageStr - -PhaseStr = Literal["train", "val", "test"] -"""The trainer phases. - -TODO: There has to exist an enum for it somewhere in PyTorch Lightning. -""" - # Types used with pydantic: FloatBetween0And1 = Annotated[float, annotated_types.Ge(0), annotated_types.Le(1)]