Skip to content

Commit

Permalink
Misc cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Nov 29, 2024
1 parent d16e3a4 commit e8aa3b2
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 63 deletions.
14 changes: 0 additions & 14 deletions src/zenml/artifacts/artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from zenml.logger import get_logger
from zenml.metadata.metadata_types import MetadataType
from zenml.utils.pydantic_utils import before_validator_handler
from zenml.utils.string_utils import format_name_template

logger = get_logger(__name__)

Expand Down Expand Up @@ -118,16 +117,3 @@ def _remove_old_attributes(cls, data: Dict[str, Any]) -> Dict[str, Any]:
data.setdefault("artifact_type", ArtifactType.SERVICE)

return data

def _evaluated_name(self, substitutions: Dict[str, str]) -> Optional[str]:
"""Evaluated name of the artifact.
Args:
substitutions: Extra placeholders to use in the name template.
Returns:
The evaluated name of the artifact.
"""
if self.name:
return format_name_template(self.name, substitutions=substitutions)
return self.name
4 changes: 2 additions & 2 deletions src/zenml/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def _get_or_create_model(self) -> "ModelResponse":

zenml_client = Client()
# backup logic, if the Model class is used directly from the code
self.name = format_name_template(self.name, substitutions={})
self.name = format_name_template(self.name)
if self.model_version_id:
mv = zenml_client.get_model_version(
model_version_name_or_number_or_id=self.model_version_id,
Expand Down Expand Up @@ -667,7 +667,7 @@ def _get_or_create_model_version(

# backup logic, if the Model class is used directly from the code
if isinstance(self.version, str):
self.version = format_name_template(self.version, substitutions={})
self.version = format_name_template(self.version)

try:
if self.version or self.model_version_id:
Expand Down
2 changes: 1 addition & 1 deletion src/zenml/models/v2/core/pipeline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class PipelineRunResponseMetadata(WorkspaceScopedResponseMetadata):
default=False,
description="Whether a template can be created from this run.",
)
steps_substitutions: Dict[str, Dict[str, str]] = Field(
step_substitutions: Dict[str, Dict[str, str]] = Field(
title="Substitutions used in the step runs of this pipeline run.",
default_factory=dict,
)
Expand Down
4 changes: 2 additions & 2 deletions src/zenml/orchestrators/step_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ def _create_or_reuse_run(self) -> Tuple[PipelineRunResponse, bool]:
and a boolean indicating whether the run was created or reused.
"""
start_time = datetime.utcnow()
run_name = orchestrator_utils.get_run_name(
run_name_template=self._deployment.run_name_template,
run_name = string_utils.format_name_template(
name_template=self._deployment.run_name_template,
substitutions=self._deployment.pipeline_configuration._get_full_substitutions(
start_time
),
Expand Down
15 changes: 7 additions & 8 deletions src/zenml/orchestrators/step_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
parse_return_type_annotations,
resolve_type_annotation,
)
from zenml.utils import materializer_utils, source_utils
from zenml.utils import materializer_utils, source_utils, string_utils
from zenml.utils.typing_utils import get_origin, is_union

if TYPE_CHECKING:
Expand Down Expand Up @@ -292,16 +292,15 @@ def _evaluate_artifact_names_in_collections(
"""
collections.append(output_annotations)
for k, v in list(output_annotations.items()):
_evaluated_name = None
if v.artifact_config:
_evaluated_name = v.artifact_config._evaluated_name(
step_run.config.substitutions
name = k
if v.artifact_config and v.artifact_config.name:
name = string_utils.format_name_template(
v.artifact_config.name,
substitutions=step_run.config.substitutions,
)
if _evaluated_name is None:
_evaluated_name = k

for d in collections:
d[_evaluated_name] = d.pop(k)
d[name] = d.pop(k)

def _load_step(self) -> "BaseStep":
"""Load the step instance.
Expand Down
24 changes: 0 additions & 24 deletions src/zenml/orchestrators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from zenml.enums import AuthScheme, StackComponentType, StoreType
from zenml.logger import get_logger
from zenml.stack import StackComponent
from zenml.utils.string_utils import format_name_template

logger = get_logger(__name__)

Expand Down Expand Up @@ -196,29 +195,6 @@ def get_config_environment_vars(
return environment_vars


def get_run_name(run_name_template: str, substitutions: Dict[str, str]) -> str:
"""Fill out the run name template to get a complete run name.
Args:
run_name_template: The run name template to fill out.
substitutions: The substitutions to use in the template.
Raises:
ValueError: If the run name is empty.
Returns:
The run name derived from the template.
"""
run_name = format_name_template(
run_name_template, substitutions=substitutions
)

if run_name == "":
raise ValueError("Empty run names are not allowed.")

return run_name


class register_artifact_store_filesystem:
"""Context manager for the artifact_store/filesystem_registry dependency.
Expand Down
7 changes: 3 additions & 4 deletions src/zenml/pipelines/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
StackResponse,
)
from zenml.orchestrators.publish_utils import publish_failed_pipeline_run
from zenml.orchestrators.utils import get_run_name
from zenml.stack import Flavor, Stack
from zenml.utils import code_utils, notebook_utils, source_utils
from zenml.utils import code_utils, notebook_utils, source_utils, string_utils
from zenml.zen_stores.base_zen_store import BaseZenStore

if TYPE_CHECKING:
Expand Down Expand Up @@ -68,8 +67,8 @@ def create_placeholder_run(
return None
start_time = datetime.utcnow()
run_request = PipelineRunRequest(
name=get_run_name(
run_name_template=deployment.run_name_template,
name=string_utils.format_name_template(
name_template=deployment.run_name_template,
substitutions=deployment.pipeline_configuration._get_full_substitutions(
start_time
),
Expand Down
19 changes: 13 additions & 6 deletions src/zenml/utils/string_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import random
import string
from datetime import datetime
from typing import Any, Callable, Dict, TypeVar, cast
from typing import Any, Callable, Dict, Optional, TypeVar, cast

from pydantic import BaseModel

Expand Down Expand Up @@ -147,13 +147,12 @@ def validate_name(model: BaseModel) -> None:

def format_name_template(
name_template: str,
substitutions: Dict[str, str],
substitutions: Optional[Dict[str, str]] = None,
) -> str:
"""Formats a name template with the given arguments.
By default, ZenML support Date and Time placeholders.
E.g. `my_run_{date}_{time}` will be formatted as `my_run_1970_01_01_00_00_00`.
Extra placeholders need to be explicitly passed in as kwargs.
By default, ZenML supports Date and Time placeholders. For example,
`my_run_{date}_{time}` will be formatted as `my_run_1970_01_01_00_00_00`.
Args:
name_template: The name template to format.
Expand All @@ -164,7 +163,10 @@ def format_name_template(
Raises:
KeyError: If a key in template is missing in the kwargs.
ValueError: If the formatted name is empty.
"""
substitutions = substitutions or {}

if ("date" not in substitutions and "{date}" in name_template) or (
"time" not in substitutions and "{time}" in name_template
):
Expand All @@ -183,13 +185,18 @@ def format_name_template(
substitutions.setdefault("time", start_time.strftime("%H_%M_%S_%f"))

try:
return name_template.format(**substitutions)
formatted_name = name_template.format(**substitutions)
except KeyError as e:
raise KeyError(
f"Could not format the name template `{name_template}`. "
f"Missing key: {e}"
)

if not formatted_name:
raise ValueError("Empty names are not allowed.")

return formatted_name


def substitute_string(value: V, substitution_func: Callable[[str], str]) -> V:
"""Recursively substitute strings in objects.
Expand Down
4 changes: 2 additions & 2 deletions src/zenml/zen_stores/schemas/pipeline_run_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def to_model(

steps = {step.name: step.to_model() for step in self.step_runs}

steps_substitutions = {
step_substitutions = {
step_name: step.config.substitutions
for step_name, step in steps.items()
}
Expand All @@ -371,7 +371,7 @@ def to_model(
if self.deployment
else None,
is_templatable=is_templatable,
steps_substitutions=steps_substitutions,
step_substitutions=step_substitutions,
)

resources = None
Expand Down

0 comments on commit e8aa3b2

Please sign in to comment.