Skip to content

Commit

Permalink
Improve logic when fetching neptune run
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Dec 2, 2024
1 parent 67dc26e commit 3f0a929
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@ def prepare_step_run(self, info: "StepRunInfo") -> None:
NeptuneExperimentTrackerSettings, self.get_settings(info)
)

self.run_state.token = self.config.api_token
self.run_state.project = self.config.project
self.run_state.run_name = info.run_name
self.run_state.tags = list(settings.tags)
self.run_state.initialize(
project=self.config.project,
token=self.config.api_token,
run_name=info.run_name,
tags=list(settings.tags),
)

def get_step_run_metadata(
self, info: "StepRunInfo"
Expand All @@ -107,4 +109,4 @@ def cleanup_step_run(self, info: "StepRunInfo", step_failed: bool) -> None:
"""
self.run_state.active_run.sync()
self.run_state.active_run.stop()
self.run_state.reset_active_run()
self.run_state.reset()
122 changes: 69 additions & 53 deletions src/zenml/integrations/neptune/experiment_trackers/run_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import zenml
from zenml.client import Client
from zenml.integrations.constants import NEPTUNE
from zenml.utils.singleton import SingletonMetaClass

if TYPE_CHECKING:
Expand All @@ -29,20 +28,38 @@
_INTEGRATION_VERSION_KEY = "source_code/integrations/zenml"


class InvalidExperimentTrackerSelected(Exception):
"""Raised if a Neptune run is fetched while using a different experiment tracker."""


class RunProvider(metaclass=SingletonMetaClass):
"""Singleton object used to store and persist a Neptune run state across the pipeline."""

def __init__(self) -> None:
"""Initialize RunProvider. Called with no arguments."""
self._active_run: Optional["Run"] = None
self._project: Optional[str]
self._run_name: Optional[str]
self._token: Optional[str]
self._tags: Optional[List[str]]
self._project: Optional[str] = None
self._run_name: Optional[str] = None
self._token: Optional[str] = None
self._tags: Optional[List[str]] = None
self._initialized = False

def initialize(
self,
project: Optional[str] = None,
token: Optional[str] = None,
run_name: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> None:
"""Initialize the run state.
Args:
project: The neptune project.
token: The neptune token.
run_name: The neptune run name.
tags: Tags for the neptune run.
"""
self._project = project
self._token = token
self._run_name = run_name
self._tags = tags
self._initialized = True

@property
def project(self) -> Optional[Any]:
Expand All @@ -53,15 +70,6 @@ def project(self) -> Optional[Any]:
"""
return self._project

@project.setter
def project(self, project: str) -> None:
"""Setter for project name.
Args:
project: Neptune project name
"""
self._project = project

@property
def token(self) -> Optional[Any]:
"""Getter for API token.
Expand All @@ -71,15 +79,6 @@ def token(self) -> Optional[Any]:
"""
return self._token

@token.setter
def token(self, token: str) -> None:
"""Setter for API token.
Args:
token: Neptune API token
"""
self._token = token

@property
def run_name(self) -> Optional[Any]:
"""Getter for run name.
Expand All @@ -89,15 +88,6 @@ def run_name(self) -> Optional[Any]:
"""
return self._run_name

@run_name.setter
def run_name(self, run_name: str) -> None:
"""Setter for run name.
Args:
run_name: name of the pipeline run
"""
self._run_name = run_name

@property
def tags(self) -> Optional[Any]:
"""Getter for run tags.
Expand All @@ -107,14 +97,14 @@ def tags(self) -> Optional[Any]:
"""
return self._tags

@tags.setter
def tags(self, tags: List[str]) -> None:
"""Setter for run tags.
@property
def initialized(self) -> bool:
"""If the run state is initialized.
Args:
tags: list of tags associated with a Neptune run
Returns:
If the run state is initialized.
"""
self._tags = tags
return self._initialized

@property
def active_run(self) -> "Run":
Expand All @@ -137,9 +127,14 @@ def active_run(self) -> "Run":
self._active_run = run
return self._active_run

def reset_active_run(self) -> None:
"""Resets the active run state to None."""
def reset(self) -> None:
"""Reset the run state."""
self._active_run = None
self._project = None
self._run_name = None
self._token = None
self._tags = None
self._initialized = False


def get_neptune_run() -> "Run":
Expand All @@ -149,14 +144,35 @@ def get_neptune_run() -> "Run":
Neptune run object
Raises:
InvalidExperimentTrackerSelected: when called while using an experiment tracker other than Neptune
RuntimeError: When unable to fetch the active neptune run.
"""
client = Client()
experiment_tracker = client.active_stack.experiment_tracker
if experiment_tracker.flavor == NEPTUNE: # type: ignore
return experiment_tracker.run_state.active_run # type: ignore
raise InvalidExperimentTrackerSelected(
"Fetching a Neptune run works only with the 'neptune' flavor of "
"the experiment tracker. The flavor currently selected is %s"
% experiment_tracker.flavor # type: ignore
from zenml.integrations.neptune.experiment_trackers import (
NeptuneExperimentTracker,
)

experiment_tracker = Client().active_stack.experiment_tracker

if not experiment_tracker:
raise RuntimeError(
"Unable to get neptune run: Missing experiment tracker in the "
"active stack."
)

if not isinstance(experiment_tracker, NeptuneExperimentTracker):
raise RuntimeError(
"Unable to get neptune run: Experiment tracker in the active "
f"stack ({experiment_tracker.flavor}) is not a neptune experiment "
"tracker."
)

run_state = experiment_tracker.run_state
if not run_state.initialized:
raise RuntimeError(
"Unable to get neptune run: The experiment tracker has not been "
"initialized. To solve this, make sure you use the experiment "
"tracker in your step. See "
"https://docs.zenml.io/stack-components/experiment-trackers/neptune#how-do-you-use-it "
"for more information."
)

return experiment_tracker.run_state.active_run

0 comments on commit 3f0a929

Please sign in to comment.