diff --git a/src/zenml/integrations/neptune/experiment_trackers/neptune_experiment_tracker.py b/src/zenml/integrations/neptune/experiment_trackers/neptune_experiment_tracker.py index df410fc7e58..e71d84d2e77 100644 --- a/src/zenml/integrations/neptune/experiment_trackers/neptune_experiment_tracker.py +++ b/src/zenml/integrations/neptune/experiment_trackers/neptune_experiment_tracker.py @@ -77,10 +77,12 @@ def prepare_step_run(self, info: "StepRunInfo") -> None: NeptuneExperimentTrackerSettings, self.get_settings(info) ) - self.run_state.token = self.config.api_token - self.run_state.project = self.config.project - self.run_state.run_name = info.run_name - self.run_state.tags = list(settings.tags) + self.run_state.initialize( + project=self.config.project, + token=self.config.api_token, + run_name=info.run_name, + tags=list(settings.tags), + ) def get_step_run_metadata( self, info: "StepRunInfo" @@ -107,4 +109,4 @@ def cleanup_step_run(self, info: "StepRunInfo", step_failed: bool) -> None: """ self.run_state.active_run.sync() self.run_state.active_run.stop() - self.run_state.reset_active_run() + self.run_state.reset() diff --git a/src/zenml/integrations/neptune/experiment_trackers/run_state.py b/src/zenml/integrations/neptune/experiment_trackers/run_state.py index 145bc464fab..6c0dbbcca25 100644 --- a/src/zenml/integrations/neptune/experiment_trackers/run_state.py +++ b/src/zenml/integrations/neptune/experiment_trackers/run_state.py @@ -20,7 +20,6 @@ import zenml from zenml.client import Client -from zenml.integrations.constants import NEPTUNE from zenml.utils.singleton import SingletonMetaClass if TYPE_CHECKING: @@ -29,20 +28,38 @@ _INTEGRATION_VERSION_KEY = "source_code/integrations/zenml" -class InvalidExperimentTrackerSelected(Exception): - """Raised if a Neptune run is fetched while using a different experiment tracker.""" - - class RunProvider(metaclass=SingletonMetaClass): """Singleton object used to store and persist a Neptune run state across the pipeline.""" def __init__(self) -> None: """Initialize RunProvider. Called with no arguments.""" self._active_run: Optional["Run"] = None - self._project: Optional[str] - self._run_name: Optional[str] - self._token: Optional[str] - self._tags: Optional[List[str]] + self._project: Optional[str] = None + self._run_name: Optional[str] = None + self._token: Optional[str] = None + self._tags: Optional[List[str]] = None + self._initialized = False + + def initialize( + self, + project: Optional[str] = None, + token: Optional[str] = None, + run_name: Optional[str] = None, + tags: Optional[List[str]] = None, + ) -> None: + """Initialize the run state. + + Args: + project: The neptune project. + token: The neptune token. + run_name: The neptune run name. + tags: Tags for the neptune run. + """ + self._project = project + self._token = token + self._run_name = run_name + self._tags = tags + self._initialized = True @property def project(self) -> Optional[Any]: @@ -53,15 +70,6 @@ def project(self) -> Optional[Any]: """ return self._project - @project.setter - def project(self, project: str) -> None: - """Setter for project name. - - Args: - project: Neptune project name - """ - self._project = project - @property def token(self) -> Optional[Any]: """Getter for API token. @@ -71,15 +79,6 @@ def token(self) -> Optional[Any]: """ return self._token - @token.setter - def token(self, token: str) -> None: - """Setter for API token. - - Args: - token: Neptune API token - """ - self._token = token - @property def run_name(self) -> Optional[Any]: """Getter for run name. @@ -89,15 +88,6 @@ def run_name(self) -> Optional[Any]: """ return self._run_name - @run_name.setter - def run_name(self, run_name: str) -> None: - """Setter for run name. - - Args: - run_name: name of the pipeline run - """ - self._run_name = run_name - @property def tags(self) -> Optional[Any]: """Getter for run tags. @@ -107,14 +97,14 @@ def tags(self) -> Optional[Any]: """ return self._tags - @tags.setter - def tags(self, tags: List[str]) -> None: - """Setter for run tags. + @property + def initialized(self) -> bool: + """If the run state is initialized. - Args: - tags: list of tags associated with a Neptune run + Returns: + If the run state is initialized. """ - self._tags = tags + return self._initialized @property def active_run(self) -> "Run": @@ -137,9 +127,14 @@ def active_run(self) -> "Run": self._active_run = run return self._active_run - def reset_active_run(self) -> None: - """Resets the active run state to None.""" + def reset(self) -> None: + """Reset the run state.""" self._active_run = None + self._project = None + self._run_name = None + self._token = None + self._tags = None + self._initialized = False def get_neptune_run() -> "Run": @@ -149,14 +144,35 @@ def get_neptune_run() -> "Run": Neptune run object Raises: - InvalidExperimentTrackerSelected: when called while using an experiment tracker other than Neptune + RuntimeError: When unable to fetch the active neptune run. """ - client = Client() - experiment_tracker = client.active_stack.experiment_tracker - if experiment_tracker.flavor == NEPTUNE: # type: ignore - return experiment_tracker.run_state.active_run # type: ignore - raise InvalidExperimentTrackerSelected( - "Fetching a Neptune run works only with the 'neptune' flavor of " - "the experiment tracker. The flavor currently selected is %s" - % experiment_tracker.flavor # type: ignore + from zenml.integrations.neptune.experiment_trackers import ( + NeptuneExperimentTracker, ) + + experiment_tracker = Client().active_stack.experiment_tracker + + if not experiment_tracker: + raise RuntimeError( + "Unable to get neptune run: Missing experiment tracker in the " + "active stack." + ) + + if not isinstance(experiment_tracker, NeptuneExperimentTracker): + raise RuntimeError( + "Unable to get neptune run: Experiment tracker in the active " + f"stack ({experiment_tracker.flavor}) is not a neptune experiment " + "tracker." + ) + + run_state = experiment_tracker.run_state + if not run_state.initialized: + raise RuntimeError( + "Unable to get neptune run: The experiment tracker has not been " + "initialized. To solve this, make sure you use the experiment " + "tracker in your step. See " + "https://docs.zenml.io/stack-components/experiment-trackers/neptune#how-do-you-use-it " + "for more information." + ) + + return experiment_tracker.run_state.active_run