Skip to content

Commit

Permalink
Create dyn. configs for optimizers and schedulers
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jul 10, 2024
1 parent 2c04bac commit 8d25f8d
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 47 deletions.
8 changes: 6 additions & 2 deletions project/algorithms/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from torch.optim.lr_scheduler import _LRScheduler

from project.algorithms.callbacks.classification_metrics import ClassificationMetricsCallback
from project.configs.lr_scheduler import CosineAnnealingLRConfig
from project.configs.optimizer import AdamConfig
from project.configs.algorithm.lr_scheduler import CosineAnnealingLRConfig
from project.configs.algorithm.optimizer import AdamConfig
from project.datamodules.image_classification import ImageClassificationDataModule

logger = getLogger(__name__)
Expand Down Expand Up @@ -133,6 +133,10 @@ def configure_callbacks(self) -> list[Callback]:
# Log some classification metrics. (This callback adds some metrics on this module).
ClassificationMetricsCallback.attach_to(self, num_classes=self.datamodule.num_classes)
)
if self.hp.lr_scheduler_frequency != 0:
from lightning.pytorch.callbacks import LearningRateMonitor

callbacks.append(LearningRateMonitor())
if self.hp.early_stopping_patience != 0:
# If early stopping is enabled, add a Callback for it:
callbacks.append(
Expand Down
9 changes: 2 additions & 7 deletions project/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

from hydra.core.config_store import ConfigStore

from project.configs.algorithm import algorithm_store, populate_algorithm_store
from project.configs.algorithm import register_algorithm_configs
from project.configs.config import Config
from project.configs.datamodule import datamodule_store
from project.configs.lr_scheduler import lr_scheduler_store
from project.configs.network import network_store
from project.configs.optimizer import optimizer_store
from project.utils.env_vars import REPO_ROOTDIR, SLURM_JOB_ID, SLURM_TMPDIR

cs = ConfigStore.instance()
Expand All @@ -17,10 +15,7 @@
def add_configs_to_hydra_store():
datamodule_store.add_to_hydra_store()
network_store.add_to_hydra_store()
optimizer_store.add_to_hydra_store()
lr_scheduler_store.add_to_hydra_store()
populate_algorithm_store()
algorithm_store.add_to_hydra_store()
register_algorithm_configs()


# todo: move the algorithm_store.add_to_hydra_store() here?
Expand Down
11 changes: 9 additions & 2 deletions project/configs/algorithm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from hydra_zen import make_custom_builds_fn, store
from hydra_zen.third_party.pydantic import pydantic_parser

from .lr_scheduler import lr_scheduler_store
from .optimizer import optimizer_store

builds_fn = make_custom_builds_fn(
zen_partial=True, populate_full_signature=True, zen_wrappers=pydantic_parser
)
Expand All @@ -15,10 +18,14 @@
algorithm_store = store(group="algorithm")


def populate_algorithm_store():
def register_algorithm_configs():
# Note: import here to avoid circular imports.
from project.algorithms import ExampleAlgorithm, JaxExample, NoOp

algorithm_store(builds_fn(ExampleAlgorithm), name="example_algo")
algorithm_store(builds_fn(ExampleAlgorithm), name="example")
algorithm_store(builds_fn(NoOp), name="no_op")
algorithm_store(builds_fn(JaxExample), name="jax_example")

optimizer_store.add_to_hydra_store()
lr_scheduler_store.add_to_hydra_store()
algorithm_store.add_to_hydra_store()
7 changes: 5 additions & 2 deletions project/configs/algorithm/example_from_config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
defaults:
- optimizer/[email protected]
- lr_scheduler/[email protected]_scheduler
# Apply the `algorithm/optimizer/Adam` config at `hp.optimizer` in this config.
- optimizer/[email protected]
- lr_scheduler/[email protected]_scheduler
_target_: project.algorithms.example.ExampleAlgorithm
_partial_: true
hp:
_target_: project.algorithms.example.ExampleAlgorithm.HParams
lr_scheduler:
step_size: 1 # Required argument for the StepLR scheduler. (reduce LR every {step_size} epochs)
63 changes: 63 additions & 0 deletions project/configs/algorithm/lr_scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import dataclasses
import inspect
from logging import getLogger as get_logger

import torch
import torch.optim.lr_scheduler
from hydra_zen import make_custom_builds_fn, store

logger = get_logger(__name__)

builds_fn = make_custom_builds_fn(zen_partial=True, populate_full_signature=True)

# LR Schedulers whose constructors have arguments with missing defaults have to be created manually,
# because we otherwise get some errors if we try to use them (e.g. T_max doesn't have a default.)

CosineAnnealingLRConfig = builds_fn(torch.optim.lr_scheduler.CosineAnnealingLR, T_max="???")
StepLRConfig = builds_fn(torch.optim.lr_scheduler.StepLR, step_size="???")
lr_scheduler_store = store(group="algorithm/lr_scheduler")
lr_scheduler_store(StepLRConfig, name="StepLR")
lr_scheduler_store(CosineAnnealingLRConfig, name="CosineAnnealingLR")


# IDEA: Could be interesting to generate configs for any member of the torch.optimizer.lr_scheduler
# package dynamically (and store it)?
# def __getattr__(self, name: str):
# """"""

_configs_defined_so_far = [k for k, v in locals().items() if dataclasses.is_dataclass(v)]
for scheduler_name, scheduler_type in [
(_name, _obj)
for _name, _obj in vars(torch.optim.lr_scheduler).items()
if inspect.isclass(_obj)
and issubclass(_obj, torch.optim.lr_scheduler.LRScheduler)
and _obj is not torch.optim.lr_scheduler.LRScheduler
]:
_config_name = f"{scheduler_name}Config"
if _config_name in _configs_defined_so_far:
# We already have a hand-made config for this scheduler. Skip it.
continue

_lr_scheduler_config = builds_fn(scheduler_type, zen_dataclass={"cls_name": _config_name})
lr_scheduler_store(_lr_scheduler_config, name=scheduler_name)
logger.debug(f"Registering config for the {scheduler_type} LR scheduler.")


def __getattr__(config_name: str):
if not config_name.endswith("Config"):
raise AttributeError
scheduler_name = config_name.removesuffix("Config")
# the keys for the config store are tuples of the form (group, config_name)
group = "algorithm/lr_scheduler"
store_key = (group, scheduler_name)
if store_key in lr_scheduler_store[group]:
logger.debug(f"Dynamically retrieving the config for the {scheduler_name} LR scheduler.")
return lr_scheduler_store[store_key]
available_configs = sorted(
config_name for (_group, config_name) in lr_scheduler_store[group].keys()
)
logger.error(
f"Unable to find the config for {scheduler_name=}. Available configs: {available_configs}."
)

raise AttributeError
45 changes: 45 additions & 0 deletions project/configs/algorithm/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import inspect
from logging import getLogger as get_logger

import torch
import torch.optim
from hydra_zen import make_custom_builds_fn, store

logger = get_logger(__name__)
builds_fn = make_custom_builds_fn(zen_partial=True, populate_full_signature=True)

optimizer_store = store(group="algorithm/optimizer")
# AdamConfig = builds_fn(torch.optim.Adam)
# SGDConfig = builds_fn(torch.optim.SGD)
# optimizer_store(AdamConfig, name="adam")
# optimizer_store(SGDConfig, name="sgd")

for optimizer_name, optimizer_type in [
(k, v)
for k, v in vars(torch.optim).items()
if inspect.isclass(v)
and issubclass(v, torch.optim.Optimizer)
and v is not torch.optim.Optimizer
]:
_algo_config = builds_fn(optimizer_type, zen_dataclass={"cls_name": f"{optimizer_name}Config"})
optimizer_store(_algo_config, name=optimizer_name)
logger.debug(f"Registering config for the {optimizer_type} optimizer.")


def __getattr__(config_name: str):
if not config_name.endswith("Config"):
raise AttributeError
optimizer_name = config_name.removesuffix("Config")
# the keys for the config store are tuples of the form (group, config_name)
store_key = ("algorithm/optimizer", optimizer_name)
if store_key in optimizer_store["algorithm/optimizer"]:
logger.debug(f"Dynamically retrieving the config for the {optimizer_name} optimizer.")
return optimizer_store[store_key]
available_optimizers = sorted(
optimizer_name for (_, optimizer_name) in optimizer_store["algorithm/optimizer"].keys()
)
logger.error(
f"Unable to find the config for optimizer {optimizer_name}. Available optimizers: {available_optimizers}."
)

raise AttributeError
File renamed without changes.
1 change: 1 addition & 0 deletions project/configs/datamodule/vision.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# todo: This config should not show up as an option on the command-line.
_target_: project.datamodules.VisionDataModule
data_dir: ${constant:DATA_DIR}
num_workers: ${constant:NUM_WORKERS}
Expand Down
20 changes: 0 additions & 20 deletions project/configs/lr_scheduler/__init__.py

This file was deleted.

14 changes: 0 additions & 14 deletions project/configs/optimizer/__init__.py

This file was deleted.

0 comments on commit 8d25f8d

Please sign in to comment.