Skip to content

Commit

Permalink
Fix a bug in hydra-zen for inner classes
Browse files Browse the repository at this point in the history
- Adds a patch for mit-ll-responsible-ai/hydra-zen#705

Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jul 10, 2024
1 parent 8d25f8d commit 935a164
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 1 deletion.
1 change: 1 addition & 0 deletions project/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from . import algorithms, configs, datamodules, experiment, main, networks, utils
from .configs import Config, add_configs_to_hydra_store
from .experiment import Experiment
from .utils.hydra_utils import patched_safe_name # noqa

# from .networks import FcNet
from .utils.types import DataModule
Expand Down
2 changes: 1 addition & 1 deletion project/configs/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# `configs/algorithm`. From the command-line, you can select both configs that are yaml files as
# well as structured config (dataclasses).

# If you add a configuration file under `configs/algorithm`, it will also be available as an option
# If you add a configuration file under `project/configs/algorithm`, it will also be available as an option
# from the command-line, and can use these configs in their default list.
algorithm_store = store(group="algorithm")

Expand Down
4 changes: 4 additions & 0 deletions project/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from .device import default_device

# Import this patch for https://github.com/mit-ll-responsible-ai/hydra-zen/issues/705 to make sure that it gets applied.
from .hydra_utils import patched_safe_name

__all__ = [
"default_device",
"patched_safe_name",
]
24 changes: 24 additions & 0 deletions project/utils/hydra_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
TypeVar,
)

import hydra_zen.structured_configs._utils
from hydra_zen import instantiate
from hydra_zen.structured_configs._utils import safe_name
from hydra_zen.typing._implementations import Partial as _Partial
from omegaconf import DictConfig, OmegaConf

Expand All @@ -27,6 +29,28 @@
T = TypeVar("T")


def patched_safe_name(obj: Any, repr_allowed: bool = True):
"""Patches a bug in Hydra-zen where the _target_ of inner classes is incorrect:
https://github.com/mit-ll-responsible-ai/hydra-zen/issues/705
"""

if not hasattr(obj, "__qualname__"):
return safe_name(obj, repr_allowed=repr_allowed)

name = safe_name(obj, repr_allowed=repr_allowed)
qualname = obj.__qualname__
assert isinstance(qualname, str)

if name != qualname and qualname.endswith("." + name):
logger.debug(f"Using patched fn: returning {qualname} for target {obj}")
return qualname

return name


hydra_zen.structured_configs._utils.safe_name = patched_safe_name


def interpolate_config_attribute(*attributes: str, default: Any | Literal[MISSING] = MISSING):
"""Use this in a config to to get an attribute from another config after it is instantiated.
Expand Down

0 comments on commit 935a164

Please sign in to comment.