From 67dc26e6c532892e0fc46e54a7a4b3bf64e88ba4 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Mon, 2 Dec 2024 13:24:02 +0100 Subject: [PATCH] Set stack before importing pipeline --- src/zenml/cli/pipeline.py | 125 ++++++++++++++------------------------ 1 file changed, 46 insertions(+), 79 deletions(-) diff --git a/src/zenml/cli/pipeline.py b/src/zenml/cli/pipeline.py index 7a7e07b177..150348f849 100644 --- a/src/zenml/cli/pipeline.py +++ b/src/zenml/cli/pipeline.py @@ -40,6 +40,33 @@ logger = get_logger(__name__) +def _import_pipeline(source: str) -> Pipeline: + """Import a pipeline. + + Args: + source: The pipeline source. + + Returns: + The pipeline. + """ + try: + pipeline_instance = source_utils.load(source) + except ModuleNotFoundError as e: + source_root = source_utils.get_source_root() + cli_utils.error( + f"Unable to import module `{e.name}`. Make sure the source path is " + f"relative to your source root `{source_root}`." + ) + except AttributeError as e: + cli_utils.error("Unable to load attribute from module: " + str(e)) + + if not isinstance(pipeline_instance, Pipeline): + cli_utils.error( + f"The given source path `{source}` does not resolve to a pipeline " + "object." + ) + + @cli.group(cls=TagGroup, tag=CliCategories.MANAGEMENT_TOOLS) def pipeline() -> None: """Interact with pipelines, runs and schedules.""" @@ -85,22 +112,7 @@ def register_pipeline( "source code root." ) - try: - pipeline_instance = source_utils.load(source) - except ModuleNotFoundError as e: - source_root = source_utils.get_source_root() - cli_utils.error( - f"Unable to import module `{e.name}`. Make sure the source path is " - f"relative to your source root `{source_root}`." - ) - except AttributeError as e: - cli_utils.error("Unable to load attribute from module: " + str(e)) - - if not isinstance(pipeline_instance, Pipeline): - cli_utils.error( - f"The given source path `{source}` does not resolve to a pipeline " - "object." - ) + pipeline_instance = _import_pipeline(source=source) parameters: Dict[str, Any] = {} if parameters_path: @@ -176,24 +188,9 @@ def build_pipeline( "your source code root." ) - try: - pipeline_instance = source_utils.load(source) - except ModuleNotFoundError as e: - source_root = source_utils.get_source_root() - cli_utils.error( - f"Unable to import module `{e.name}`. Make sure the source path is " - f"relative to your source root `{source_root}`." - ) - except AttributeError as e: - cli_utils.error("Unable to load attribute from module: " + str(e)) - - if not isinstance(pipeline_instance, Pipeline): - cli_utils.error( - f"The given source path `{source}` does not resolve to a pipeline " - "object." - ) - with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id): + pipeline_instance = _import_pipeline(source=source) + pipeline_instance = pipeline_instance.with_options( config_path=config_path ) @@ -277,36 +274,21 @@ def run_pipeline( "your source code root." ) - try: - pipeline_instance = source_utils.load(source) - except ModuleNotFoundError as e: - source_root = source_utils.get_source_root() - cli_utils.error( - f"Unable to import module `{e.name}`. Make sure the source path is " - f"relative to your source root `{source_root}`." - ) - except AttributeError as e: - cli_utils.error("Unable to load attribute from module: " + str(e)) - - if not isinstance(pipeline_instance, Pipeline): - cli_utils.error( - f"The given source path `{source}` does not resolve to a pipeline " - "object." - ) - - build: Union[str, PipelineBuildBase, None] = None - if build_path_or_id: - if uuid_utils.is_valid_uuid(build_path_or_id): - build = build_path_or_id - elif os.path.exists(build_path_or_id): - build = PipelineBuildBase.from_yaml(build_path_or_id) - else: - cli_utils.error( - f"The specified build {build_path_or_id} is not a valid UUID " - "or file path." - ) - with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id): + pipeline_instance = _import_pipeline(source=source) + + build: Union[str, PipelineBuildBase, None] = None + if build_path_or_id: + if uuid_utils.is_valid_uuid(build_path_or_id): + build = build_path_or_id + elif os.path.exists(build_path_or_id): + build = PipelineBuildBase.from_yaml(build_path_or_id) + else: + cli_utils.error( + f"The specified build {build_path_or_id} is not a valid UUID " + "or file path." + ) + pipeline_instance = pipeline_instance.with_options( config_path=config_path, build=build, @@ -369,24 +351,9 @@ def create_run_template( "init` at your source code root." ) - try: - pipeline_instance = source_utils.load(source) - except ModuleNotFoundError as e: - source_root = source_utils.get_source_root() - cli_utils.error( - f"Unable to import module `{e.name}`. Make sure the source path is " - f"relative to your source root `{source_root}`." - ) - except AttributeError as e: - cli_utils.error("Unable to load attribute from module: " + str(e)) - - if not isinstance(pipeline_instance, Pipeline): - cli_utils.error( - f"The given source path `{source}` does not resolve to a pipeline " - "object." - ) - with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id): + pipeline_instance = _import_pipeline(source=source) + pipeline_instance = pipeline_instance.with_options( config_path=config_path )