From 0ccb1fd68643bb2bbb349d9849e5175d65f23dce Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 29 Nov 2024 14:00:11 +0100 Subject: [PATCH] Fix input resolution for steps with dynamic artifact names (#3228) * Fix input resolution for steps with dynamic artifact names * Improve logic * Linting * Add test * Fix variable access * Really fix test * Rename --- src/zenml/models/v2/core/pipeline_run.py | 9 +++++++ src/zenml/orchestrators/input_utils.py | 25 ++++++++++++++----- src/zenml/orchestrators/step_run_utils.py | 3 +++ ...ming.py => test_dynamic_artifact_names.py} | 20 +++++++++++++++ tests/unit/conftest.py | 13 ++++++++++ tests/unit/orchestrators/test_input_utils.py | 20 ++++++++------- 6 files changed, 75 insertions(+), 15 deletions(-) rename tests/integration/functional/steps/{test_step_naming.py => test_dynamic_artifact_names.py} (95%) diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index a9cace6bc12..5d0fdd55052 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -550,6 +550,15 @@ def is_templatable(self) -> bool: """ return self.get_metadata().is_templatable + @property + def step_substitutions(self) -> Dict[str, Dict[str, str]]: + """The `step_substitutions` property. + + Returns: + the value of the property. + """ + return self.get_metadata().step_substitutions + @property def model_version(self) -> Optional[ModelVersionResponse]: """The `model_version` property. diff --git a/src/zenml/orchestrators/input_utils.py b/src/zenml/orchestrators/input_utils.py index 0094bdf84f1..1487f395a6a 100644 --- a/src/zenml/orchestrators/input_utils.py +++ b/src/zenml/orchestrators/input_utils.py @@ -20,7 +20,7 @@ from zenml.config.step_configurations import Step from zenml.enums import ArtifactSaveType, StepRunInputArtifactType from zenml.exceptions import InputResolutionError -from zenml.utils import pagination_utils +from zenml.utils import pagination_utils, string_utils if TYPE_CHECKING: from zenml.models import PipelineRunResponse @@ -53,7 +53,8 @@ def resolve_step_inputs( current_run_steps = { run_step.name: run_step for run_step in pagination_utils.depaginate( - Client().list_run_steps, pipeline_run_id=pipeline_run.id + Client().list_run_steps, + pipeline_run_id=pipeline_run.id, ) } @@ -66,11 +67,23 @@ def resolve_step_inputs( f"No step `{input_.step_name}` found in current run." ) + # Try to get the substitutions from the pipeline run first, as we + # already have a hydrated version of that. In the unlikely case + # that the pipeline run is outdated, we fetch it from the step + # run instead which will costs us one hydration call. + substitutions = ( + pipeline_run.step_substitutions.get(step_run.name) + or step_run.config.substitutions + ) + output_name = string_utils.format_name_template( + input_.output_name, substitutions=substitutions + ) + try: - outputs = step_run.outputs[input_.output_name] + outputs = step_run.outputs[output_name] except KeyError: raise InputResolutionError( - f"No step output `{input_.output_name}` found for step " + f"No step output `{output_name}` found for step " f"`{input_.step_name}`." ) @@ -83,12 +96,12 @@ def resolve_step_inputs( # This should never happen, there can only be a single regular step # output for a name raise InputResolutionError( - f"Too many step outputs for output `{input_.output_name}` of " + f"Too many step outputs for output `{output_name}` of " f"step `{input_.step_name}`." ) elif len(step_outputs) == 0: raise InputResolutionError( - f"No step output `{input_.output_name}` found for step " + f"No step output `{output_name}` found for step " f"`{input_.step_name}`." ) diff --git a/src/zenml/orchestrators/step_run_utils.py b/src/zenml/orchestrators/step_run_utils.py index d226a09f275..e371b4c509a 100644 --- a/src/zenml/orchestrators/step_run_utils.py +++ b/src/zenml/orchestrators/step_run_utils.py @@ -309,6 +309,9 @@ def create_cached_step_runs( for invocation_id in cache_candidates: visited_invocations.add(invocation_id) + # Make sure the request factory has the most up to date pipeline + # run to avoid hydration calls + request_factory.pipeline_run = pipeline_run try: step_run_request = request_factory.create_request( invocation_id diff --git a/tests/integration/functional/steps/test_step_naming.py b/tests/integration/functional/steps/test_dynamic_artifact_names.py similarity index 95% rename from tests/integration/functional/steps/test_step_naming.py rename to tests/integration/functional/steps/test_dynamic_artifact_names.py index 0446d4a1c26..2259ca2c7e2 100644 --- a/tests/integration/functional/steps/test_step_naming.py +++ b/tests/integration/functional/steps/test_dynamic_artifact_names.py @@ -12,6 +12,7 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. +from contextlib import ExitStack as does_not_raise from typing import Callable, Tuple import pytest @@ -122,6 +123,11 @@ def mixed_with_unannotated_returns() -> ( ) +@step +def step_with_string_input(input_: str) -> None: + pass + + @pytest.mark.parametrize( "step", [ @@ -362,3 +368,17 @@ def _inner(pass_to_step: str = ""): assert p2_step_subs["date"] == "step_level" assert p1_step_subs["funny_name"] == "pipeline_level" assert p2_step_subs["funny_name"] == "step_level" + + +def test_dynamically_named_artifacts_in_downstream_steps( + clean_client: "Client", +): + """Test that dynamically named artifacts can be used in downstream steps.""" + + @pipeline(enable_cache=False) + def _inner(ret: str): + artifact = dynamic_single_string_standard() + step_with_string_input(artifact) + + with does_not_raise(): + _inner("output_1") diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9dcf69a0c72..81b3e95126f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. +from collections import defaultdict from datetime import datetime from typing import Any, Callable, Dict, List, Optional from uuid import uuid4 @@ -416,6 +417,12 @@ def sample_pipeline_run( sample_workspace_model: WorkspaceResponse, ) -> PipelineRunResponse: """Return sample pipeline run view for testing purposes.""" + now = datetime.utcnow() + substitutions = { + "date": now.strftime("%Y_%m_%d"), + "time": now.strftime("%H_%M_%S_%f"), + } + return PipelineRunResponse( id=uuid4(), name="sample_run_name", @@ -430,6 +437,7 @@ def sample_pipeline_run( workspace=sample_workspace_model, config=PipelineConfiguration(name="aria_pipeline"), is_templatable=False, + steps_substitutions=defaultdict(lambda: substitutions.copy()), ), resources=PipelineRunResponseResources(tags=[]), ) @@ -543,10 +551,15 @@ def f( spec = StepSpec.model_validate( {"source": "module.step_class", "upstream_steps": []} ) + now = datetime.utcnow() config = StepConfiguration.model_validate( { "name": step_name, "outputs": outputs or {}, + "substitutions": { + "date": now.strftime("%Y_%m_%d"), + "time": now.strftime("%H_%M_%S_%f"), + }, } ) return StepRunResponse( diff --git a/tests/unit/orchestrators/test_input_utils.py b/tests/unit/orchestrators/test_input_utils.py index 8c97e41feb6..d56a3c15a31 100644 --- a/tests/unit/orchestrators/test_input_utils.py +++ b/tests/unit/orchestrators/test_input_utils.py @@ -12,14 +12,13 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. -from uuid import uuid4 import pytest from zenml.config.step_configurations import Step from zenml.enums import StepRunInputArtifactType from zenml.exceptions import InputResolutionError -from zenml.models import Page, PipelineRunResponse +from zenml.models import Page from zenml.models.v2.core.artifact_version import ArtifactVersionResponse from zenml.models.v2.core.step_run import StepRunInputResponse from zenml.orchestrators import input_utils @@ -29,6 +28,7 @@ def test_input_resolution( mocker, sample_artifact_version_model: ArtifactVersionResponse, create_step_run, + sample_pipeline_run, ): """Tests that input resolution works if the correct models exist in the zen store.""" @@ -60,7 +60,7 @@ def test_input_resolution( ) input_artifacts, parent_ids = input_utils.resolve_step_inputs( - step=step, pipeline_run=PipelineRunResponse(id=uuid4(), name="foo") + step=step, pipeline_run=sample_pipeline_run ) assert input_artifacts == { "input_name": StepRunInputResponse( @@ -71,7 +71,7 @@ def test_input_resolution( assert parent_ids == [step_run.id] -def test_input_resolution_with_missing_step_run(mocker): +def test_input_resolution_with_missing_step_run(mocker, sample_pipeline_run): """Tests that input resolution fails if the upstream step run is missing.""" mocker.patch( "zenml.zen_stores.sql_zen_store.SqlZenStore.list_run_steps", @@ -97,11 +97,13 @@ def test_input_resolution_with_missing_step_run(mocker): with pytest.raises(InputResolutionError): input_utils.resolve_step_inputs( - step=step, pipeline_run=PipelineRunResponse(id=uuid4(), name="foo") + step=step, pipeline_run=sample_pipeline_run ) -def test_input_resolution_with_missing_artifact(mocker, create_step_run): +def test_input_resolution_with_missing_artifact( + mocker, create_step_run, sample_pipeline_run +): """Tests that input resolution fails if the upstream step run output artifact is missing.""" step_run = create_step_run( @@ -132,12 +134,12 @@ def test_input_resolution_with_missing_artifact(mocker, create_step_run): with pytest.raises(InputResolutionError): input_utils.resolve_step_inputs( - step=step, pipeline_run=PipelineRunResponse(id=uuid4(), name="foo") + step=step, pipeline_run=sample_pipeline_run ) def test_input_resolution_fetches_all_run_steps( - mocker, sample_artifact_version_model, create_step_run + mocker, sample_artifact_version_model, create_step_run, sample_pipeline_run ): """Tests that input resolution fetches all step runs of the pipeline run.""" step_run = create_step_run( @@ -178,7 +180,7 @@ def test_input_resolution_fetches_all_run_steps( ) input_utils.resolve_step_inputs( - step=step, pipeline_run=PipelineRunResponse(id=uuid4(), name="foo") + step=step, pipeline_run=sample_pipeline_run ) # `resolve_step_inputs(...)` depaginates the run steps so we fetch all