Skip to content

Commit

Permalink
Set stack before importing pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Dec 2, 2024
1 parent fbbfc29 commit 67dc26e
Showing 1 changed file with 46 additions and 79 deletions.
125 changes: 46 additions & 79 deletions src/zenml/cli/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,33 @@
logger = get_logger(__name__)


def _import_pipeline(source: str) -> Pipeline:
"""Import a pipeline.
Args:
source: The pipeline source.
Returns:
The pipeline.
"""
try:
pipeline_instance = source_utils.load(source)
except ModuleNotFoundError as e:
source_root = source_utils.get_source_root()
cli_utils.error(
f"Unable to import module `{e.name}`. Make sure the source path is "
f"relative to your source root `{source_root}`."
)
except AttributeError as e:
cli_utils.error("Unable to load attribute from module: " + str(e))

if not isinstance(pipeline_instance, Pipeline):
cli_utils.error(
f"The given source path `{source}` does not resolve to a pipeline "
"object."
)


@cli.group(cls=TagGroup, tag=CliCategories.MANAGEMENT_TOOLS)
def pipeline() -> None:
"""Interact with pipelines, runs and schedules."""
Expand Down Expand Up @@ -85,22 +112,7 @@ def register_pipeline(
"source code root."
)

try:
pipeline_instance = source_utils.load(source)
except ModuleNotFoundError as e:
source_root = source_utils.get_source_root()
cli_utils.error(
f"Unable to import module `{e.name}`. Make sure the source path is "
f"relative to your source root `{source_root}`."
)
except AttributeError as e:
cli_utils.error("Unable to load attribute from module: " + str(e))

if not isinstance(pipeline_instance, Pipeline):
cli_utils.error(
f"The given source path `{source}` does not resolve to a pipeline "
"object."
)
pipeline_instance = _import_pipeline(source=source)

parameters: Dict[str, Any] = {}
if parameters_path:
Expand Down Expand Up @@ -176,24 +188,9 @@ def build_pipeline(
"your source code root."
)

try:
pipeline_instance = source_utils.load(source)
except ModuleNotFoundError as e:
source_root = source_utils.get_source_root()
cli_utils.error(
f"Unable to import module `{e.name}`. Make sure the source path is "
f"relative to your source root `{source_root}`."
)
except AttributeError as e:
cli_utils.error("Unable to load attribute from module: " + str(e))

if not isinstance(pipeline_instance, Pipeline):
cli_utils.error(
f"The given source path `{source}` does not resolve to a pipeline "
"object."
)

with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id):
pipeline_instance = _import_pipeline(source=source)

pipeline_instance = pipeline_instance.with_options(
config_path=config_path
)
Expand Down Expand Up @@ -277,36 +274,21 @@ def run_pipeline(
"your source code root."
)

try:
pipeline_instance = source_utils.load(source)
except ModuleNotFoundError as e:
source_root = source_utils.get_source_root()
cli_utils.error(
f"Unable to import module `{e.name}`. Make sure the source path is "
f"relative to your source root `{source_root}`."
)
except AttributeError as e:
cli_utils.error("Unable to load attribute from module: " + str(e))

if not isinstance(pipeline_instance, Pipeline):
cli_utils.error(
f"The given source path `{source}` does not resolve to a pipeline "
"object."
)

build: Union[str, PipelineBuildBase, None] = None
if build_path_or_id:
if uuid_utils.is_valid_uuid(build_path_or_id):
build = build_path_or_id
elif os.path.exists(build_path_or_id):
build = PipelineBuildBase.from_yaml(build_path_or_id)
else:
cli_utils.error(
f"The specified build {build_path_or_id} is not a valid UUID "
"or file path."
)

with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id):
pipeline_instance = _import_pipeline(source=source)

build: Union[str, PipelineBuildBase, None] = None
if build_path_or_id:
if uuid_utils.is_valid_uuid(build_path_or_id):
build = build_path_or_id
elif os.path.exists(build_path_or_id):
build = PipelineBuildBase.from_yaml(build_path_or_id)
else:
cli_utils.error(
f"The specified build {build_path_or_id} is not a valid UUID "
"or file path."
)

pipeline_instance = pipeline_instance.with_options(
config_path=config_path,
build=build,
Expand Down Expand Up @@ -369,24 +351,9 @@ def create_run_template(
"init` at your source code root."
)

try:
pipeline_instance = source_utils.load(source)
except ModuleNotFoundError as e:
source_root = source_utils.get_source_root()
cli_utils.error(
f"Unable to import module `{e.name}`. Make sure the source path is "
f"relative to your source root `{source_root}`."
)
except AttributeError as e:
cli_utils.error("Unable to load attribute from module: " + str(e))

if not isinstance(pipeline_instance, Pipeline):
cli_utils.error(
f"The given source path `{source}` does not resolve to a pipeline "
"object."
)

with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id):
pipeline_instance = _import_pipeline(source=source)

pipeline_instance = pipeline_instance.with_options(
config_path=config_path
)
Expand Down

0 comments on commit 67dc26e

Please sign in to comment.