diff --git a/docs/pages/examples.rst b/docs/pages/examples.rst index e385df4..36229ed 100644 --- a/docs/pages/examples.rst +++ b/docs/pages/examples.rst @@ -252,16 +252,15 @@ Creating ``pd.DataFrame`` from config # By default, hydra-slayer use partial fit for functions # (what is useful with activation functions in neural networks). # But if we want to call ``pandas.read_csv`` function instead, - # then we should pass ``call_meta_factory`` manually. - meta_factory: &call_function - _target_: hydra_slayer.call_meta_factory + # then we should set ``call`` mode manually. + _mode_: call right: _target_: pandas.read_csv filepath_or_buffer: dataset/dataset_part2.csv - meta_factory: *call_function + _mode_: call how: inner 'on': user - meta_factory: *call_function + _mode_: call .. code-block:: python @@ -319,11 +318,10 @@ Extending configs # config.yaml dataset: - _target_: hydra_slayer.get_from_params # ``yaml.safe_load`` will return dictionary with parameters, # but to get ``DataLoader`` additional ``hydra_slayer.get_from_params`` # should be used. - + _target_: hydra_slayer.get_from_params kwargs: # Read dataset from "dataset.yaml", roughly equivalent to # with open("dataset.yaml") as stream: @@ -332,16 +330,13 @@ Extending configs stream: _target_: open file: dataset.yaml - meta_factory: &call_function - _target_: hydra_slayer.call_meta_factory - - meta_factory: *call_function + _mode_: call + _mode_: call model: _target_: torchvision.models.resnet18 pretrained: true - meta_factory: - _target_: hydra_slayer.call_meta_factory + _mode_: call criterion: _target_: torch.nn.CrossEntropyLoss diff --git a/hydra_slayer/factory.py b/hydra_slayer/factory.py index 5e56072..8a075c1 100644 --- a/hydra_slayer/factory.py +++ b/hydra_slayer/factory.py @@ -1,4 +1,5 @@ from typing import Any, Callable, Mapping, Tuple, Type, Union +import copy import functools import inspect @@ -7,6 +8,8 @@ Factory = Union[Type, Callable[..., Any]] MetaFactory = Callable[[Factory, Tuple, Mapping], Any] +DEFAULT_CALL_MODE_KEY = "_mode_" + def call_meta_factory(factory: Factory, args: Tuple, kwargs: Mapping): """Creates a new instance from ``factory``. @@ -41,10 +44,24 @@ def partial_meta_factory(factory: Factory, args: Tuple, kwargs: Mapping): def default_meta_factory(factory: Factory, args: Tuple, kwargs: Mapping): - """ - Creates a new instance from ``factory`` if ``factory`` is class - (like :py:func:`call_meta_factory`), else returns a new partial object - (like :py:func:`partial_meta_factory`). + """Returns a new instance or a new partial object. + + * _mode_='auto' + + Creates a new instance from ``factory`` if ``factory`` is class + (like :py:func:`call_meta_factory`), else returns a new partial object + (like :py:func:`partial_meta_factory`). + + * _mode_='call' + + Returns a result of the factory called with the positional arguments + ``args`` and keyword arguments ``kwargs``. + + * _mode_='partial' + + Returns a new partial object which when called will behave like factory + called with the positional arguments ``args`` and keyword arguments + ``kwargs``. Args: factory: factory to create instance from @@ -54,7 +71,33 @@ def default_meta_factory(factory: Factory, args: Tuple, kwargs: Mapping): Returns: Instance. + Raises: + ValueError: if mode not in list: ``'auto'``, ``'call'``, ``'partial'``. + + Examples: + >>> default_meta_factory(int, (42,)) + 42 + + >>> # please note that additional () are used + >>> default_meta_factory(lambda x: x, (42,))() + 42 + + >>> default_meta_factory(int, ('42',), {"base": 16}) + 66 + + >>> # please note that additional () are not needed + >>> default_meta_factory(lambda x: x, (42,), {"_mode_": "call"}) + 42 + + >>> default_meta_factory(lambda x: x, ('42',), {"_mode_": "partial", "base": 16})() + 66 """ - if inspect.isfunction(factory): + # make a copy of kwargs since we don't want to modify them directly + kwargs = copy.copy(kwargs) + mode = kwargs.pop(DEFAULT_CALL_MODE_KEY, "auto") + if mode not in {"auto", "call", "partial"}: + raise ValueError(f"`{mode}` is not a valid call mode") + + if mode == "auto" and inspect.isfunction(factory) or mode == "partial": return partial_meta_factory(factory, args, kwargs) return call_meta_factory(factory, args, kwargs) diff --git a/hydra_slayer/functional.py b/hydra_slayer/functional.py index e61ca80..8a64635 100644 --- a/hydra_slayer/functional.py +++ b/hydra_slayer/functional.py @@ -194,10 +194,12 @@ def get_from_params(*, shared_params: Optional[Dict[str, Any]] = None, **kwargs) Creates instance based in configuration dict with ``instantiation_fn``. Note: - The name of the factory to use should be provided by ``'_target_'`` keyword. + The name of the factory to use should be provided + by ``'_target_'`` keyword. Args: - shared_params: params to pass on all levels in case of recursive creation + shared_params: params to pass on all levels in case of + recursive creation **kwargs: named parameters for factory Returns: diff --git a/hydra_slayer/registry.py b/hydra_slayer/registry.py index bd9adfa..8174970 100644 --- a/hydra_slayer/registry.py +++ b/hydra_slayer/registry.py @@ -117,7 +117,8 @@ def add_from_module( module: module to scan prefix: prefix string for all the module's factories. If prefix is a list, all values will be treated as aliases - ignore_all: if ``True``, ignores ``__all__`` attribute of the module + ignore_all: if ``True``, ignores ``__all__`` attribute + of the module Raises: TypeError: if prefix is not a list or a string @@ -203,7 +204,8 @@ def get_from_params( If ``config[name_key]`` is None, ``None`` is returned. Args: - shared_params: params to pass on all levels in case of recursive creation + shared_params: params to pass on all levels in case of + recursive creation **kwargs: \*\*kwargs to pass to the factory Returns: diff --git a/tests/test_factory.py b/tests/test_factory.py index 5b3545b..b28251e 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -38,3 +38,35 @@ def test_default_meta_factory(): res = default_meta_factory(lambda x: x, (42,), {}) assert res() == 42 + + +def test_default_meta_factory_mode(): + # `int` is class, so `call_meta_factory` is expected + res = default_meta_factory(int, (42,), {"_mode_": "auto"}) + + assert res == 42 + + # `lambda` is function, so `partial_meta_factory` is expected + res = default_meta_factory(lambda x: x, (42,), {"_mode_": "auto"}) + + assert res() == 42 + + # _mode_='call', so `call_meta_factory` is expected + res = default_meta_factory(int, (42,), {"_mode_": "call"}) + + assert res == 42 + + # _mode_='call', so `call_meta_factory` is expected + res = default_meta_factory(lambda x: x, (42,), {"_mode_": "call"}) + + assert res == 42 + + # _mode_='partial', so `partial_meta_factory` is expected + res = default_meta_factory(int, (42,), {"_mode_": "partial"}) + + assert res() == 42 + + # _mode_='partial', so `partial_meta_factory` is expected + res = default_meta_factory(lambda x: x, (42,), {"_mode_": "partial"}) + + assert res() == 42