Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a proof-of-concept for an Algorithm that uses Jax for its forward/backward passes [RT-71] #4

Merged
merged 27 commits into from
Jun 14, 2024
Merged
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
61c0662
Add an example algo that uses jax!
lebrice May 30, 2024
752afd1
Simplify the jax example
lebrice May 30, 2024
d5b400c
Slightly tweak the jax example
lebrice Jun 3, 2024
1a71d48
Tweak the jax example
lebrice Jun 3, 2024
4deb156
Tweak algo a bit (again)
lebrice Jun 4, 2024
6306ccd
Use flax nn.Module
lebrice Jun 4, 2024
a52972c
Hacky: Wrap jax fn into a torch.autograd.Function
lebrice Jun 4, 2024
8feece6
Make it work with automatic optimization and jit!
lebrice Jun 4, 2024
8d79e67
Able to use jax in intermediate node in graph!
lebrice Jun 6, 2024
0473808
Update to use git packages
lebrice Jun 6, 2024
70e7053
Rename `batch_idx`->`batch_index` everywhere
lebrice Jun 7, 2024
2c53544
Fix broken callback due to `batch_idx` rename
lebrice Jun 7, 2024
e3f4c90
Use a callback to log classification metrics
lebrice Jun 7, 2024
18f6c25
Update the jax algo
lebrice Jun 11, 2024
fcd8484
Make the callback compatible with more recent PL
lebrice Jun 12, 2024
8f595fd
Make the Jax algo usable from CLI, tweak configs
lebrice Jun 13, 2024
75765e9
Fix tests to use the tensor_regression package
lebrice Jun 13, 2024
61e204f
Fix some issues with config registration in tests
lebrice Jun 13, 2024
9a220a1
Fix other tiny issues in test code
lebrice Jun 13, 2024
542d086
Fix issue with resnet50 config
lebrice Jun 13, 2024
00cbd4f
Add some generated tests for the Jax algo example
lebrice Jun 13, 2024
cdd548c
Fix tests for algo that doesnt support jax
lebrice Jun 13, 2024
3cf5cfe
'fix' issue with doctest of some configs
lebrice Jun 13, 2024
8c159cc
Set JAX_PLATFORMS=cpu in GitHub CI
lebrice Jun 14, 2024
df3b907
Tweak build.yml again
lebrice Jun 14, 2024
1e4c8db
Fix build.yml
lebrice Jun 14, 2024
6c16f03
Set rounding precision for regression tests
lebrice Jun 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix some issues with config registration in tests
Signed-off-by: Fabrice Normandin <[email protected]>
lebrice committed Jun 13, 2024

Verified

This commit was signed with the committer’s verified signature.
lebrice Fabrice Normandin
commit 61e204f7d7c39672803098476b03049ba59941a3
2 changes: 1 addition & 1 deletion project/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@

# If you add a configuration file under `configs/algorithm`, it will also be available as an option
# from the command-line, and be validated against the schema.

# todo: It might be nicer if we did this this `configs/algorithms` instead of here, no?
algorithm_store = store(group="algorithm")
algorithm_store(ExampleAlgorithm.HParams(), name="example_algo")
algorithm_store(ManualGradientsExample.HParams(), name="manual_optimization")
2 changes: 1 addition & 1 deletion project/algorithms/bases/algorithm_test.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
from torch.utils.data import DataLoader
from typing_extensions import ParamSpec

from project.configs.config import Config, cs
from project.configs import Config, cs
from project.conftest import setup_hydra_for_tests_and_compose
from project.datamodules.image_classification import (
ImageClassificationDataModule,
5 changes: 2 additions & 3 deletions project/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -11,12 +11,11 @@
)
from .network import network_store

# todo: look into using this instead:
# from hydra_zen import store

cs = ConfigStore.instance()
cs.store(name="base_config", node=Config)
datamodule_store.add_to_hydra_store()
network_store.add_to_hydra_store()
# todo: move the algorithm_store.add_to_hydra_store() here?

