diff --git a/metaflow/plugins/pypi/conda_decorator.py b/metaflow/plugins/pypi/conda_decorator.py index fc0e4d83a1b..b1b7ee833d9 100644 --- a/metaflow/plugins/pypi/conda_decorator.py +++ b/metaflow/plugins/pypi/conda_decorator.py @@ -49,6 +49,13 @@ class CondaStepDecorator(StepDecorator): # CONDA_CHANNELS in their environment. For pinning specific packages to specific # conda channels, users can specify channel::package as the package name. + def __init__(self, attributes=None, statically_defined=False): + self._attributes_with_user_values = ( + set(attributes.keys()) if attributes is not None else set() + ) + + super(CondaStepDecorator, self).__init__(attributes, statically_defined) + def init(self): super(CondaStepDecorator, self).init() @@ -58,9 +65,11 @@ def init(self): **self.attributes["packages"], } del self.attributes["libraries"] + if self.attributes["packages"]: + self._attributes_with_user_values.add("packages") def is_attribute_user_defined(self, name): - return name in self._user_defined_attributes + return name in self._attributes_with_user_values def step_init(self, flow, graph, step, decos, environment, flow_datastore, logger): # The init_environment hook for Environment creates the relevant virtual @@ -80,9 +89,10 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge **super_attributes["packages"], **self.attributes["packages"], } - self._user_defined_attributes = self._user_defined_attributes.union( - conda_base._user_defined_attributes + self._attributes_with_user_values.update( + conda_base._attributes_with_user_values ) + self.attributes["python"] = ( self.attributes["python"] or super_attributes["python"] ) @@ -329,22 +339,15 @@ class CondaFlowDecorator(FlowDecorator): } def __init__(self, attributes=None, statically_defined=False): - self._user_defined_attributes = ( - attributes.copy() if attributes is not None else {} + self._attributes_with_user_values = ( + set(attributes.keys()) if attributes is not None else set() ) + super(CondaFlowDecorator, self).__init__(attributes, statically_defined) def init(self): super(CondaFlowDecorator, self).init() - # We have to go back and fixup _user_defined_attributes for potential - # config resolution - self._user_defined_attributes = { - k: v - for k, v in self.attributes.items() - if k in self._user_defined_attributes - } - # Support legacy 'libraries=' attribute for the decorator. self.attributes["packages"] = { **self.attributes["libraries"], @@ -355,7 +358,7 @@ def init(self): self.attributes["python"] = str(self.attributes["python"]) def is_attribute_user_defined(self, name): - return name in self._user_defined_attributes + return name in self._attributes_with_user_values def flow_init( self, flow, graph, environment, flow_datastore, metadata, logger, echo, options diff --git a/metaflow/plugins/pypi/pypi_decorator.py b/metaflow/plugins/pypi/pypi_decorator.py index 575531eb458..bcc7dda891a 100644 --- a/metaflow/plugins/pypi/pypi_decorator.py +++ b/metaflow/plugins/pypi/pypi_decorator.py @@ -24,6 +24,13 @@ class PyPIStepDecorator(StepDecorator): name = "pypi" defaults = {"packages": {}, "python": None, "disabled": None} # wheels + def __init__(self, attributes=None, statically_defined=False): + self._attributes_with_user_values = ( + set(attributes.keys()) if attributes is not None else set() + ) + + super().__init__(attributes, statically_defined) + def step_init(self, flow, graph, step, decos, environment, flow_datastore, logger): # The init_environment hook for Environment creates the relevant virtual # environments. The step_init hook sets up the relevant state for that hook to @@ -36,8 +43,8 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge if "pypi_base" in self.flow._flow_decorators: pypi_base = self.flow._flow_decorators["pypi_base"][0] super_attributes = pypi_base.attributes - self._user_defined_attributes = self._user_defined_attributes.union( - pypi_base._user_defined_attributes + self._attributes_with_user_values.update( + pypi_base._attributes_with_user_values ) self.attributes["packages"] = { **super_attributes["packages"], @@ -99,7 +106,7 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge environment.set_local_root(LocalStorage.get_datastore_root_from_config(logger)) def is_attribute_user_defined(self, name): - return name in self._user_defined_attributes + return name in self._attributes_with_user_values class PyPIFlowDecorator(FlowDecorator): @@ -122,9 +129,10 @@ class PyPIFlowDecorator(FlowDecorator): defaults = {"packages": {}, "python": None, "disabled": None} def __init__(self, attributes=None, statically_defined=False): - self._user_defined_attributes = ( - attributes.copy() if attributes is not None else {} + self._attributes_with_user_values = ( + set(attributes.keys()) if attributes is not None else set() ) + super().__init__(attributes, statically_defined) def flow_init( diff --git a/test/unit/test_conda_decorator.py b/test/unit/test_conda_decorator.py new file mode 100644 index 00000000000..1f8e2d75649 --- /dev/null +++ b/test/unit/test_conda_decorator.py @@ -0,0 +1,29 @@ +from metaflow.plugins.pypi.conda_decorator import CondaStepDecorator + + +def test_decorator_custom_attributes(): + deco = CondaStepDecorator(attributes={"python": "3.9"}) + deco.init() + assert deco.is_attribute_user_defined( + "python" + ), "python is supposed to be an user-defined attribute" + assert not deco.is_attribute_user_defined( + "packages" + ), "packages is supposed to be default" + assert not deco.is_attribute_user_defined( + "libraries" + ), "libraries is supposed to be default" + + +def test_decorator_custom_attributes_with_backward_compatibility(): + deco = CondaStepDecorator(attributes={"libraries": {"a": "test"}}) + deco.init() + assert not deco.is_attribute_user_defined( + "python" + ), "python is supposed to be default" + assert deco.is_attribute_user_defined( + "packages" + ), "packages is supposed to be user-defined" + assert deco.is_attribute_user_defined( + "libraries" + ), "libraries is supposed to be user-defined" diff --git a/test/unit/test_pypi_decorator.py b/test/unit/test_pypi_decorator.py new file mode 100644 index 00000000000..32289080524 --- /dev/null +++ b/test/unit/test_pypi_decorator.py @@ -0,0 +1,12 @@ +from metaflow.plugins.pypi.pypi_decorator import PyPIStepDecorator + + +def test_decorator_custom_attributes(): + deco = PyPIStepDecorator(attributes={"python": "3.9"}) + deco.init() + assert deco.is_attribute_user_defined( + "python" + ), "python is supposed to be an user-defined attribute" + assert not deco.is_attribute_user_defined( + "packages" + ), "packages is supposed to be default"