-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create dyn. configs for optimizers and schedulers
Signed-off-by: Fabrice Normandin <[email protected]>
- Loading branch information
Showing
10 changed files
with
131 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,10 @@ | ||
defaults: | ||
- optimizer/[email protected] | ||
- lr_scheduler/[email protected]_scheduler | ||
# Apply the `algorithm/optimizer/Adam` config at `hp.optimizer` in this config. | ||
- optimizer/[email protected] | ||
- lr_scheduler/[email protected]_scheduler | ||
_target_: project.algorithms.example.ExampleAlgorithm | ||
_partial_: true | ||
hp: | ||
_target_: project.algorithms.example.ExampleAlgorithm.HParams | ||
lr_scheduler: | ||
step_size: 1 # Required argument for the StepLR scheduler. (reduce LR every {step_size} epochs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import dataclasses | ||
import inspect | ||
from logging import getLogger as get_logger | ||
|
||
import torch | ||
import torch.optim.lr_scheduler | ||
from hydra_zen import make_custom_builds_fn, store | ||
|
||
logger = get_logger(__name__) | ||
|
||
builds_fn = make_custom_builds_fn(zen_partial=True, populate_full_signature=True) | ||
|
||
# LR Schedulers whose constructors have arguments with missing defaults have to be created manually, | ||
# because we otherwise get some errors if we try to use them (e.g. T_max doesn't have a default.) | ||
|
||
CosineAnnealingLRConfig = builds_fn(torch.optim.lr_scheduler.CosineAnnealingLR, T_max="???") | ||
StepLRConfig = builds_fn(torch.optim.lr_scheduler.StepLR, step_size="???") | ||
lr_scheduler_store = store(group="algorithm/lr_scheduler") | ||
lr_scheduler_store(StepLRConfig, name="StepLR") | ||
lr_scheduler_store(CosineAnnealingLRConfig, name="CosineAnnealingLR") | ||
|
||
|
||
# IDEA: Could be interesting to generate configs for any member of the torch.optimizer.lr_scheduler | ||
# package dynamically (and store it)? | ||
# def __getattr__(self, name: str): | ||
# """""" | ||
|
||
_configs_defined_so_far = [k for k, v in locals().items() if dataclasses.is_dataclass(v)] | ||
for scheduler_name, scheduler_type in [ | ||
(_name, _obj) | ||
for _name, _obj in vars(torch.optim.lr_scheduler).items() | ||
if inspect.isclass(_obj) | ||
and issubclass(_obj, torch.optim.lr_scheduler.LRScheduler) | ||
and _obj is not torch.optim.lr_scheduler.LRScheduler | ||
]: | ||
_config_name = f"{scheduler_name}Config" | ||
if _config_name in _configs_defined_so_far: | ||
# We already have a hand-made config for this scheduler. Skip it. | ||
continue | ||
|
||
_lr_scheduler_config = builds_fn(scheduler_type, zen_dataclass={"cls_name": _config_name}) | ||
lr_scheduler_store(_lr_scheduler_config, name=scheduler_name) | ||
logger.debug(f"Registering config for the {scheduler_type} LR scheduler.") | ||
|
||
|
||
def __getattr__(config_name: str): | ||
if not config_name.endswith("Config"): | ||
raise AttributeError | ||
scheduler_name = config_name.removesuffix("Config") | ||
# the keys for the config store are tuples of the form (group, config_name) | ||
group = "algorithm/lr_scheduler" | ||
store_key = (group, scheduler_name) | ||
if store_key in lr_scheduler_store[group]: | ||
logger.debug(f"Dynamically retrieving the config for the {scheduler_name} LR scheduler.") | ||
return lr_scheduler_store[store_key] | ||
available_configs = sorted( | ||
config_name for (_group, config_name) in lr_scheduler_store[group].keys() | ||
) | ||
logger.error( | ||
f"Unable to find the config for {scheduler_name=}. Available configs: {available_configs}." | ||
) | ||
|
||
raise AttributeError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import inspect | ||
from logging import getLogger as get_logger | ||
|
||
import torch | ||
import torch.optim | ||
from hydra_zen import make_custom_builds_fn, store | ||
|
||
logger = get_logger(__name__) | ||
builds_fn = make_custom_builds_fn(zen_partial=True, populate_full_signature=True) | ||
|
||
optimizer_store = store(group="algorithm/optimizer") | ||
# AdamConfig = builds_fn(torch.optim.Adam) | ||
# SGDConfig = builds_fn(torch.optim.SGD) | ||
# optimizer_store(AdamConfig, name="adam") | ||
# optimizer_store(SGDConfig, name="sgd") | ||
|
||
for optimizer_name, optimizer_type in [ | ||
(k, v) | ||
for k, v in vars(torch.optim).items() | ||
if inspect.isclass(v) | ||
and issubclass(v, torch.optim.Optimizer) | ||
and v is not torch.optim.Optimizer | ||
]: | ||
_algo_config = builds_fn(optimizer_type, zen_dataclass={"cls_name": f"{optimizer_name}Config"}) | ||
optimizer_store(_algo_config, name=optimizer_name) | ||
logger.debug(f"Registering config for the {optimizer_type} optimizer.") | ||
|
||
|
||
def __getattr__(config_name: str): | ||
if not config_name.endswith("Config"): | ||
raise AttributeError | ||
optimizer_name = config_name.removesuffix("Config") | ||
# the keys for the config store are tuples of the form (group, config_name) | ||
store_key = ("algorithm/optimizer", optimizer_name) | ||
if store_key in optimizer_store["algorithm/optimizer"]: | ||
logger.debug(f"Dynamically retrieving the config for the {optimizer_name} optimizer.") | ||
return optimizer_store[store_key] | ||
available_optimizers = sorted( | ||
optimizer_name for (_, optimizer_name) in optimizer_store["algorithm/optimizer"].keys() | ||
) | ||
logger.error( | ||
f"Unable to find the config for optimizer {optimizer_name}. Available optimizers: {available_optimizers}." | ||
) | ||
|
||
raise AttributeError |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.