Skip to content

Commit

Permalink
Clean up the configs for algorithms
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jul 10, 2024
1 parent 1560307 commit 2c04bac
Show file tree
Hide file tree
Showing 15 changed files with 113 additions and 130 deletions.
5 changes: 4 additions & 1 deletion project/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from . import algorithms, configs, datamodules, experiment, main, networks, utils
from .configs import Config
from .configs import Config, add_configs_to_hydra_store
from .experiment import Experiment

# from .networks import FcNet
from .utils.types import DataModule

add_configs_to_hydra_store()


__all__ = [
"algorithms",
"experiment",
Expand Down
18 changes: 1 addition & 17 deletions project/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,10 @@
from hydra_zen import builds, store

from project.algorithms.jax_example import JaxExample
from project.algorithms.no_op import NoOp

from .example import ExampleAlgorithm

# NOTE: This works the same way as creating config files for each algorithm under
# `configs/algorithm`. From the command-line, you can select both configs that are yaml files as
# well as structured config (dataclasses).

# If you add a configuration file under `configs/algorithm`, it will also be available as an option
# from the command-line, and be validated against the schema.
# todo: It might be nicer if we did this this `configs/algorithms` instead of here, no?
algorithm_store = store(group="algorithm")
algorithm_store(ExampleAlgorithm.HParams(), name="example_algo")
algorithm_store(builds(NoOp, populate_full_signature=False), name="no_op")
algorithm_store(JaxExample.HParams(), name="jax_example")

algorithm_store.add_to_hydra_store()

__all__ = [
"ExampleAlgorithm",
"ManualGradientsExample",
"JaxExample",
"NoOp",
]
26 changes: 17 additions & 9 deletions project/algorithms/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@
from torch.optim.lr_scheduler import _LRScheduler

from project.algorithms.callbacks.classification_metrics import ClassificationMetricsCallback
from project.configs.algorithm.lr_scheduler import CosineAnnealingLRConfig
from project.configs.algorithm.optimizer import AdamConfig
from project.utils.types.protocols import DataModule
from project.configs.lr_scheduler import CosineAnnealingLRConfig
from project.configs.optimizer import AdamConfig
from project.datamodules.image_classification import ImageClassificationDataModule

logger = getLogger(__name__)


class ExampleAlgorithm(LightningModule):
"""Example learning algorithm for image classification."""

@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class HParams:
"""Hyper-Parameters."""

Expand All @@ -53,9 +53,9 @@ class HParams:

def __init__(
self,
datamodule: DataModule[tuple[torch.Tensor, torch.Tensor]],
datamodule: ImageClassificationDataModule,
network: torch.nn.Module,
hp: ExampleAlgorithm.HParams | None = None,
hp: ExampleAlgorithm.HParams = HParams(),
):
super().__init__()
self.datamodule = datamodule
Expand Down Expand Up @@ -102,12 +102,20 @@ def shared_step(

def configure_optimizers(self) -> dict:
"""Creates the optimizers and the LR scheduler (if needed)."""
optimizer_partial: functools.partial[Optimizer] = instantiate(self.hp.optimizer)
lr_scheduler_partial: functools.partial[_LRScheduler] = instantiate(self.hp.lr_scheduler)
optimizer_partial: functools.partial[Optimizer]
if isinstance(self.hp.optimizer, functools.partial):
optimizer_partial = self.hp.optimizer
else:
optimizer_partial = instantiate(self.hp.optimizer)
optimizer = optimizer_partial(self.parameters())

optimizers: dict[str, Any] = {"optimizer": optimizer}

lr_scheduler_partial: functools.partial[_LRScheduler]
if isinstance(self.hp.lr_scheduler, functools.partial):
lr_scheduler_partial = self.hp.lr_scheduler
else:
lr_scheduler_partial = instantiate(self.hp.lr_scheduler)

if self.hp.lr_scheduler_frequency != 0:
lr_scheduler = lr_scheduler_partial(optimizer)
optimizers["lr_scheduler"] = {
Expand Down
2 changes: 1 addition & 1 deletion project/algorithms/example_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@

class TestExampleAlgorithm(ClassificationAlgorithmTests[ExampleAlgorithm]):
algorithm_type = ExampleAlgorithm
algorithm_name: str = "example_algo"
algorithm_name: str = "example"
unsupported_datamodule_names: ClassVar[list[str]] = ["rl"]
_supported_network_types: ClassVar[list[type]] = [torch.nn.Module]
4 changes: 2 additions & 2 deletions project/algorithms/jax_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class JaxExample(LightningModule):
written in Jax, and the loss function is in pytorch.
"""

@dataclasses.dataclass
@dataclasses.dataclass(frozen=True)
class HParams:
"""Hyper-parameters of the algo."""

Expand All @@ -93,7 +93,7 @@ def __init__(
*,
network: flax.linen.Module,
datamodule: ImageClassificationDataModule,
hp: HParams | None = None,
hp: HParams = HParams(),
):
super().__init__()
self.datamodule = datamodule
Expand Down
26 changes: 18 additions & 8 deletions project/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,27 @@

from hydra.core.config_store import ConfigStore

from ..utils.env_vars import REPO_ROOTDIR, SLURM_JOB_ID, SLURM_TMPDIR
from .config import Config
from .datamodule import (
datamodule_store,
)
from .network import network_store
from project.configs.algorithm import algorithm_store, populate_algorithm_store
from project.configs.config import Config
from project.configs.datamodule import datamodule_store
from project.configs.lr_scheduler import lr_scheduler_store
from project.configs.network import network_store
from project.configs.optimizer import optimizer_store
from project.utils.env_vars import REPO_ROOTDIR, SLURM_JOB_ID, SLURM_TMPDIR

cs = ConfigStore.instance()
cs.store(name="base_config", node=Config)
datamodule_store.add_to_hydra_store()
network_store.add_to_hydra_store()


def add_configs_to_hydra_store():
datamodule_store.add_to_hydra_store()
network_store.add_to_hydra_store()
optimizer_store.add_to_hydra_store()
lr_scheduler_store.add_to_hydra_store()
populate_algorithm_store()
algorithm_store.add_to_hydra_store()


# todo: move the algorithm_store.add_to_hydra_store() here?

__all__ = [
Expand Down
24 changes: 24 additions & 0 deletions project/configs/algorithm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from hydra_zen import make_custom_builds_fn, store
from hydra_zen.third_party.pydantic import pydantic_parser

builds_fn = make_custom_builds_fn(
zen_partial=True, populate_full_signature=True, zen_wrappers=pydantic_parser
)


# NOTE: This works the same way as creating config files for each algorithm under
# `configs/algorithm`. From the command-line, you can select both configs that are yaml files as
# well as structured config (dataclasses).

# If you add a configuration file under `configs/algorithm`, it will also be available as an option
# from the command-line, and can use these configs in their default list.
algorithm_store = store(group="algorithm")


def populate_algorithm_store():
# Note: import here to avoid circular imports.
from project.algorithms import ExampleAlgorithm, JaxExample, NoOp

algorithm_store(builds_fn(ExampleAlgorithm), name="example_algo")
algorithm_store(builds_fn(NoOp), name="no_op")
algorithm_store(builds_fn(JaxExample), name="jax_example")
7 changes: 7 additions & 0 deletions project/configs/algorithm/example_from_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- optimizer/[email protected]
- lr_scheduler/[email protected]_scheduler
_target_: project.algorithms.example.ExampleAlgorithm
_partial_: true
hp:
_target_: project.algorithms.example.ExampleAlgorithm.HParams
33 changes: 0 additions & 33 deletions project/configs/algorithm/lr_scheduler/__init__.py

This file was deleted.

57 changes: 0 additions & 57 deletions project/configs/algorithm/optimizer/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion project/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defaults:
- base_config
- _self_
- datamodule: cifar10
- algorithm: example_algo
- algorithm: example
- network: resnet18
- trainer: default.yaml
- trainer/callbacks: default.yaml
Expand Down
20 changes: 20 additions & 0 deletions project/configs/lr_scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import torch.optim.lr_scheduler
from hydra_zen import make_custom_builds_fn, store
from hydra_zen.third_party.pydantic import pydantic_parser

builds_fn = make_custom_builds_fn(
zen_partial=True, populate_full_signature=True, zen_wrappers=pydantic_parser
)

CosineAnnealingLRConfig = builds_fn(torch.optim.lr_scheduler.CosineAnnealingLR, T_max=85)
StepLRConfig = builds_fn(torch.optim.lr_scheduler.CosineAnnealingLR)
lr_scheduler_store = store(group="algorithm/lr_scheduler")
lr_scheduler_store(StepLRConfig, name="step_lr")
lr_scheduler_store(CosineAnnealingLRConfig, name="cosine_annealing_lr")


# IDEA: Could be interesting to generate configs for any member of the torch.optimizer.lr_scheduler
# package dynamically (and store it)?
# def __getattr__(self, name: str):
# """"""
14 changes: 14 additions & 0 deletions project/configs/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
import torch.optim
from hydra_zen import make_custom_builds_fn, store
from hydra_zen.third_party.pydantic import pydantic_parser

builds_fn = make_custom_builds_fn(
zen_partial=True, populate_full_signature=True, zen_wrappers=pydantic_parser
)

optimizer_store = store(group="algorithm/optimizer")
AdamConfig = builds_fn(torch.optim.Adam)
SGDConfig = builds_fn(torch.optim.SGD)
optimizer_store(AdamConfig, name="adam")
optimizer_store(SGDConfig, name="sgd")
File renamed without changes.
5 changes: 4 additions & 1 deletion project/experiment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import dataclasses
import functools
import logging
import os
import random
Expand Down Expand Up @@ -237,7 +238,9 @@ def instantiate_algorithm(
if hasattr(algo_config, "_target_"):
# A dataclass of some sort, with a _target_ attribute.
algorithm = instantiate(algo_config, datamodule=datamodule, network=network)
assert isinstance(algorithm, LightningModule)
if isinstance(algorithm, functools.partial):
algorithm = algorithm()
assert isinstance(algorithm, LightningModule), algorithm
return algorithm

if not dataclasses.is_dataclass(algo_config):
Expand Down

0 comments on commit 2c04bac

Please sign in to comment.