diff --git a/project/__init__.py b/project/__init__.py index 16882217..aa3d30a6 100644 --- a/project/__init__.py +++ b/project/__init__.py @@ -1,10 +1,13 @@ from . import algorithms, configs, datamodules, experiment, main, networks, utils -from .configs import Config +from .configs import Config, add_configs_to_hydra_store from .experiment import Experiment # from .networks import FcNet from .utils.types import DataModule +add_configs_to_hydra_store() + + __all__ = [ "algorithms", "experiment", diff --git a/project/algorithms/__init__.py b/project/algorithms/__init__.py index 5ffca1a8..d5e78cc3 100644 --- a/project/algorithms/__init__.py +++ b/project/algorithms/__init__.py @@ -1,26 +1,10 @@ -from hydra_zen import builds, store - from project.algorithms.jax_example import JaxExample from project.algorithms.no_op import NoOp from .example import ExampleAlgorithm -# NOTE: This works the same way as creating config files for each algorithm under -# `configs/algorithm`. From the command-line, you can select both configs that are yaml files as -# well as structured config (dataclasses). - -# If you add a configuration file under `configs/algorithm`, it will also be available as an option -# from the command-line, and be validated against the schema. -# todo: It might be nicer if we did this this `configs/algorithms` instead of here, no? -algorithm_store = store(group="algorithm") -algorithm_store(ExampleAlgorithm.HParams(), name="example_algo") -algorithm_store(builds(NoOp, populate_full_signature=False), name="no_op") -algorithm_store(JaxExample.HParams(), name="jax_example") - -algorithm_store.add_to_hydra_store() - __all__ = [ "ExampleAlgorithm", - "ManualGradientsExample", "JaxExample", + "NoOp", ] diff --git a/project/algorithms/example.py b/project/algorithms/example.py index 8ec1d52f..11536ed8 100644 --- a/project/algorithms/example.py +++ b/project/algorithms/example.py @@ -20,9 +20,9 @@ from torch.optim.lr_scheduler import _LRScheduler from project.algorithms.callbacks.classification_metrics import ClassificationMetricsCallback -from project.configs.algorithm.lr_scheduler import CosineAnnealingLRConfig -from project.configs.algorithm.optimizer import AdamConfig -from project.utils.types.protocols import DataModule +from project.configs.lr_scheduler import CosineAnnealingLRConfig +from project.configs.optimizer import AdamConfig +from project.datamodules.image_classification import ImageClassificationDataModule logger = getLogger(__name__) @@ -30,7 +30,7 @@ class ExampleAlgorithm(LightningModule): """Example learning algorithm for image classification.""" - @dataclasses.dataclass + @dataclasses.dataclass(frozen=True) class HParams: """Hyper-Parameters.""" @@ -53,9 +53,9 @@ class HParams: def __init__( self, - datamodule: DataModule[tuple[torch.Tensor, torch.Tensor]], + datamodule: ImageClassificationDataModule, network: torch.nn.Module, - hp: ExampleAlgorithm.HParams | None = None, + hp: ExampleAlgorithm.HParams = HParams(), ): super().__init__() self.datamodule = datamodule @@ -102,12 +102,20 @@ def shared_step( def configure_optimizers(self) -> dict: """Creates the optimizers and the LR scheduler (if needed).""" - optimizer_partial: functools.partial[Optimizer] = instantiate(self.hp.optimizer) - lr_scheduler_partial: functools.partial[_LRScheduler] = instantiate(self.hp.lr_scheduler) + optimizer_partial: functools.partial[Optimizer] + if isinstance(self.hp.optimizer, functools.partial): + optimizer_partial = self.hp.optimizer + else: + optimizer_partial = instantiate(self.hp.optimizer) optimizer = optimizer_partial(self.parameters()) - optimizers: dict[str, Any] = {"optimizer": optimizer} + lr_scheduler_partial: functools.partial[_LRScheduler] + if isinstance(self.hp.lr_scheduler, functools.partial): + lr_scheduler_partial = self.hp.lr_scheduler + else: + lr_scheduler_partial = instantiate(self.hp.lr_scheduler) + if self.hp.lr_scheduler_frequency != 0: lr_scheduler = lr_scheduler_partial(optimizer) optimizers["lr_scheduler"] = { diff --git a/project/algorithms/example_test.py b/project/algorithms/example_test.py index 9d9a1824..fa5ea9b1 100644 --- a/project/algorithms/example_test.py +++ b/project/algorithms/example_test.py @@ -9,6 +9,6 @@ class TestExampleAlgorithm(ClassificationAlgorithmTests[ExampleAlgorithm]): algorithm_type = ExampleAlgorithm - algorithm_name: str = "example_algo" + algorithm_name: str = "example" unsupported_datamodule_names: ClassVar[list[str]] = ["rl"] _supported_network_types: ClassVar[list[type]] = [torch.nn.Module] diff --git a/project/algorithms/jax_example.py b/project/algorithms/jax_example.py index edb12090..257a24c7 100644 --- a/project/algorithms/jax_example.py +++ b/project/algorithms/jax_example.py @@ -80,7 +80,7 @@ class JaxExample(LightningModule): written in Jax, and the loss function is in pytorch. """ - @dataclasses.dataclass + @dataclasses.dataclass(frozen=True) class HParams: """Hyper-parameters of the algo.""" @@ -93,7 +93,7 @@ def __init__( *, network: flax.linen.Module, datamodule: ImageClassificationDataModule, - hp: HParams | None = None, + hp: HParams = HParams(), ): super().__init__() self.datamodule = datamodule diff --git a/project/configs/__init__.py b/project/configs/__init__.py index 13e6a8d7..c8d8fc85 100644 --- a/project/configs/__init__.py +++ b/project/configs/__init__.py @@ -2,17 +2,27 @@ from hydra.core.config_store import ConfigStore -from ..utils.env_vars import REPO_ROOTDIR, SLURM_JOB_ID, SLURM_TMPDIR -from .config import Config -from .datamodule import ( - datamodule_store, -) -from .network import network_store +from project.configs.algorithm import algorithm_store, populate_algorithm_store +from project.configs.config import Config +from project.configs.datamodule import datamodule_store +from project.configs.lr_scheduler import lr_scheduler_store +from project.configs.network import network_store +from project.configs.optimizer import optimizer_store +from project.utils.env_vars import REPO_ROOTDIR, SLURM_JOB_ID, SLURM_TMPDIR cs = ConfigStore.instance() cs.store(name="base_config", node=Config) -datamodule_store.add_to_hydra_store() -network_store.add_to_hydra_store() + + +def add_configs_to_hydra_store(): + datamodule_store.add_to_hydra_store() + network_store.add_to_hydra_store() + optimizer_store.add_to_hydra_store() + lr_scheduler_store.add_to_hydra_store() + populate_algorithm_store() + algorithm_store.add_to_hydra_store() + + # todo: move the algorithm_store.add_to_hydra_store() here? __all__ = [ diff --git a/project/configs/algorithm/__init__.py b/project/configs/algorithm/__init__.py index e69de29b..b8c08718 100644 --- a/project/configs/algorithm/__init__.py +++ b/project/configs/algorithm/__init__.py @@ -0,0 +1,24 @@ +from hydra_zen import make_custom_builds_fn, store +from hydra_zen.third_party.pydantic import pydantic_parser + +builds_fn = make_custom_builds_fn( + zen_partial=True, populate_full_signature=True, zen_wrappers=pydantic_parser +) + + +# NOTE: This works the same way as creating config files for each algorithm under +# `configs/algorithm`. From the command-line, you can select both configs that are yaml files as +# well as structured config (dataclasses). + +# If you add a configuration file under `configs/algorithm`, it will also be available as an option +# from the command-line, and can use these configs in their default list. +algorithm_store = store(group="algorithm") + + +def populate_algorithm_store(): + # Note: import here to avoid circular imports. + from project.algorithms import ExampleAlgorithm, JaxExample, NoOp + + algorithm_store(builds_fn(ExampleAlgorithm), name="example_algo") + algorithm_store(builds_fn(NoOp), name="no_op") + algorithm_store(builds_fn(JaxExample), name="jax_example") diff --git a/project/configs/algorithm/example_from_config.yaml b/project/configs/algorithm/example_from_config.yaml new file mode 100644 index 00000000..5557e38b --- /dev/null +++ b/project/configs/algorithm/example_from_config.yaml @@ -0,0 +1,7 @@ +defaults: + - optimizer/adam@hp.optimizer + - lr_scheduler/step_lr@hp.lr_scheduler +_target_: project.algorithms.example.ExampleAlgorithm +_partial_: true +hp: + _target_: project.algorithms.example.ExampleAlgorithm.HParams diff --git a/project/configs/algorithm/lr_scheduler/__init__.py b/project/configs/algorithm/lr_scheduler/__init__.py deleted file mode 100644 index 67180aa1..00000000 --- a/project/configs/algorithm/lr_scheduler/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import TypeVar - -import torch.optim.lr_scheduler -from hydra_zen import hydrated_dataclass -from hydra_zen.typing._implementations import PartialBuilds - -LRSchedulerType = TypeVar("LRSchedulerType", bound=torch.optim.lr_scheduler._LRScheduler) -# TODO: Double-check this, but defining LRSchedulerConfig like this makes it unusable as a type -# annotation on the hparams, since omegaconf will complain it isn't a valid base class. -type LRSchedulerConfig[LRSchedulerType: torch.optim.lr_scheduler._LRScheduler] = PartialBuilds[ - LRSchedulerType -] - - -# TODO: getting doctest issues here? -@hydrated_dataclass(target=torch.optim.lr_scheduler.StepLR, zen_partial=True) -class StepLRConfig: - """Config for the StepLR Scheduler.""" - - step_size: int = 30 - gamma: float = 0.1 - last_epoch: int = -1 - verbose: bool = False - - -@hydrated_dataclass(target=torch.optim.lr_scheduler.CosineAnnealingLR, zen_partial=True) -class CosineAnnealingLRConfig: - """Config for the CosineAnnealingLR Scheduler.""" - - T_max: int = 85 - eta_min: float = 0 - last_epoch: int = -1 - verbose: bool = False diff --git a/project/configs/algorithm/optimizer/__init__.py b/project/configs/algorithm/optimizer/__init__.py deleted file mode 100644 index 2efcedc0..00000000 --- a/project/configs/algorithm/optimizer/__init__.py +++ /dev/null @@ -1,57 +0,0 @@ -import functools -from collections.abc import Callable -from typing import Any, Protocol, runtime_checkable - -import torch -from hydra_zen import hydrated_dataclass, instantiate - -# OptimizerType = TypeVar("OptimizerType", bound=torch.optim.Optimizer, covariant=True) - - -@runtime_checkable -class OptimizerConfig[OptimizerType: torch.optim.Optimizer](Protocol): - """Configuration for an optimizer. - - Returns a partial[OptimizerType] when instantiated. - """ - - __call__: Callable[..., functools.partial[OptimizerType]] = instantiate - - -# NOTE: Getting weird bug in omegaconf if I try to make OptimizerConfig generic! -# Instead I'm making it a protocol. - - -@hydrated_dataclass(target=torch.optim.SGD, zen_partial=True) -class SGDConfig: - """Configuration for the SGD optimizer.""" - - lr: Any - momentum: float = 0.0 - dampening: float = 0.0 - weight_decay: float = 0.0 - nesterov: bool = False - - __call__ = instantiate - - -# TODO: frozen=True doesn't work here, which is a real bummer. It would have saved us a lot of -# functools.partial(AdamConfig, lr=123) nonsense. -@hydrated_dataclass(target=torch.optim.Adam, zen_partial=True) -class AdamConfig: - """Configuration for the Adam optimizer.""" - - lr: Any = 0.001 - betas: Any = (0.9, 0.999) - eps: float = 1e-08 - weight_decay: float = 0 - amsgrad: bool = False - - __call__ = instantiate - - -# NOTE: we don't add an `optimizer` group, since models could have one or more optimizers. -# Models can register their own groups, e.g. `model/optimizer`. if they want to. -# cs = ConfigStore.instance() -# cs.store(group="optimizer", name="sgd", node=SGDConfig) -# cs.store(group="optimizer", name="adam", node=AdamConfig) diff --git a/project/configs/config.yaml b/project/configs/config.yaml index f6c31838..72a76392 100644 --- a/project/configs/config.yaml +++ b/project/configs/config.yaml @@ -2,7 +2,7 @@ defaults: - base_config - _self_ - datamodule: cifar10 - - algorithm: example_algo + - algorithm: example - network: resnet18 - trainer: default.yaml - trainer/callbacks: default.yaml diff --git a/project/configs/lr_scheduler/__init__.py b/project/configs/lr_scheduler/__init__.py new file mode 100644 index 00000000..1e8101c1 --- /dev/null +++ b/project/configs/lr_scheduler/__init__.py @@ -0,0 +1,20 @@ +import torch +import torch.optim.lr_scheduler +from hydra_zen import make_custom_builds_fn, store +from hydra_zen.third_party.pydantic import pydantic_parser + +builds_fn = make_custom_builds_fn( + zen_partial=True, populate_full_signature=True, zen_wrappers=pydantic_parser +) + +CosineAnnealingLRConfig = builds_fn(torch.optim.lr_scheduler.CosineAnnealingLR, T_max=85) +StepLRConfig = builds_fn(torch.optim.lr_scheduler.CosineAnnealingLR) +lr_scheduler_store = store(group="algorithm/lr_scheduler") +lr_scheduler_store(StepLRConfig, name="step_lr") +lr_scheduler_store(CosineAnnealingLRConfig, name="cosine_annealing_lr") + + +# IDEA: Could be interesting to generate configs for any member of the torch.optimizer.lr_scheduler +# package dynamically (and store it)? +# def __getattr__(self, name: str): +# """""" diff --git a/project/configs/optimizer/__init__.py b/project/configs/optimizer/__init__.py new file mode 100644 index 00000000..704b40e5 --- /dev/null +++ b/project/configs/optimizer/__init__.py @@ -0,0 +1,14 @@ +import torch +import torch.optim +from hydra_zen import make_custom_builds_fn, store +from hydra_zen.third_party.pydantic import pydantic_parser + +builds_fn = make_custom_builds_fn( + zen_partial=True, populate_full_signature=True, zen_wrappers=pydantic_parser +) + +optimizer_store = store(group="algorithm/optimizer") +AdamConfig = builds_fn(torch.optim.Adam) +SGDConfig = builds_fn(torch.optim.SGD) +optimizer_store(AdamConfig, name="adam") +optimizer_store(SGDConfig, name="sgd") diff --git a/project/configs/algorithm/optimizer/adamw.yaml b/project/configs/optimizer/adamw.yaml similarity index 100% rename from project/configs/algorithm/optimizer/adamw.yaml rename to project/configs/optimizer/adamw.yaml diff --git a/project/experiment.py b/project/experiment.py index 3434135a..ca0fa78e 100644 --- a/project/experiment.py +++ b/project/experiment.py @@ -1,6 +1,7 @@ from __future__ import annotations import dataclasses +import functools import logging import os import random @@ -237,7 +238,9 @@ def instantiate_algorithm( if hasattr(algo_config, "_target_"): # A dataclass of some sort, with a _target_ attribute. algorithm = instantiate(algo_config, datamodule=datamodule, network=network) - assert isinstance(algorithm, LightningModule) + if isinstance(algorithm, functools.partial): + algorithm = algorithm() + assert isinstance(algorithm, LightningModule), algorithm return algorithm if not dataclasses.is_dataclass(algo_config):