diff --git a/project/__init__.py b/project/__init__.py index aa3d30a6..de8fc9a0 100644 --- a/project/__init__.py +++ b/project/__init__.py @@ -1,6 +1,7 @@ from . import algorithms, configs, datamodules, experiment, main, networks, utils from .configs import Config, add_configs_to_hydra_store from .experiment import Experiment +from .utils.hydra_utils import patched_safe_name # noqa # from .networks import FcNet from .utils.types import DataModule diff --git a/project/configs/algorithm/__init__.py b/project/configs/algorithm/__init__.py index de418e63..565a750c 100644 --- a/project/configs/algorithm/__init__.py +++ b/project/configs/algorithm/__init__.py @@ -13,7 +13,7 @@ # `configs/algorithm`. From the command-line, you can select both configs that are yaml files as # well as structured config (dataclasses). -# If you add a configuration file under `configs/algorithm`, it will also be available as an option +# If you add a configuration file under `project/configs/algorithm`, it will also be available as an option # from the command-line, and can use these configs in their default list. algorithm_store = store(group="algorithm") diff --git a/project/utils/__init__.py b/project/utils/__init__.py index d316a228..e9cfa161 100644 --- a/project/utils/__init__.py +++ b/project/utils/__init__.py @@ -1,5 +1,9 @@ from .device import default_device +# Import this patch for https://github.com/mit-ll-responsible-ai/hydra-zen/issues/705 to make sure that it gets applied. +from .hydra_utils import patched_safe_name + __all__ = [ "default_device", + "patched_safe_name", ] diff --git a/project/utils/hydra_utils.py b/project/utils/hydra_utils.py index 1ae7cc0e..07c0fbca 100644 --- a/project/utils/hydra_utils.py +++ b/project/utils/hydra_utils.py @@ -15,7 +15,9 @@ TypeVar, ) +import hydra_zen.structured_configs._utils from hydra_zen import instantiate +from hydra_zen.structured_configs._utils import safe_name from hydra_zen.typing._implementations import Partial as _Partial from omegaconf import DictConfig, OmegaConf @@ -27,6 +29,28 @@ T = TypeVar("T") +def patched_safe_name(obj: Any, repr_allowed: bool = True): + """Patches a bug in Hydra-zen where the _target_ of inner classes is incorrect: + https://github.com/mit-ll-responsible-ai/hydra-zen/issues/705 + """ + + if not hasattr(obj, "__qualname__"): + return safe_name(obj, repr_allowed=repr_allowed) + + name = safe_name(obj, repr_allowed=repr_allowed) + qualname = obj.__qualname__ + assert isinstance(qualname, str) + + if name != qualname and qualname.endswith("." + name): + logger.debug(f"Using patched fn: returning {qualname} for target {obj}") + return qualname + + return name + + +hydra_zen.structured_configs._utils.safe_name = patched_safe_name + + def interpolate_config_attribute(*attributes: str, default: Any | Literal[MISSING] = MISSING): """Use this in a config to to get an attribute from another config after it is instantiated.