From 8a03919b55a911b799f9169f80c4c251385fb10b Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 12 Sep 2023 12:22:10 -0400 Subject: [PATCH 1/7] RegistryMixin - tooling for easier registry/plugin patterns across NM repos --- src/sparsezoo/utils/registry.py | 168 ++++++++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 src/sparsezoo/utils/registry.py diff --git a/src/sparsezoo/utils/registry.py b/src/sparsezoo/utils/registry.py new file mode 100644 index 00000000..24b626ae --- /dev/null +++ b/src/sparsezoo/utils/registry.py @@ -0,0 +1,168 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Universal registry to support registration and loading of child classes and plugins +of neuralmagic utilities +""" + +import importlib +from collections import defaultdict +from typing import Any, Dict, Optional, Type + + +_REGISTRY: Dict[Type, Dict[str, Any]] = defaultdict(dict) + + +class RegistryMixin: + """ + Universal registry to support registration and loading of child classes and plugins + of neuralmagic utilities. + + Classes that require a registry or plugins may add the `RegistryMixin` and use + `register` and `load` as the main entrypoints for adding new implementations and + loading requested values from its registry. + + example + ```python + class Dataset(RegistryMixin): + pass + + + # register with default name + @Dataset.register() + class ImageNetDataset(Dataset) + pass + + # load as "ImageNetDataset" + imagenet = Dataset.load("ImageNetDataset") + + # register with custom name + @Dataset.register(name="cifar-dataset") + class Cifar(Dataset): + pass + + # load as "cifar-dataset" + cifar = Dataset.load("cifar-dataset") + + # load from custom file that implements a dataset + mnist = Dataset.load("/path/to/mnnist_dataset.py:MnistDataset") + ``` + """ + + @classmethod + def register(cls, name: Optional[str] = None): + def decorator(value: Any): + cls.register_value(value, name=name) + return value + + return decorator + + @classmethod + def register_value( + cls, value: Any, name: Optional[str] = None, require_subclass: bool = False + ): + _register( + parent_class=cls, + value=value, + name=name, + require_subclass=require_subclass, + ) + + @classmethod + def load( + cls, class_name: str, require_subclass: bool = False, **constructor_kwargs + ): + constructor = cls.get_value( + class_name=class_name, require_subclass=require_subclass + ) + return constructor(**constructor_kwargs) + + @classmethod + def get_value(cls, class_name: str, require_subclass: bool = False): + return _get_from_registry( + parent_class=cls, name=class_name, require_subclass=require_subclass + ) + + +def _register( + parent_class: Type, + value: Any, + name: Optional[str] = None, + require_subclass: bool = False, +): + if name is None: + name = value.__name__ + + if require_subclass: + _validate_subclass(parent_class, value) + + if name in _REGISTRY[parent_class]: + registered_value = _REGISTRY[parent_class][name] + if registered_value is not value: + raise RuntimeError( + f"Attempting to register name {name} as {value} " + f"however {name} has already been registered as {registered_value}" + ) + else: + _REGISTRY[parent_class][name] = value + + +def _get_from_registry( + parent_class: Type, name: str, require_subclass: bool = False +) -> Any: + + if ":" in name: + # user specifying specific module to load and value to import + module_path, value_name = name.split(":") + retrieved_value = _import_and_get_value_from_module(module_path, value_name) + else: + retrieved_value = _REGISTRY[parent_class].get(name) + if retrieved_value is None: + raise ValueError( + f"Unable to find {name} registered under type {parent_class}. " + f"Registered values for {parent_class}: " + f"{list(_REGISTRY[parent_class].keys())}" + ) + + if require_subclass: + _validate_subclass(parent_class, retrieved_value) + + return retrieved_value + + +def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any: + # load module + spec = importlib.util.spec_from_file_location( + f"plugin_module_for_{value_name}", module_path + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # get value from module + value = getattr(module, value_name) + + if not value: + raise RuntimeError( + f"Unable to find attribute {value_name} in module {module_path}" + ) + return value + + +def _validate_subclass(parent_class: Type, child_class: Type): + if not issubclass(child_class, parent_class): + raise ValueError( + f"class {child_class} is not a subclass of the class it is " + f"registered for: {parent_class}." + ) From 0594e3efb534d09f1d3d263b3c31cddd2f5fc625 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 12 Sep 2023 13:33:32 -0400 Subject: [PATCH 2/7] add 'registry' to method names --- src/sparsezoo/utils/registry.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sparsezoo/utils/registry.py b/src/sparsezoo/utils/registry.py index 24b626ae..db55c2ab 100644 --- a/src/sparsezoo/utils/registry.py +++ b/src/sparsezoo/utils/registry.py @@ -54,10 +54,10 @@ class Cifar(Dataset): pass # load as "cifar-dataset" - cifar = Dataset.load("cifar-dataset") + cifar = Dataset.load_from_registry("cifar-dataset") # load from custom file that implements a dataset - mnist = Dataset.load("/path/to/mnnist_dataset.py:MnistDataset") + mnist = Dataset.load_from_registry("/path/to/mnnist_dataset.py:MnistDataset") ``` """ @@ -81,7 +81,7 @@ def register_value( ) @classmethod - def load( + def load_from_registry( cls, class_name: str, require_subclass: bool = False, **constructor_kwargs ): constructor = cls.get_value( @@ -90,7 +90,7 @@ def load( return constructor(**constructor_kwargs) @classmethod - def get_value(cls, class_name: str, require_subclass: bool = False): + def get_value_from_registry(cls, class_name: str, require_subclass: bool = False): return _get_from_registry( parent_class=cls, name=class_name, require_subclass=require_subclass ) From ce3ac2914d60320652c22163f32bc9c63bc0d44b Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 12 Sep 2023 13:42:29 -0400 Subject: [PATCH 3/7] add registered_names --- src/sparsezoo/utils/registry.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/sparsezoo/utils/registry.py b/src/sparsezoo/utils/registry.py index db55c2ab..eb28d2df 100644 --- a/src/sparsezoo/utils/registry.py +++ b/src/sparsezoo/utils/registry.py @@ -19,7 +19,7 @@ import importlib from collections import defaultdict -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, List, Optional, Type _REGISTRY: Dict[Type, Dict[str, Any]] = defaultdict(dict) @@ -95,6 +95,10 @@ def get_value_from_registry(cls, class_name: str, require_subclass: bool = False parent_class=cls, name=class_name, require_subclass=require_subclass ) + @classmethod + def registered_names(cls) -> List[str]: + return list(_REGISTRY[cls].keys()) + def _register( parent_class: Type, From b301f943ac79aede809fc97f9df07a8c2b2f27d6 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 12 Sep 2023 15:28:40 -0400 Subject: [PATCH 4/7] testing --- src/sparsezoo/utils/registry.py | 97 ++++++++++++++++++++++---- tests/sparsezoo/utils/test_registry.py | 60 ++++++++++++++++ 2 files changed, 145 insertions(+), 12 deletions(-) create mode 100644 tests/sparsezoo/utils/test_registry.py diff --git a/src/sparsezoo/utils/registry.py b/src/sparsezoo/utils/registry.py index eb28d2df..5eb0d8a4 100644 --- a/src/sparsezoo/utils/registry.py +++ b/src/sparsezoo/utils/registry.py @@ -22,6 +22,14 @@ from typing import Any, Dict, List, Optional, Type +__all__ = [ + "RegistryMixin", + "register", + "get_from_registry", + "registered_names", +] + + _REGISTRY: Dict[Type, Dict[str, Any]] = defaultdict(dict) @@ -63,6 +71,14 @@ class Cifar(Dataset): @classmethod def register(cls, name: Optional[str] = None): + """ + Decorator for registering a value (ie class or function) wrapped by this + decorator to the base class (class that .register is called from) + + :param name: name to register the wrapped value as, defaults to value.__name__ + :return: register decorator + """ + def decorator(value: Any): cls.register_value(value, name=name) return value @@ -73,7 +89,14 @@ def decorator(value: Any): def register_value( cls, value: Any, name: Optional[str] = None, require_subclass: bool = False ): - _register( + """ + Registers the given value to the class `.register_value` is called from + :param value: value to register + :param name: name to register the wrapped value as, defaults to value.__name__ + :param require_subclass: require that value is a subclass of the class this + method is called from + """ + register( parent_class=cls, value=value, name=name, @@ -82,37 +105,67 @@ def register_value( @classmethod def load_from_registry( - cls, class_name: str, require_subclass: bool = False, **constructor_kwargs - ): - constructor = cls.get_value( - class_name=class_name, require_subclass=require_subclass + cls, name: str, require_subclass: bool = False, **constructor_kwargs + ) -> object: + """ + :param name: name of registered class to load + :param require_subclass: require that object is a subclass of the class this + method is called from + :param constructor_kwargs: arguments to pass to the constructor retrieved + from the registry + :return: loaded object registered to this class under the given name, + constructed with the given kwargs. Raises error if the name is + not found in the registry + """ + constructor = cls.get_value_from_registry( + name=name, require_subclass=require_subclass ) return constructor(**constructor_kwargs) @classmethod - def get_value_from_registry(cls, class_name: str, require_subclass: bool = False): - return _get_from_registry( - parent_class=cls, name=class_name, require_subclass=require_subclass + def get_value_from_registry(cls, name: str, require_subclass: bool = False): + """ + :param name: name to retrieve from the registry + :param require_subclass: require that value is a subclass of the class this + method is called from + :return: value from retrieved the registry for the given name, raises + error if not found + """ + return get_from_registry( + parent_class=cls, name=name, require_subclass=require_subclass ) @classmethod def registered_names(cls) -> List[str]: - return list(_REGISTRY[cls].keys()) + """ + :return: list of all names registered to this class + """ + return registered_names(cls) -def _register( +def register( parent_class: Type, value: Any, name: Optional[str] = None, require_subclass: bool = False, ): + """ + :param parent_class: class to register the name under + :param value: value to register + :param name: name to register the wrapped value as, defaults to value.__name__ + :param require_subclass: require that value is a subclass of the class this + method is called from + """ if name is None: + # default name name = value.__name__ if require_subclass: _validate_subclass(parent_class, value) if name in _REGISTRY[parent_class]: + # name already exists - raise error if two different values are attempting + # to share the same name registered_value = _REGISTRY[parent_class][name] if registered_value is not value: raise RuntimeError( @@ -123,21 +176,30 @@ def _register( _REGISTRY[parent_class][name] = value -def _get_from_registry( +def get_from_registry( parent_class: Type, name: str, require_subclass: bool = False ) -> Any: + """ + :param parent_class: class that the name is registered under + :param name: name to retrieve from the registry of the class + :param require_subclass: require that value is a subclass of the class this + method is called from + :return: value from retrieved the registry for the given name, raises + error if not found + """ if ":" in name: # user specifying specific module to load and value to import module_path, value_name = name.split(":") retrieved_value = _import_and_get_value_from_module(module_path, value_name) else: + # look up name in registry retrieved_value = _REGISTRY[parent_class].get(name) if retrieved_value is None: raise ValueError( f"Unable to find {name} registered under type {parent_class}. " f"Registered values for {parent_class}: " - f"{list(_REGISTRY[parent_class].keys())}" + f"{registered_names(parent_class)}" ) if require_subclass: @@ -146,7 +208,18 @@ def _get_from_registry( return retrieved_value +def registered_names(parent_class: Type) -> List[str]: + """ + :param parent_class: class to look up the registry of + :return: all names registered to the given class + """ + return list(_REGISTRY[parent_class].keys()) + + def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any: + # import the given module path and try to get the value_name if it is included + # in the module + # load module spec = importlib.util.spec_from_file_location( f"plugin_module_for_{value_name}", module_path diff --git a/tests/sparsezoo/utils/test_registry.py b/tests/sparsezoo/utils/test_registry.py new file mode 100644 index 00000000..410b9414 --- /dev/null +++ b/tests/sparsezoo/utils/test_registry.py @@ -0,0 +1,60 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from sparsezoo.utils.registry import RegistryMixin + + +def test_registery_flow_single(): + class Foo(RegistryMixin): + pass + + @Foo.register() + class Foo1(Foo): + pass + + @Foo.register(name="name_2") + class Foo2(Foo): + pass + + assert {"Foo1", "name_2"} == set(Foo.registered_names()) + + with pytest.raises(ValueError): + Foo.get_value_from_registry("Foo2") + + assert Foo.get_value_from_registry("Foo1") is Foo1 + assert isinstance(Foo.load_from_registry("name_2"), Foo2) + + +def test_registry_flow_multiple(): + class Foo(RegistryMixin): + pass + + class Bar(RegistryMixin): + pass + + @Foo.register() + class Foo1(Foo): + pass + + @Bar.register() + class Bar1(Bar): + pass + + assert ["Foo1"] == Foo.registered_names() + assert ["Bar1"] == Bar.registered_names() + + assert Foo.get_value_from_registry("Foo1") is Foo1 + assert Bar.get_value_from_registry("Bar1") is Bar1 From b3e098fb26e8625f7bd37eddb7a1cff065a32468 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 12 Sep 2023 16:50:31 -0400 Subject: [PATCH 5/7] add class level setting for requires_subclass --- src/sparsezoo/utils/registry.py | 33 ++++++++++++-------------- tests/sparsezoo/utils/test_registry.py | 15 ++++++++++++ 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/src/sparsezoo/utils/registry.py b/src/sparsezoo/utils/registry.py index 5eb0d8a4..f447c8d5 100644 --- a/src/sparsezoo/utils/registry.py +++ b/src/sparsezoo/utils/registry.py @@ -42,6 +42,9 @@ class RegistryMixin: `register` and `load` as the main entrypoints for adding new implementations and loading requested values from its registry. + If a class should only have its child classes in its registry, the class should + set the static attribute `registry_requires_subclass` to True + example ```python class Dataset(RegistryMixin): @@ -69,6 +72,10 @@ class Cifar(Dataset): ``` """ + # set to True in child class to add check that registered/retrieved values + # implement the class it is registered to + registry_requires_subclass: bool = False + @classmethod def register(cls, name: Optional[str] = None): """ @@ -86,53 +93,43 @@ def decorator(value: Any): return decorator @classmethod - def register_value( - cls, value: Any, name: Optional[str] = None, require_subclass: bool = False - ): + def register_value(cls, value: Any, name: Optional[str] = None): """ Registers the given value to the class `.register_value` is called from :param value: value to register :param name: name to register the wrapped value as, defaults to value.__name__ - :param require_subclass: require that value is a subclass of the class this - method is called from """ register( parent_class=cls, value=value, name=name, - require_subclass=require_subclass, + require_subclass=cls.registry_requires_subclass, ) @classmethod - def load_from_registry( - cls, name: str, require_subclass: bool = False, **constructor_kwargs - ) -> object: + def load_from_registry(cls, name: str, **constructor_kwargs) -> object: """ :param name: name of registered class to load - :param require_subclass: require that object is a subclass of the class this - method is called from :param constructor_kwargs: arguments to pass to the constructor retrieved from the registry :return: loaded object registered to this class under the given name, constructed with the given kwargs. Raises error if the name is not found in the registry """ - constructor = cls.get_value_from_registry( - name=name, require_subclass=require_subclass - ) + constructor = cls.get_value_from_registry(name=name) return constructor(**constructor_kwargs) @classmethod - def get_value_from_registry(cls, name: str, require_subclass: bool = False): + def get_value_from_registry(cls, name: str): """ :param name: name to retrieve from the registry - :param require_subclass: require that value is a subclass of the class this - method is called from :return: value from retrieved the registry for the given name, raises error if not found """ return get_from_registry( - parent_class=cls, name=name, require_subclass=require_subclass + parent_class=cls, + name=name, + require_subclass=cls.registry_requires_subclass, ) @classmethod diff --git a/tests/sparsezoo/utils/test_registry.py b/tests/sparsezoo/utils/test_registry.py index 410b9414..b1fcbb3b 100644 --- a/tests/sparsezoo/utils/test_registry.py +++ b/tests/sparsezoo/utils/test_registry.py @@ -58,3 +58,18 @@ class Bar1(Bar): assert Foo.get_value_from_registry("Foo1") is Foo1 assert Bar.get_value_from_registry("Bar1") is Bar1 + + +def test_registry_requires_subclass(): + class Foo(RegistryMixin): + registry_requires_subclass = True + + @Foo.register() + class Foo1(Foo): + pass + + with pytest.raises(ValueError): + + @Foo.register() + class NotFoo: + pass From 881e868c071fe1e5e983daa7bfc34fabaa4d294f Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 12 Sep 2023 16:52:56 -0400 Subject: [PATCH 6/7] docstring code example typo --- src/sparsezoo/utils/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparsezoo/utils/registry.py b/src/sparsezoo/utils/registry.py index f447c8d5..af6c62fd 100644 --- a/src/sparsezoo/utils/registry.py +++ b/src/sparsezoo/utils/registry.py @@ -53,7 +53,7 @@ class Dataset(RegistryMixin): # register with default name @Dataset.register() - class ImageNetDataset(Dataset) + class ImageNetDataset(Dataset): pass # load as "ImageNetDataset" From 7b3e3bc2d845fc23d2c44f281568d159f79ce821 Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Thu, 5 Oct 2023 14:21:22 -0400 Subject: [PATCH 7/7] review suggestion Co-authored-by: Rahul Tuli --- src/sparsezoo/utils/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparsezoo/utils/registry.py b/src/sparsezoo/utils/registry.py index af6c62fd..679d977e 100644 --- a/src/sparsezoo/utils/registry.py +++ b/src/sparsezoo/utils/registry.py @@ -225,7 +225,7 @@ def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any: spec.loader.exec_module(module) # get value from module - value = getattr(module, value_name) + value = getattr(module, value_name, None) if not value: raise RuntimeError(