diff --git a/examples/e2e/.copier-answers.yml b/examples/e2e/.copier-answers.yml index e6fb1292beb..a9c5acbf8c3 100644 --- a/examples/e2e/.copier-answers.yml +++ b/examples/e2e/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.11.20-2-g760142f +_commit: 2024.11.28 _src_path: gh:zenml-io/template-e2e-batch data_quality_checks: true email: info@zenml.io diff --git a/examples/e2e_nlp/.copier-answers.yml b/examples/e2e_nlp/.copier-answers.yml index 274927e3ce5..b12ebdd786b 100644 --- a/examples/e2e_nlp/.copier-answers.yml +++ b/examples/e2e_nlp/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.10.30-2-g1ae14e3 +_commit: 2024.11.28 _src_path: gh:zenml-io/template-nlp accelerator: cpu cloud_of_choice: aws diff --git a/examples/llm_finetuning/.copier-answers.yml b/examples/llm_finetuning/.copier-answers.yml index 7deecebb1d2..47bfa4cf2af 100644 --- a/examples/llm_finetuning/.copier-answers.yml +++ b/examples/llm_finetuning/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.11.08-2-gece1d46 +_commit: 2024.11.28 _src_path: gh:zenml-io/template-llm-finetuning bf16: true cuda_version: cuda11.8 diff --git a/examples/mlops_starter/.copier-answers.yml b/examples/mlops_starter/.copier-answers.yml index 364bccaa9d0..ec87b32240d 100644 --- a/examples/mlops_starter/.copier-answers.yml +++ b/examples/mlops_starter/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.10.30-7-gb60e441 +_commit: 2024.11.28 _src_path: gh:zenml-io/template-starter email: info@zenml.io full_name: ZenML GmbH diff --git a/src/zenml/cli/pipeline.py b/src/zenml/cli/pipeline.py index 7a7e07b1777..0e50ffded05 100644 --- a/src/zenml/cli/pipeline.py +++ b/src/zenml/cli/pipeline.py @@ -40,6 +40,35 @@ logger = get_logger(__name__) +def _import_pipeline(source: str) -> Pipeline: + """Import a pipeline. + + Args: + source: The pipeline source. + + Returns: + The pipeline. + """ + try: + pipeline_instance = source_utils.load(source) + except ModuleNotFoundError as e: + source_root = source_utils.get_source_root() + cli_utils.error( + f"Unable to import module `{e.name}`. Make sure the source path is " + f"relative to your source root `{source_root}`." + ) + except AttributeError as e: + cli_utils.error("Unable to load attribute from module: " + str(e)) + + if not isinstance(pipeline_instance, Pipeline): + cli_utils.error( + f"The given source path `{source}` does not resolve to a pipeline " + "object." + ) + + return pipeline_instance + + @cli.group(cls=TagGroup, tag=CliCategories.MANAGEMENT_TOOLS) def pipeline() -> None: """Interact with pipelines, runs and schedules.""" @@ -85,22 +114,7 @@ def register_pipeline( "source code root." ) - try: - pipeline_instance = source_utils.load(source) - except ModuleNotFoundError as e: - source_root = source_utils.get_source_root() - cli_utils.error( - f"Unable to import module `{e.name}`. Make sure the source path is " - f"relative to your source root `{source_root}`." - ) - except AttributeError as e: - cli_utils.error("Unable to load attribute from module: " + str(e)) - - if not isinstance(pipeline_instance, Pipeline): - cli_utils.error( - f"The given source path `{source}` does not resolve to a pipeline " - "object." - ) + pipeline_instance = _import_pipeline(source=source) parameters: Dict[str, Any] = {} if parameters_path: @@ -176,24 +190,9 @@ def build_pipeline( "your source code root." ) - try: - pipeline_instance = source_utils.load(source) - except ModuleNotFoundError as e: - source_root = source_utils.get_source_root() - cli_utils.error( - f"Unable to import module `{e.name}`. Make sure the source path is " - f"relative to your source root `{source_root}`." - ) - except AttributeError as e: - cli_utils.error("Unable to load attribute from module: " + str(e)) - - if not isinstance(pipeline_instance, Pipeline): - cli_utils.error( - f"The given source path `{source}` does not resolve to a pipeline " - "object." - ) - with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id): + pipeline_instance = _import_pipeline(source=source) + pipeline_instance = pipeline_instance.with_options( config_path=config_path ) @@ -277,36 +276,21 @@ def run_pipeline( "your source code root." ) - try: - pipeline_instance = source_utils.load(source) - except ModuleNotFoundError as e: - source_root = source_utils.get_source_root() - cli_utils.error( - f"Unable to import module `{e.name}`. Make sure the source path is " - f"relative to your source root `{source_root}`." - ) - except AttributeError as e: - cli_utils.error("Unable to load attribute from module: " + str(e)) - - if not isinstance(pipeline_instance, Pipeline): - cli_utils.error( - f"The given source path `{source}` does not resolve to a pipeline " - "object." - ) - - build: Union[str, PipelineBuildBase, None] = None - if build_path_or_id: - if uuid_utils.is_valid_uuid(build_path_or_id): - build = build_path_or_id - elif os.path.exists(build_path_or_id): - build = PipelineBuildBase.from_yaml(build_path_or_id) - else: - cli_utils.error( - f"The specified build {build_path_or_id} is not a valid UUID " - "or file path." - ) - with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id): + pipeline_instance = _import_pipeline(source=source) + + build: Union[str, PipelineBuildBase, None] = None + if build_path_or_id: + if uuid_utils.is_valid_uuid(build_path_or_id): + build = build_path_or_id + elif os.path.exists(build_path_or_id): + build = PipelineBuildBase.from_yaml(build_path_or_id) + else: + cli_utils.error( + f"The specified build {build_path_or_id} is not a valid UUID " + "or file path." + ) + pipeline_instance = pipeline_instance.with_options( config_path=config_path, build=build, @@ -369,24 +353,9 @@ def create_run_template( "init` at your source code root." ) - try: - pipeline_instance = source_utils.load(source) - except ModuleNotFoundError as e: - source_root = source_utils.get_source_root() - cli_utils.error( - f"Unable to import module `{e.name}`. Make sure the source path is " - f"relative to your source root `{source_root}`." - ) - except AttributeError as e: - cli_utils.error("Unable to load attribute from module: " + str(e)) - - if not isinstance(pipeline_instance, Pipeline): - cli_utils.error( - f"The given source path `{source}` does not resolve to a pipeline " - "object." - ) - with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id): + pipeline_instance = _import_pipeline(source=source) + pipeline_instance = pipeline_instance.with_options( config_path=config_path ) diff --git a/src/zenml/integrations/neptune/experiment_trackers/neptune_experiment_tracker.py b/src/zenml/integrations/neptune/experiment_trackers/neptune_experiment_tracker.py index df410fc7e58..e71d84d2e77 100644 --- a/src/zenml/integrations/neptune/experiment_trackers/neptune_experiment_tracker.py +++ b/src/zenml/integrations/neptune/experiment_trackers/neptune_experiment_tracker.py @@ -77,10 +77,12 @@ def prepare_step_run(self, info: "StepRunInfo") -> None: NeptuneExperimentTrackerSettings, self.get_settings(info) ) - self.run_state.token = self.config.api_token - self.run_state.project = self.config.project - self.run_state.run_name = info.run_name - self.run_state.tags = list(settings.tags) + self.run_state.initialize( + project=self.config.project, + token=self.config.api_token, + run_name=info.run_name, + tags=list(settings.tags), + ) def get_step_run_metadata( self, info: "StepRunInfo" @@ -107,4 +109,4 @@ def cleanup_step_run(self, info: "StepRunInfo", step_failed: bool) -> None: """ self.run_state.active_run.sync() self.run_state.active_run.stop() - self.run_state.reset_active_run() + self.run_state.reset() diff --git a/src/zenml/integrations/neptune/experiment_trackers/run_state.py b/src/zenml/integrations/neptune/experiment_trackers/run_state.py index 145bc464fab..6c0dbbcca25 100644 --- a/src/zenml/integrations/neptune/experiment_trackers/run_state.py +++ b/src/zenml/integrations/neptune/experiment_trackers/run_state.py @@ -20,7 +20,6 @@ import zenml from zenml.client import Client -from zenml.integrations.constants import NEPTUNE from zenml.utils.singleton import SingletonMetaClass if TYPE_CHECKING: @@ -29,20 +28,38 @@ _INTEGRATION_VERSION_KEY = "source_code/integrations/zenml" -class InvalidExperimentTrackerSelected(Exception): - """Raised if a Neptune run is fetched while using a different experiment tracker.""" - - class RunProvider(metaclass=SingletonMetaClass): """Singleton object used to store and persist a Neptune run state across the pipeline.""" def __init__(self) -> None: """Initialize RunProvider. Called with no arguments.""" self._active_run: Optional["Run"] = None - self._project: Optional[str] - self._run_name: Optional[str] - self._token: Optional[str] - self._tags: Optional[List[str]] + self._project: Optional[str] = None + self._run_name: Optional[str] = None + self._token: Optional[str] = None + self._tags: Optional[List[str]] = None + self._initialized = False + + def initialize( + self, + project: Optional[str] = None, + token: Optional[str] = None, + run_name: Optional[str] = None, + tags: Optional[List[str]] = None, + ) -> None: + """Initialize the run state. + + Args: + project: The neptune project. + token: The neptune token. + run_name: The neptune run name. + tags: Tags for the neptune run. + """ + self._project = project + self._token = token + self._run_name = run_name + self._tags = tags + self._initialized = True @property def project(self) -> Optional[Any]: @@ -53,15 +70,6 @@ def project(self) -> Optional[Any]: """ return self._project - @project.setter - def project(self, project: str) -> None: - """Setter for project name. - - Args: - project: Neptune project name - """ - self._project = project - @property def token(self) -> Optional[Any]: """Getter for API token. @@ -71,15 +79,6 @@ def token(self) -> Optional[Any]: """ return self._token - @token.setter - def token(self, token: str) -> None: - """Setter for API token. - - Args: - token: Neptune API token - """ - self._token = token - @property def run_name(self) -> Optional[Any]: """Getter for run name. @@ -89,15 +88,6 @@ def run_name(self) -> Optional[Any]: """ return self._run_name - @run_name.setter - def run_name(self, run_name: str) -> None: - """Setter for run name. - - Args: - run_name: name of the pipeline run - """ - self._run_name = run_name - @property def tags(self) -> Optional[Any]: """Getter for run tags. @@ -107,14 +97,14 @@ def tags(self) -> Optional[Any]: """ return self._tags - @tags.setter - def tags(self, tags: List[str]) -> None: - """Setter for run tags. + @property + def initialized(self) -> bool: + """If the run state is initialized. - Args: - tags: list of tags associated with a Neptune run + Returns: + If the run state is initialized. """ - self._tags = tags + return self._initialized @property def active_run(self) -> "Run": @@ -137,9 +127,14 @@ def active_run(self) -> "Run": self._active_run = run return self._active_run - def reset_active_run(self) -> None: - """Resets the active run state to None.""" + def reset(self) -> None: + """Reset the run state.""" self._active_run = None + self._project = None + self._run_name = None + self._token = None + self._tags = None + self._initialized = False def get_neptune_run() -> "Run": @@ -149,14 +144,35 @@ def get_neptune_run() -> "Run": Neptune run object Raises: - InvalidExperimentTrackerSelected: when called while using an experiment tracker other than Neptune + RuntimeError: When unable to fetch the active neptune run. """ - client = Client() - experiment_tracker = client.active_stack.experiment_tracker - if experiment_tracker.flavor == NEPTUNE: # type: ignore - return experiment_tracker.run_state.active_run # type: ignore - raise InvalidExperimentTrackerSelected( - "Fetching a Neptune run works only with the 'neptune' flavor of " - "the experiment tracker. The flavor currently selected is %s" - % experiment_tracker.flavor # type: ignore + from zenml.integrations.neptune.experiment_trackers import ( + NeptuneExperimentTracker, ) + + experiment_tracker = Client().active_stack.experiment_tracker + + if not experiment_tracker: + raise RuntimeError( + "Unable to get neptune run: Missing experiment tracker in the " + "active stack." + ) + + if not isinstance(experiment_tracker, NeptuneExperimentTracker): + raise RuntimeError( + "Unable to get neptune run: Experiment tracker in the active " + f"stack ({experiment_tracker.flavor}) is not a neptune experiment " + "tracker." + ) + + run_state = experiment_tracker.run_state + if not run_state.initialized: + raise RuntimeError( + "Unable to get neptune run: The experiment tracker has not been " + "initialized. To solve this, make sure you use the experiment " + "tracker in your step. See " + "https://docs.zenml.io/stack-components/experiment-trackers/neptune#how-do-you-use-it " + "for more information." + ) + + return experiment_tracker.run_state.active_run diff --git a/src/zenml/integrations/registry.py b/src/zenml/integrations/registry.py index 1ba15574d03..6d9f8cc10b8 100644 --- a/src/zenml/integrations/registry.py +++ b/src/zenml/integrations/registry.py @@ -111,7 +111,7 @@ def select_integration_requirements( ) else: raise KeyError( - f"Version {integration_name} does not exist. " + f"Integration {integration_name} does not exist. " f"Currently the following integrations are implemented. " f"{self.list_integration_names}" ) @@ -148,7 +148,7 @@ def select_uninstall_requirements( ].get_uninstall_requirements(target_os=target_os) else: raise KeyError( - f"Version {integration_name} does not exist. " + f"Integration {integration_name} does not exist. " f"Currently the following integrations are implemented. " f"{self.list_integration_names}" )