__all__ = [
"Config",
6 changes: 0 additions & 6 deletions project/configs/config.py
Original file line number Diff line number Diff line change
@@ -3,8 +3,6 @@
from logging import getLogger as get_logger
from typing import Any, Literal

from hydra.core.config_store import ConfigStore

logger = get_logger(__name__)
LogLevel = Literal["debug", "info", "warning", "error", "critical"]

@@ -39,7 +37,3 @@ class Config:
debug: bool = False

verbose: bool = False


cs = ConfigStore.instance()
cs.store(name="base_config", node=Config)
25 changes: 22 additions & 3 deletions project/configs/network/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,26 @@
import hydra_zen
import torchvision.models
from hydra_zen import store

from project.networks import FcNetConfig, ResNet18Config
from project.networks.fcnet import FcNet
from project.utils.hydra_utils import interpolate_config_attribute

network_store = store(group="network")
network_store(FcNetConfig, name="fcnet")
network_store(ResNet18Config, name="resnet18")
network_store(
hydra_zen.builds(
torchvision.models.resnet18,
populate_full_signature=True,
num_classes=interpolate_config_attribute("datamodule.num_classes"),
),
name="resnet18",
)
network_store(
hydra_zen.builds(
FcNet,
hydra_convert="object",
hydra_recursive=True,
populate_full_signature=True,
output_dims=interpolate_config_attribute("datamodule.num_classes"),
),
name="fcnet",
)
2 changes: 2 additions & 0 deletions project/configs/network/jax_cnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_target_: project.algorithms.jax_algo.CNN
num_classes: ${instance_attr:datamodule.num_classes}
3 changes: 3 additions & 0 deletions project/configs/network/jax_fcnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: project.algorithms.jax_algo.JaxFcNet
num_classes: ${instance_attr:datamodule.num_classes}
num_features: 256
3 changes: 3 additions & 0 deletions project/configs/network/resnet50.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: torchvision.models.resnet50
pretrained: true
num_classes: "${instance_attr:datamodule.num_classes,datamodule.action_dims:1000}"
9 changes: 5 additions & 4 deletions project/main_test.py
Original file line number Diff line number Diff line change
@@ -4,14 +4,15 @@
import typing
from pathlib import Path

import hydra_zen
import pytest

from project.algorithms import Algorithm, ExampleAlgorithm
from project.configs.config import Config
from project.configs.datamodule import CIFAR10DataModuleConfig
from project.conftest import setup_hydra_for_tests_and_compose, use_overrides
from project.datamodules.image_classification.cifar10 import CIFAR10DataModule
from project.networks import FcNetConfig
from project.networks.fcnet import FcNet
from project.utils.hydra_utils import resolve_dictconfig

if typing.TYPE_CHECKING:
@@ -77,13 +78,13 @@ def test_setting_algorithm(
@pytest.mark.parametrize(
("overrides", "expected_type"),
[
(["algorithm=example_algo", "network=fcnet"], FcNetConfig),
(["algorithm=example_algo", "network=fcnet"], FcNet),
],
ids=_ids,
)
def test_setting_network(
overrides: list[str],
expected_type: type[Algorithm.HParams],
expected_type: type,
testing_overrides: list[str],
tmp_path: Path,
) -> None:
@@ -93,7 +94,7 @@ def test_setting_network(
) as dictconfig:
options = resolve_dictconfig(dictconfig)
assert isinstance(options, Config)
assert isinstance(options.network, expected_type)
assert hydra_zen.get_target(options.network) is expected_type


# TODO: Add some more integration tests:
23 changes: 0 additions & 23 deletions project/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -13,30 +13,7 @@
# _cs.store(group="network", name="fcnet", node=FcNetConfig)
# _cs.store(group="network", name="resnet18", node=ResNet18Config)
# Add your network configs here.
from dataclasses import field

from hydra_zen import hydrated_dataclass
from torchvision.models import resnet18

from project.utils.hydra_utils import interpolated_field

from .fcnet import FcNet


@hydrated_dataclass(target=FcNet, hydra_convert="object", hydra_recursive=True)
class FcNetConfig:
output_dims: int = interpolated_field(
"${instance_attr:datamodule.num_classes,datamodule.action_dims}", default=-1
)
hparams: FcNet.HParams = field(default_factory=FcNet.HParams)


@hydrated_dataclass(target=resnet18)
class ResNet18Config:
pretrained: bool = False
num_classes: int = interpolated_field(
"${instance_attr:datamodule.num_classes,datamodule.action_dims}", default=1000
)


__all__ = ["FcNet"]
126 changes: 80 additions & 46 deletions project/utils/hydra_utils.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,86 @@
T = TypeVar("T")


def interpolate_config_attribute(*attributes: str, default: Any | Literal[MISSING] = MISSING):
"""Use this in a config to to get an attribute from another config after it is instantiated.

Multiple attributes can be specified, which will lead to trying each of them in order until the
attribute is found. If none are found, then an error will be raised.

For example, if we only know the number of classes in the datamodule after it is instantiated,
we can set this in the network config so it is created with the right number of output dims.

```yaml
_target_: torchvision.models.resnet50
num_classes: ${instance_attr:datamodule.num_classes}
```

This is equivalent to:

>>> import hydra_zen
>>> import torchvision.models
>>> resnet50_config = hydra_zen.builds(
... torchvision.models.resnet50,
... num_classes=interpolate_config_attribute("datamodule.num_classes"),
... populate_full_signature=True,
... )
>>> print(hydra_zen.to_yaml(resnet50_config)) # doctest: +NORMALIZE_WHITESPACE
_target_: torchvision.models.resnet.resnet50
weights: null
progress: true
num_classes: ${instance_attr:datamodule.num_classes}
"""
if default is MISSING:
return "${instance_attr:" + ",".join(attributes) + "}"
return "${instance_attr:" + ",".join(attributes) + ":" + str(default) + "}"


def interpolated_field(
interpolation: str,
default: T | Literal[MISSING] = MISSING,
default_factory: Callable[[], T] | Literal[MISSING] = MISSING,
instance_attr: bool = False,
) -> T:
"""Field with a default value computed with a OmegaConf-style interpolation when appropriate.

When the dataclass is created by Hydra / OmegaConf, the interpolation is used.
Otherwise, behaves as usual (either using default or calling the default_factory).

Parameters
----------
interpolation: The string interpolation to use to get the default value.
default: The default value to use when not in a hydra/OmegaConf context.
default_factory: The default value to use when not in a hydra/OmegaConf context.
instance_attr: Whether to use the `instance_attr` custom resolver to run the interpolation \
with respect to instantiated objects instead of their configs.
Passing `interpolation='${instance_attr:some_config.some_attr}'` has the same effect.

This last parameter is important, since in order to retrieve the instance attribute, we need to
instantiate the objects, which could be expensive. These instantiated objects are reused at
least, but still, be mindful when using this parameter.
"""
assert "${" in interpolation and "}" in interpolation

if instance_attr:
if not interpolation.startswith("${instance_attr:"):
interpolation = interpolation.removeprefix("${")
interpolation = "${instance_attr:" + interpolation

if default is MISSING and default_factory is MISSING:
raise RuntimeError(
"Interpolated fields currently still require a default value or default factory for "
"when they are used outside the Hydra/OmegaConf context."
)
return field(
default_factory=functools.partial(
_default_factory,
interpolation=interpolation,
default=default,
default_factory=default_factory,
)
)


# @dataclass(init=False)
class Partial(functools.partial[T], _Partial[T]):
def __getattr__(self, name: str):
@@ -262,52 +342,6 @@ def get_instantiated_attr(
)


def interpolated_field(
interpolation: str,
default: T | Literal[MISSING] = MISSING,
default_factory: Callable[[], T] | Literal[MISSING] = MISSING,
instance_attr: bool = False,
) -> T:
"""Field with a default value computed with a OmegaConf-style interpolation when appropriate.

When the dataclass is created by Hydra / OmegaConf, the interpolation is used.
Otherwise, behaves as usual (either using default or calling the default_factory).

Parameters
----------
interpolation: The string interpolation to use to get the default value.
default: The default value to use when not in a hydra/OmegaConf context.
default_factory: The default value to use when not in a hydra/OmegaConf context.
instance_attr: Whether to use the `instance_attr` custom resolver to run the interpolation \
with respect to instantiated objects instead of their configs.
Passing `interpolation='${instance_attr:some_config.some_attr}'` has the same effect.

This last parameter is important, since in order to retrieve the instance attribute, we need to
instantiate the objects, which could be expensive. These instantiated objects are reused at
least, but still, be mindful when using this parameter.
"""
assert "${" in interpolation and "}" in interpolation

if instance_attr:
if not interpolation.startswith("${instance_attr:"):
interpolation = interpolation.removeprefix("${")
interpolation = "${instance_attr:" + interpolation

if default is MISSING and default_factory is MISSING:
raise RuntimeError(
"Interpolated fields currently still require a default value or default factory for "
"when they are used outside the Hydra/OmegaConf context."
)
return field(
default_factory=functools.partial(
_default_factory,
interpolation=interpolation,
default=default,
default_factory=default_factory,
)
)


def being_called_in_hydra_context() -> bool:
import hydra.core.utils
import omegaconf._utils
25 changes: 19 additions & 6 deletions project/utils/testutils.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
from torch import Tensor, nn
from torch.optim import Optimizer

from project.configs.config import Config, cs
from project.configs import Config, cs
from project.configs.datamodule import DATA_DIR, SLURM_JOB_ID
from project.datamodules.image_classification import (
ImageClassificationDataModule,
@@ -125,11 +125,24 @@ def _parametrized_fixture_method(request: pytest.FixtureRequest):


def get_all_configs_in_group(group_name: str) -> list[str]:
names_yaml = cs.list(group_name)
names = [name.rpartition(".")[0] for name in names_yaml]
if "base" in names:
names.remove("base")
return names
# note: here we're copying a bit of the internal code from Hydra so that we also get the
# configs that are just yaml files, in addition to the configs we added programmatically to the
# configstores.

# names_yaml = cs.list(group_name)
# names = [name.rpartition(".")[0] for name in names_yaml]
# if "base" in names:
# names.remove("base")
# return names

from hydra._internal.config_loader_impl import ConfigLoaderImpl
from hydra._internal.utils import create_automatic_config_search_path

search_path = create_automatic_config_search_path(
calling_file=None, calling_module=None, config_path="pkg://project.configs"
)
config_loader = ConfigLoaderImpl(config_search_path=search_path)
return config_loader.get_group_options(group_name)


def get_all_algorithm_names() -> list[str]:
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -63,6 +63,7 @@ build-backend = "setuptools.build_meta"

[tool.pytest.ini_options]
testpaths = ["project"]
addopts = ["--doctest-modules"]

[tool.ruff]
line-length = 99