From fbbfc2995b76a1d402f5040af239e5b5228fa37b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bar=C4=B1=C5=9F=20Can=20Durak?= <36421093+bcdurak@users.noreply.github.com> Date: Fri, 29 Nov 2024 17:45:32 +0100 Subject: [PATCH] Follow-up on the `run_metadata` changes (#3193) * Initial commit, nuking all metadata responses and seeing what breaks * Removed last remnant of LazyLoader * Reintroducing the lazy loaders. * Add LazyRunMetadataResponse to EntrypointFunctionDefinition * Test for lazy loaders works now * Fixed tests, reformatted * Use updated template * Auto-update of Starter template * Updated more templates * Fixed failing test * Fixed step run schemas * Auto-update of E2E template * Auto-update of NLP template * Fixed tests, removed additional .value access * Further fixing * Fixed linting issues * Reformatted * Linted, formatted and tested again * Typing * Maybe fix everything * Apply some feedback * new operation * new log_metadata function * changes to the base filters * new filters * adding log_metadata to __all__ * checkpoint with float casting * adding tests * final touches and formatting * formatting * moved the utils * modified log metadata function * checkpoint * deprecating the old functions * linting and final fixes * better error message * fixing the client method * better error message * consistent creation\ * adjusting tests * linting * changes for step metadata * more test adjustments * testing unit tests * linting * fixing more tests * fixing more tests * more test fixes * fixing the test * fixing per comments * added validation, constant error message * linting * new changes * second checkpoint * fixing revisions * adding overlap to remove warnings * complete docs changes * adding a parameter to control the related entity behaviour * fixing the toc * fixed the description * docstring * spellcheck * metadata creation during artifact version creation * allowing artifact metadata with name for external artifact * update the template versions * Auto-update of LLM Finetuning template * Auto-update of Starter template * Auto-update of E2E template * Auto-update of NLP template * fixing the migration script * formatting * redirects * minor fixes * working pipelines again * small fix * working checkpoint * fixes, linting, docstrings * fixing unit tests * docs updates 1 * docs update 2 * fixing integration tests * spellcheck * formatting * Auto-update of E2E template * docs changes * review comments * added the batch rbac call * added a validator to check the name of the keys * small adjustments * base schema added * formatting * new functionalities * breaking circular imports * spellchecker * other minor fixes * covering the uncovered case * adjusting tests * fixing the quickstart again * minor change * going back to publisher step id * updating github refs * Auto-update of LLM Finetuning template * Auto-update of Starter template * fixing tests * updated docs * Auto-update of E2E template * Auto-update of NLP template * formatting * review comments * adding some tests in * review comments * Update src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py Co-authored-by: Michael Schuster * Update src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py Co-authored-by: Michael Schuster * Update src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py Co-authored-by: Michael Schuster * Update src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py Co-authored-by: Michael Schuster * Update src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py Co-authored-by: Michael Schuster * changed assert to value error * fixed the alembic head * changed the interaction with the models * trimmed down * small bugfix * naming recommendations * linting * fixing the test --------- Co-authored-by: AlexejPenner Co-authored-by: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Co-authored-by: GitHub Actions Co-authored-by: Michael Schuster Co-authored-by: Michael Schuster --- .gitbook.yaml | 1 + .../update-templates-to-examples.yml | 8 +- .../track-metrics-metadata/README.md | 42 ++- .../attach-metadata-to-a-model.md | 71 ++-- .../attach-metadata-to-a-run.md | 87 +++++ .../attach-metadata-to-a-step.md | 105 ++++++ .../attach-metadata-to-an-artifact.md | 85 +++-- .../attach-metadata-to-steps.md | 65 ---- .../fetch-metadata-within-steps.md | 4 + .../grouping-metadata.md | 16 +- .../logging-metadata.md | 15 +- .../build-pipelines/README.md | 5 +- docs/book/toc.md | 5 +- examples/e2e/.copier-answers.yml | 2 +- .../hp_tuning/hp_tuning_single_search.py | 5 +- examples/e2e_nlp/.copier-answers.yml | 2 +- .../e2e_nlp/steps/training/model_trainer.py | 5 +- examples/llm_finetuning/.copier-answers.yml | 2 +- examples/llm_finetuning/steps/log_metadata.py | 9 +- .../llm_finetuning/steps/prepare_datasets.py | 9 +- examples/mlops_starter/.copier-answers.yml | 2 +- .../mlops_starter/steps/data_preprocessor.py | 7 +- .../mlops_starter/steps/model_evaluator.py | 19 +- examples/quickstart/steps/model_evaluator.py | 11 +- examples/quickstart/steps/model_tester.py | 14 +- src/zenml/artifacts/utils.py | 79 ++-- src/zenml/cli/base.py | 8 +- src/zenml/client.py | 20 +- src/zenml/model/model.py | 8 +- src/zenml/model/utils.py | 34 +- src/zenml/models/__init__.py | 6 + src/zenml/models/v2/core/artifact_version.py | 9 +- src/zenml/models/v2/core/model_version.py | 9 +- src/zenml/models/v2/core/pipeline_run.py | 9 +- src/zenml/models/v2/core/run_metadata.py | 39 +- src/zenml/models/v2/core/step_run.py | 8 +- src/zenml/models/v2/misc/run_metadata.py | 38 ++ src/zenml/orchestrators/publish_utils.py | 15 +- src/zenml/steps/utils.py | 8 +- src/zenml/utils/metadata_utils.py | 339 ++++++++++-------- .../routers/workspaces_endpoints.py | 38 +- .../cc269488e5a9_separate_run_metadata.py | 135 +++++++ src/zenml/zen_stores/schemas/__init__.py | 6 +- .../zen_stores/schemas/artifact_schemas.py | 20 +- src/zenml/zen_stores/schemas/model_schemas.py | 24 +- .../schemas/pipeline_run_schemas.py | 43 ++- .../schemas/run_metadata_schemas.py | 97 +++-- .../zen_stores/schemas/step_run_schemas.py | 25 +- src/zenml/zen_stores/schemas/utils.py | 50 ++- src/zenml/zen_stores/sql_zen_store.py | 125 +++++-- .../functional/artifacts/test_utils.py | 11 +- .../functional/model/test_model_version.py | 12 +- .../pipelines/test_pipeline_context.py | 2 +- .../functional/steps/test_step_context.py | 3 +- tests/integration/functional/test_client.py | 32 +- .../functional/utils/test_metadata_utils.py | 184 ++++++++++ .../functional/zen_stores/test_zen_store.py | 16 +- 57 files changed, 1482 insertions(+), 566 deletions(-) create mode 100644 docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md create mode 100644 docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md delete mode 100644 docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-steps.md create mode 100644 src/zenml/models/v2/misc/run_metadata.py create mode 100644 src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py create mode 100644 tests/integration/functional/utils/test_metadata_utils.py diff --git a/.gitbook.yaml b/.gitbook.yaml index 984ca276a61..8a1dc252feb 100644 --- a/.gitbook.yaml +++ b/.gitbook.yaml @@ -18,6 +18,7 @@ redirects: how-to/setting-up-a-project-repository/best-practices: how-to/project-setup-and-management/setting-up-a-project-repository/set-up-repository.md getting-started/zenml-pro/system-architectures: getting-started/system-architectures.md how-to/build-pipelines/name-your-pipeline-and-runs: how-to/pipeline-development/build-pipelines/name-your-pipeline-runs.md + how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-steps: how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md # ZenML Pro getting-started/zenml-pro/user-management: getting-started/zenml-pro/core-concepts.md diff --git a/.github/workflows/update-templates-to-examples.yml b/.github/workflows/update-templates-to-examples.yml index ecaaadb5c61..cb6a8ff6a93 100644 --- a/.github/workflows/update-templates-to-examples.yml +++ b/.github/workflows/update-templates-to-examples.yml @@ -46,7 +46,7 @@ jobs: python-version: ${{ inputs.python-version }} stack-name: local ref-zenml: ${{ github.ref }} - ref-template: 2024.11.20 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py + ref-template: 2024.11.28 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py - name: Clean-up run: | rm -rf ./local_checkout @@ -118,7 +118,7 @@ jobs: python-version: ${{ inputs.python-version }} stack-name: local ref-zenml: ${{ github.ref }} - ref-template: 2024.10.30 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py + ref-template: 2024.11.28 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py - name: Clean-up run: | rm -rf ./local_checkout @@ -189,7 +189,7 @@ jobs: python-version: ${{ inputs.python-version }} stack-name: local ref-zenml: ${{ github.ref }} - ref-template: 2024.10.30 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py + ref-template: 2024.11.28 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py - name: Clean-up run: | rm -rf ./local_checkout @@ -261,7 +261,7 @@ jobs: with: python-version: ${{ inputs.python-version }} ref-zenml: ${{ github.ref }} - ref-template: 2024.11.08 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py + ref-template: 2024.11.28 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py - name: Clean-up run: | rm -rf ./local_checkout diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/README.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/README.md index df281351c70..fd27d792107 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/README.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/README.md @@ -5,10 +5,44 @@ description: Tracking metrics and metadata # Track metrics and metadata -Logging metrics and metadata is standardized in ZenML. The most common pattern is to use the `log_xxx` methods, e.g.: +ZenML provides a unified way to log and manage metrics and metadata through +the `log_metadata` function. This versatile function allows you to log +metadata across various entities like models, artifacts, steps, and runs +through a single interface. Additionally, you can adjust if you want to +automatically the same metadata for the related entities. -* Log metadata to a [model](attach-metadata-to-a-model.md): `log_model_metadata` -* Log metadata to an [artifact](attach-metadata-to-an-artifact.md): `log_artifact_metadata` -* Log metadata to a [step](attach-metadata-to-steps.md): `log_step_metadata` +### The most basic use-case + +You can use the `log_metadata` function within a step: + +```python +from zenml import step, log_metadata + +@step +def my_step() -> ...: + log_metadata(metadata={"accuracy": 0.91}) + ... +``` + +This will log the `accuracy` for the step, its pipeline run, and if provided +its model version. + +### Additional use-cases + +The `log_metadata` function also supports various use-cases by allowing you to +specify the target entity (e.g., model, artifact, step, or run) with flexible +parameters. You can learn more about these use-cases in the following pages: + +- [Log metadata to a step](attach-metadata-to-a-step.md) +- [Log metadata to a run](attach-metadata-to-a-run.md) +- [Log metadata to an artifact](attach-metadata-to-an-artifact.md) +- [Log metadata to a model](attach-metadata-to-a-model.md) + +{% hint style="warning" %} +The older methods for logging metadata to specific entities, such as +`log_model_metadata`, `log_artifact_metadata`, and `log_step_metadata`, are +now deprecated. It is recommended to use `log_metadata` for all future +implementations. +{% endhint %}
ZenML Scarf
diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md index 0e46458b803..05bd97d5529 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md @@ -1,34 +1,43 @@ --- -description: >- - Attach any metadata as key-value pairs to your models for future reference and - auditability. +description: Learn how to attach metadata to a model. --- # Attach metadata to a model +ZenML allows you to log metadata for models, which provides additional context +that goes beyond individual artifact details. Model metadata can represent +high-level insights, such as evaluation results, deployment information, +or customer-specific details, making it easier to manage and interpret +the model's usage and performance across different versions. + ## Logging Metadata for Models -While artifact metadata is specific to individual outputs of steps, model metadata encapsulates broader and more general information that spans across multiple artifacts. For example, evaluation results or the name of a customer for whom the model is intended could be logged with the model. +To log metadata for a model, use the `log_metadata` function. This function +lets you attach key-value metadata to a model, which can include metrics and +other JSON-serializable values, such as custom ZenML types like `Uri`, +`Path`, and `StorageSize`. Here's an example of logging metadata for a model: ```python -from zenml import step, log_model_metadata, ArtifactConfig, get_step_context from typing import Annotated + import pandas as pd -from sklearn.ensemble import RandomForestClassifier from sklearn.base import ClassifierMixin +from sklearn.ensemble import RandomForestClassifier + +from zenml import step, log_metadata, ArtifactConfig, get_step_context + @step -def train_model(dataset: pd.DataFrame) -> Annotated[ClassifierMixin, ArtifactConfig(name="sklearn_classifier")]: - """Train a model""" - # Fit the model and compute metrics +def train_model(dataset: pd.DataFrame) -> Annotated[ + ClassifierMixin, ArtifactConfig(name="sklearn_classifier") +]: + """Train a model and log model metadata.""" classifier = RandomForestClassifier().fit(dataset) accuracy, precision, recall = ... - - # Log metadata for the model - # This associates the metadata with the ZenML model, not the artifact - log_model_metadata( + + log_metadata( metadata={ "evaluation_metrics": { "accuracy": accuracy, @@ -36,19 +45,36 @@ def train_model(dataset: pd.DataFrame) -> Annotated[ClassifierMixin, ArtifactCon "recall": recall } }, - # Omitted model_name will use the model in the current context - model_name="zenml_model_name", - # Omitted model_version will default to 'latest' - model_version="zenml_model_version", + infer_model=True, ) + return classifier ``` -In this example, the metadata is associated with the model rather than the specific classifier artifact. This is particularly useful when the metadata reflects an aggregation or summary of various steps and artifacts in the pipeline. +In this example, the metadata is associated with the model rather than the +specific classifier artifact. This is particularly useful when the metadata +reflects an aggregation or summary of various steps and artifacts in the +pipeline. + + +### Selecting Models with `log_metadata` + +When using `log_metadata`, ZenML provides flexible options of attaching +metadata to model versions: + +1. **Using `infer_model`**: If used within a step, ZenML will use the step + context to infer the model it is using and attach the metadata to it. +2. **Model Name and Version Provided**: If both a model name and version are + provided, ZenML will use these to identify and attach metadata to the + specific model version. +3. **Model Version ID Provided**: If a model version ID is directly provided, + ZenML will use it to fetch and attach the metadata to that specific model + version. ## Fetching logged metadata -Once metadata has been logged in an [artifact](attach-metadata-to-an-artifact.md), model, or [step](attach-metadata-to-steps.md), we can easily fetch the metadata with the ZenML Client: +Once metadata has been attached to a model, it can be retrieved for inspection +or analysis using the ZenML Client. ```python from zenml.client import Client @@ -56,7 +82,12 @@ from zenml.client import Client client = Client() model = client.get_model_version("my_model", "my_version") -print(model.run_metadata["metadata_key"].value) +print(model.run_metadata["metadata_key"]) ``` +{% hint style="info" %} +When you are fetching metadata using a specific key, the returned value will +always reflect the latest entry. +{% endhint %} +
ZenML Scarf
diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md new file mode 100644 index 00000000000..e04a0c9006f --- /dev/null +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md @@ -0,0 +1,87 @@ +--- +description: Learn how to attach metadata to a run. +--- + +# Attach Metadata to a Run + +In ZenML, you can log metadata directly to a pipeline run, either during or +after execution, using the `log_metadata` function. This function allows you +to attach a dictionary of key-value pairs as metadata to a pipeline run, +with values that can be any JSON-serializable data type, including ZenML +custom types like `Uri`, `Path`, `DType`, and `StorageSize`. + +## Logging Metadata Within a Run + +If you are logging metadata from within a step that’s part of a pipeline run, +calling `log_metadata` will attach the specified metadata to the current +pipeline run where the metadata key will have the `step_name::metadata_key` +pattern. This allows you to use the same metadata key from different steps +while the run's still executing. + +```python +from typing import Annotated + +import pandas as pd +from sklearn.base import ClassifierMixin +from sklearn.ensemble import RandomForestClassifier + +from zenml import step, log_metadata, ArtifactConfig + + +@step +def train_model(dataset: pd.DataFrame) -> Annotated[ + ClassifierMixin, + ArtifactConfig(name="sklearn_classifier", is_model_artifact=True) +]: + """Train a model and log run-level metadata.""" + classifier = RandomForestClassifier().fit(dataset) + accuracy, precision, recall = ... + + # Log metadata at the run level + log_metadata( + metadata={ + "run_metrics": { + "accuracy": accuracy, + "precision": precision, + "recall": recall + } + } + ) + return classifier +``` + +## Manually Logging Metadata to a Pipeline Run + +You can also attach metadata to a specific pipeline run without needing a step, +using identifiers like the run ID. This is useful when logging information or +metrics that were calculated post-execution. + +```python +from zenml import log_metadata + +log_metadata( + metadata={"post_run_info": {"some_metric": 5.0}}, + run_id_name_or_prefix="run_id_name_or_prefix" +) +``` + +## Fetching Logged Metadata + +Once metadata has been logged in a pipeline run, you can retrieve it using +the ZenML Client: + +```python +from zenml.client import Client + +client = Client() +run = client.get_pipeline_run("run_id_name_or_prefix") + +print(run.run_metadata["metadata_key"]) +``` + +{% hint style="info" %} +When you are fetching metadata using a specific key, the returned value will +always reflect the latest entry. +{% endhint %} + +
ZenML Scarf
diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md new file mode 100644 index 00000000000..e53b49a8274 --- /dev/null +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md @@ -0,0 +1,105 @@ +--- +description: Learn how to attach metadata to a step. +--- + +# Attach metadata to a step + +In ZenML, you can log metadata for a specific step during or after its +execution by using the `log_metadata` function. This function allows you to +attach a dictionary of key-value pairs as metadata to a step. The metadata +can be any JSON-serializable value, including custom classes such as +`Uri`, `Path`, `DType`, and `StorageSize`. + +## Logging Metadata Within a Step + +If called within a step, `log_metadata` automatically attaches the metadata to +the currently executing step and its associated pipeline run. This is +ideal for logging metrics or information that becomes available during the +step execution. + +```python +from typing import Annotated + +import pandas as pd +from sklearn.base import ClassifierMixin +from sklearn.ensemble import RandomForestClassifier + +from zenml import step, log_metadata, ArtifactConfig + + +@step +def train_model(dataset: pd.DataFrame) -> Annotated[ + ClassifierMixin, + ArtifactConfig(name="sklearn_classifier") +]: + """Train a model and log evaluation metrics.""" + classifier = RandomForestClassifier().fit(dataset) + accuracy, precision, recall = ... + + # Log metadata at the step level + log_metadata( + metadata={ + "evaluation_metrics": { + "accuracy": accuracy, + "precision": precision, + "recall": recall + } + } + ) + return classifier +``` + +{% hint style="info" %} +If you run a pipeline where the step execution is cached, the cached step run +will copy the metadata that was created in the original step execution. +(If there is any metadata that was generated manually after the execution of +the original step, these entries will not be included in this process.) +{% endhint %} + +## Manually Logging Metadata a Step Run + +You can also log metadata for a specific step after execution, using +identifiers to specify the pipeline, step, and run. This approach is +useful when you want to log metadata post-execution. + +```python +from zenml import log_metadata + +log_metadata( + metadata={ + "additional_info": {"a_number": 3} + }, + step_name="step_name", + run_id_name_or_prefix="run_id_name_or_prefix" +) + +# or + +log_metadata( + metadata={ + "additional_info": {"a_number": 3} + }, + step_id="step_id", +) +``` + +## Fetching logged metadata + +Once metadata has been logged in a step, we can easily fetch the metadata with +the ZenML Client: + +```python +from zenml.client import Client + +client = Client() +step = client.get_pipeline_run("pipeline_id").steps["step_name"] + +print(step.run_metadata["metadata_key"]) +``` + +{% hint style="info" %} +When you are fetching metadata using a specific key, the returned value will +always reflect the latest entry. +{% endhint %} + +
ZenML Scarf
diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md index 4d19bca4b51..7f57ac1c18a 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md @@ -1,48 +1,77 @@ --- -description: Learn how to log metadata for artifacts and models in ZenML. +description: Learn how to attach metadata to an artifact. --- # Attach metadata to an artifact ![Metadata in the dashboard](../../../.gitbook/assets/metadata-in-dashboard.png) -Metadata plays a critical role in ZenML, providing context and additional information about various entities within the platform. Anything which is `metadata` in ZenML can be compared in the dashboard. - -This guide will explain how to log metadata for artifacts and models in ZenML and detail the types of metadata that can be logged. +In ZenML, metadata enhances artifacts by adding context and important details, +such as size, structure, or performance metrics. This metadata is accessible +in the ZenML dashboard, making it easier to inspect, compare, and track +artifacts across pipeline runs. ## Logging Metadata for Artifacts -Artifacts in ZenML are outputs of steps within a pipeline, such as datasets, models, or evaluation results. Associating metadata with artifacts can help users understand the nature and characteristics of these outputs. +Artifacts in ZenML are outputs of steps within a pipeline, such as datasets, +models, or evaluation results. Associating metadata with artifacts can help +users understand the nature and characteristics of these outputs. -To log metadata for an artifact, you can use the `log_artifact_metadata` method. This method allows you to attach a dictionary of key-value pairs as metadata to an artifact. The metadata can be any JSON-serializable value, including custom classes such as `Uri`, `Path`, `DType`, and `StorageSize`. Find out more about these different types [here](../../model-management-metrics/track-metrics-metadata/logging-metadata.md). +To log metadata for an artifact, use the `log_metadata` function, specifying +the artifact name, version, or ID. The metadata can be any JSON-serializable +value, including ZenML custom types like `Uri`, `Path`, `DType`, and +`StorageSize`. Find out more about these different types +[here](../../model-management-metrics/track-metrics-metadata/logging-metadata.md). Here's an example of logging metadata for an artifact: ```python -from zenml import step, log_artifact_metadata +import pandas as pd + +from zenml import step, log_metadata from zenml.metadata.metadata_types import StorageSize + @step -def process_data_step(dataframe: pd.DataFrame) -> Annotated[pd.DataFrame, "processed_data"],: +def process_data_step(dataframe: pd.DataFrame) -> pd.DataFrame: """Process a dataframe and log metadata about the result.""" - # Perform processing on the dataframe... processed_dataframe = ... # Log metadata about the processed dataframe - log_artifact_metadata( - artifact_name="processed_data", + log_metadata( metadata={ "row_count": len(processed_dataframe), "columns": list(processed_dataframe.columns), - "storage_size": StorageSize(processed_dataframe.memory_usage().sum()) - } + "storage_size": StorageSize( + processed_dataframe.memory_usage().sum()) + }, + infer_artifact=True, ) return processed_dataframe ``` +### Selecting the artifact to log the metadata to + +When using `log_metadata` with an artifact name, ZenML provides flexible +options to attach metadata to the correct artifact: + +1. **Using `infer_artifact`**: If used within a step, ZenML will use the step +context to infer the outputs artifacts of the step. If the step has only one +output, this artifact will be selected. However, if you additionally +provide an `artifact_name`, ZenML will search for this name in the output space +of the step (useful for step with multiple outputs). +2. **Name and Version Provided**: If both an artifact name and version are +provided, ZenML will use these to identify and attach metadata to the +specific artifact version. +3. **Artifact Version ID Provided**: If an artifact version ID is provided +directly, ZenML will use it to fetch and attach the metadata to that +specific artifact version. + ## Fetching logged metadata -Once metadata has been logged in an artifact, or [step](../../model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md), we can easily fetch the metadata with the ZenML Client: +Once metadata has been logged in an artifact, or +[step](../track-metrics-metadata/attach-metadata-to-a-model.md), we can easily +fetch the metadata with the ZenML Client: ```python from zenml.client import Client @@ -50,19 +79,29 @@ from zenml.client import Client client = Client() artifact = client.get_artifact_version("my_artifact", "my_version") -print(artifact.run_metadata["metadata_key"].value) +print(artifact.run_metadata["metadata_key"]) ``` +{% hint style="info" %} +When you are fetching metadata using a specific key, the returned value will +always reflect the latest entry. +{% endhint %} + ## Grouping Metadata in the Dashboard -When logging metadata passing a dictionary of dictionaries in the `metadata` parameter will group the metadata into cards in the ZenML dashboard. This feature helps organize metadata into logical sections, making it easier to visualize and understand. +When logging metadata passing a dictionary of dictionaries in the `metadata` +parameter will group the metadata into cards in the ZenML dashboard. This +feature helps organize metadata into logical sections, making it easier to +visualize and understand. Here's an example of grouping metadata into cards: ```python +from zenml import log_metadata + from zenml.metadata.metadata_types import StorageSize -log_artifact_metadata( +log_metadata( metadata={ "model_metrics": { "accuracy": 0.95, @@ -73,12 +112,14 @@ log_artifact_metadata( "dataset_size": StorageSize(1500000), "feature_columns": ["age", "income", "score"] } - } + }, + artifact_name="my_artifact", + artifact_version="version", ) ``` -In the ZenML dashboard, "model\_metrics" and "data\_details" would appear as separate cards, each containing their respective key-value pairs. - -
ZenML Scarf
- +In the ZenML dashboard, `model_metrics` and `data_details` would appear as +separate cards, each containing their respective key-value pairs. + +
ZenML Scarf
\ No newline at end of file diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-steps.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-steps.md deleted file mode 100644 index 299bf131099..00000000000 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-steps.md +++ /dev/null @@ -1,65 +0,0 @@ -# Attach metadata to steps - -You might want to log metadata and have that be attached to a specific step during the course of your work. This is possible by using the `log_step_metadata` method. This method allows you to attach a dictionary of key-value pairs as metadata to a step. The metadata can be any JSON-serializable value, including custom classes such as `Uri`, `Path`, `DType`, and `StorageSize`. - -You can call this method from within a step or from outside. If you call it from within it will attach the metadata to the step and run that is currently being executed. - -```python -from zenml import step, log_step_metadata, ArtifactConfig, get_step_context -from typing import Annotated -import pandas as pd -from sklearn.ensemble import RandomForestClassifier -from sklearn.base import ClassifierMixin - -@step -def train_model(dataset: pd.DataFrame) -> Annotated[ClassifierMixin, ArtifactConfig(name="sklearn_classifier")]: - """Train a model""" - # Fit the model and compute metrics - classifier = RandomForestClassifier().fit(dataset) - accuracy, precision, recall = ... - - # Log metadata at the step level - # This associates the metadata with the ZenML step run - log_step_metadata( - metadata={ - "evaluation_metrics": { - "accuracy": accuracy, - "precision": precision, - "recall": recall - } - }, - ) - return classifier -``` - -If you call it from outside you can attach the metadata to a specific step run from any pipeline and step. This is useful if you want to attach the metadata after you've run the step. - -```python -from zenml import log_step_metadata -# run some step - -# subsequently log the metadata for the step -log_step_metadata( - metadata={ - "some_metadata": {"a_number": 3} - }, - pipeline_name_id_or_prefix="my_pipeline", - step_name="my_step", - run_id="my_step_run_id" -) -``` - -## Fetching logged metadata - -Once metadata has been logged in an [artifact](attach-metadata-to-an-artifact.md), [model](attach-metadata-to-a-model.md), we can easily fetch the metadata with the ZenML Client: - -```python -from zenml.client import Client - -client = Client() -step = client.get_pipeline_run().steps["step_name"] - -print(step.run_metadata["metadata_key"].value) -``` - -
ZenML Scarf
diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/fetch-metadata-within-steps.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/fetch-metadata-within-steps.md index 2e8d940c33d..d57f523483d 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/fetch-metadata-within-steps.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/fetch-metadata-within-steps.md @@ -11,6 +11,7 @@ To find information about the pipeline or step that is currently running, you ca ```python from zenml import step, get_step_context + @step def my_step(): step_context = get_step_context() @@ -22,6 +23,9 @@ def my_step(): Furthermore, you can also use the `StepContext` to find out where the outputs of your current step will be stored and which [Materializer](../../data-artifact-management/handle-data-artifacts/handle-custom-data-types.md) class will be used to save them: ```python +from zenml import step, get_step_context + + @step def my_step(): step_context = get_step_context() diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/grouping-metadata.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/grouping-metadata.md index 6595dd3244a..e90400f96c5 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/grouping-metadata.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/grouping-metadata.md @@ -6,14 +6,18 @@ description: Learn how to group key-value pairs in the dashboard. ![Metadata in the dashboard](../../../.gitbook/assets/metadata-in-dashboard.png) -When logging metadata passing a dictionary of dictionaries in the `metadata` parameter will group the metadata into cards in the ZenML dashboard. This feature helps organize metadata into logical sections, making it easier to visualize and understand. +When logging metadata passing a dictionary of dictionaries in the +`metadata` parameter will group the metadata into cards in the ZenML dashboard. +This feature helps organize metadata into logical sections, making it +easier to visualize and understand. Here's an example of grouping metadata into cards: ```python +from zenml import log_metadata from zenml.metadata.metadata_types import StorageSize -log_artifact_metadata( +log_metadata( metadata={ "model_metrics": { "accuracy": 0.95, @@ -24,11 +28,15 @@ log_artifact_metadata( "dataset_size": StorageSize(1500000), "feature_columns": ["age", "income", "score"] } - } + }, + artifact_name="my_artifact", + artifact_version="my_artifact_version", ) ``` -In the ZenML dashboard, "model\_metrics" and "data\_details" would appear as separate cards, each containing their respective key-value pairs. +In the ZenML dashboard, "model_metrics" and "data_details" would appear +as separate cards, each containing their respective key-value pairs. +
ZenML Scarf
diff --git a/docs/book/how-to/model-management-metrics/track-metrics-metadata/logging-metadata.md b/docs/book/how-to/model-management-metrics/track-metrics-metadata/logging-metadata.md index f7d4c67c199..8ea9fc0f1e1 100644 --- a/docs/book/how-to/model-management-metrics/track-metrics-metadata/logging-metadata.md +++ b/docs/book/how-to/model-management-metrics/track-metrics-metadata/logging-metadata.md @@ -4,13 +4,15 @@ description: Tracking your metadata. # Special Metadata Types -ZenML supports several special metadata types to capture specific kinds of information. Here are examples of how to use the special types `Uri`, `Path`, `DType`, and `StorageSize`: +ZenML supports several special metadata types to capture specific kinds of +information. Here are examples of how to use the special types `Uri`, `Path`, +`DType`, and `StorageSize`: ```python -from zenml.metadata.metadata_types import StorageSize, DType -from zenml import log_artifact_metadata +from zenml import log_metadata +from zenml.metadata.metadata_types import StorageSize, DType, Uri, Path -log_artifact_metadata( +log_metadata( metadata={ "dataset_source": Uri("gs://my-bucket/datasets/source.csv"), "preprocessing_script": Path("/scripts/preprocess.py"), @@ -20,7 +22,7 @@ log_artifact_metadata( "score": DType("int") }, "processed_data_size": StorageSize(2500000) - } + }, ) ``` @@ -31,6 +33,7 @@ In this example: * `DType` is used to describe the data types of specific columns. * `StorageSize` is used to indicate the size of the processed data in bytes. -These special types help standardize the format of metadata and ensure that it is logged in a consistent and interpretable manner. +These special types help standardize the format of metadata and ensure that it +is logged in a consistent and interpretable manner.
ZenML Scarf
diff --git a/docs/book/how-to/pipeline-development/build-pipelines/README.md b/docs/book/how-to/pipeline-development/build-pipelines/README.md index 7240901d161..3e9e67ec289 100644 --- a/docs/book/how-to/pipeline-development/build-pipelines/README.md +++ b/docs/book/how-to/pipeline-development/build-pipelines/README.md @@ -8,6 +8,9 @@ description: >- # Build a pipeline ```python +from zenml import pipeline, step + + @step # Just add this decorator def load_data() -> dict: training_data = [[1, 2], [3, 4], [5, 6]] @@ -46,6 +49,6 @@ locally or remotely. See our documentation on this [here](../../../getting-start Check below for more advanced ways to build and interact with your pipeline. -
Configure pipeline/step parametersuse-pipeline-step-parameters.md
Name and annotate step outputsstep-output-typing-and-annotation.md
Control caching behaviorcontrol-caching-behavior.md
Run pipeline from a pipelinetrigger-a-pipeline-from-another.md
Control the execution order of stepscontrol-execution-order-of-steps.md
Customize the step invocation idsusing-a-custom-step-invocation-id.md
Name your pipeline runsname-your-pipeline-and-runs.md
Use failure/success hooksuse-failure-success-hooks.md
Hyperparameter tuninghyper-parameter-tuning.md
Attach metadata to stepsattach-metadata-to-steps.md
Fetch metadata within stepsfetch-metadata-within-steps.md
Fetch metadata during pipeline compositionfetch-metadata-within-pipeline.md
Enable or disable logs storingenable-or-disable-logs-storing.md
Special Metadata Typeslogging-metadata.md
Access secrets in a stepaccess-secrets-in-a-step.md
+
Configure pipeline/step parametersuse-pipeline-step-parameters.md
Name and annotate step outputsstep-output-typing-and-annotation.md
Control caching behaviorcontrol-caching-behavior.md
Run pipeline from a pipelinetrigger-a-pipeline-from-another.md
Control the execution order of stepscontrol-execution-order-of-steps.md
Customize the step invocation idsusing-a-custom-step-invocation-id.md
Name your pipeline runsname-your-pipeline-and-runs.md
Use failure/success hooksuse-failure-success-hooks.md
Hyperparameter tuninghyper-parameter-tuning.md
Attach metadata to a stepattach-metadata-to-a-step.md
Fetch metadata within stepsfetch-metadata-within-steps.md
Fetch metadata during pipeline compositionfetch-metadata-within-pipeline.md
Enable or disable logs storingenable-or-disable-logs-storing.md
Special Metadata Typeslogging-metadata.md
Access secrets in a stepaccess-secrets-in-a-step.md
ZenML Scarf
diff --git a/docs/book/toc.md b/docs/book/toc.md index db14ef7678a..1211cdb9d82 100644 --- a/docs/book/toc.md +++ b/docs/book/toc.md @@ -150,9 +150,10 @@ * [Linking model binaries/data to a Model](how-to/model-management-metrics/model-control-plane/linking-model-binaries-data-to-models.md) * [Load artifacts from Model](how-to/model-management-metrics/model-control-plane/load-artifacts-from-model.md) * [Track metrics and metadata](how-to/model-management-metrics/track-metrics-metadata/README.md) - * [Attach metadata to a model](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md) + * [Attach metadata to a step](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-step.md) + * [Attach metadata to a run](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-run.md) * [Attach metadata to an artifact](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-an-artifact.md) - * [Attach metadata to steps](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-steps.md) + * [Attach metadata to a model](how-to/model-management-metrics/track-metrics-metadata/attach-metadata-to-a-model.md) * [Group metadata](how-to/model-management-metrics/track-metrics-metadata/grouping-metadata.md) * [Special Metadata Types](how-to/model-management-metrics/track-metrics-metadata/logging-metadata.md) * [Fetch metadata within steps](how-to/model-management-metrics/track-metrics-metadata/fetch-metadata-within-steps.md) diff --git a/examples/e2e/.copier-answers.yml b/examples/e2e/.copier-answers.yml index 0a2f40d5a92..e6fb1292beb 100644 --- a/examples/e2e/.copier-answers.yml +++ b/examples/e2e/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.11.20 +_commit: 2024.11.20-2-g760142f _src_path: gh:zenml-io/template-e2e-batch data_quality_checks: true email: info@zenml.io diff --git a/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py b/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py index f2f39969a6f..7b55eebae7a 100644 --- a/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py +++ b/examples/e2e/steps/hp_tuning/hp_tuning_single_search.py @@ -25,7 +25,7 @@ from typing_extensions import Annotated from utils import get_model_from_config -from zenml import log_artifact_metadata, step +from zenml import log_metadata, step from zenml.logger import get_logger logger = get_logger(__name__) @@ -95,9 +95,10 @@ def hp_tuning_single_search( y_pred = cv.predict(X_tst) score = accuracy_score(y_tst, y_pred) # log score along with output artifact as metadata - log_artifact_metadata( + log_metadata( metadata={"metric": float(score)}, artifact_name="hp_result", + infer_artifact=True, ) ### YOUR CODE ENDS HERE ### return cv.best_estimator_ diff --git a/examples/e2e_nlp/.copier-answers.yml b/examples/e2e_nlp/.copier-answers.yml index e13858e7da1..274927e3ce5 100644 --- a/examples/e2e_nlp/.copier-answers.yml +++ b/examples/e2e_nlp/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.10.30 +_commit: 2024.10.30-2-g1ae14e3 _src_path: gh:zenml-io/template-nlp accelerator: cpu cloud_of_choice: aws diff --git a/examples/e2e_nlp/steps/training/model_trainer.py b/examples/e2e_nlp/steps/training/model_trainer.py index edb9ab23ba5..0a3de574c09 100644 --- a/examples/e2e_nlp/steps/training/model_trainer.py +++ b/examples/e2e_nlp/steps/training/model_trainer.py @@ -30,7 +30,7 @@ from typing_extensions import Annotated from utils.misc import compute_metrics -from zenml import ArtifactConfig, log_artifact_metadata, step +from zenml import ArtifactConfig, log_metadata, step from zenml.client import Client from zenml.integrations.mlflow.experiment_trackers import ( MLFlowExperimentTracker, @@ -157,9 +157,10 @@ def model_trainer( eval_results = trainer.evaluate(metric_key_prefix="") # Log the evaluation results in model control plane - log_artifact_metadata( + log_metadata( metadata={"metrics": eval_results}, artifact_name="model", + infer_artifact=True, ) ### YOUR CODE ENDS HERE ### diff --git a/examples/llm_finetuning/.copier-answers.yml b/examples/llm_finetuning/.copier-answers.yml index 2c547f98d61..7deecebb1d2 100644 --- a/examples/llm_finetuning/.copier-answers.yml +++ b/examples/llm_finetuning/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.11.08 +_commit: 2024.11.08-2-gece1d46 _src_path: gh:zenml-io/template-llm-finetuning bf16: true cuda_version: cuda11.8 diff --git a/examples/llm_finetuning/steps/log_metadata.py b/examples/llm_finetuning/steps/log_metadata.py index 645f98cc8ea..90109fdf3c4 100644 --- a/examples/llm_finetuning/steps/log_metadata.py +++ b/examples/llm_finetuning/steps/log_metadata.py @@ -17,7 +17,7 @@ from typing import Any, Dict -from zenml import get_step_context, log_model_metadata, step +from zenml import get_step_context, log_metadata, step @step(enable_cache=False) @@ -39,4 +39,9 @@ def log_metadata_from_step_artifact( metadata = {artifact_name: metadata_dict} - log_model_metadata(metadata) + if context.model: + log_metadata( + metadata=metadata, + model_name=context.model.name, + model_version=context.model.version, + ) diff --git a/examples/llm_finetuning/steps/prepare_datasets.py b/examples/llm_finetuning/steps/prepare_datasets.py index fe98126369d..b9cc13c2261 100644 --- a/examples/llm_finetuning/steps/prepare_datasets.py +++ b/examples/llm_finetuning/steps/prepare_datasets.py @@ -22,7 +22,7 @@ from typing_extensions import Annotated from utils.tokenizer import generate_and_tokenize_prompt, load_tokenizer -from zenml import log_model_metadata, step +from zenml import log_metadata, step from zenml.materializers import BuiltInMaterializer from zenml.utils.cuda_utils import cleanup_gpu_memory @@ -49,11 +49,12 @@ def prepare_data( cleanup_gpu_memory(force=True) - log_model_metadata( - { + log_metadata( + metadata={ "system_prompt": system_prompt, "base_model_id": base_model_id, - } + }, + infer_model=True, ) tokenizer = load_tokenizer(base_model_id, False, use_fast) diff --git a/examples/mlops_starter/.copier-answers.yml b/examples/mlops_starter/.copier-answers.yml index fd6b937c7c9..364bccaa9d0 100644 --- a/examples/mlops_starter/.copier-answers.yml +++ b/examples/mlops_starter/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2024.10.30 +_commit: 2024.10.30-7-gb60e441 _src_path: gh:zenml-io/template-starter email: info@zenml.io full_name: ZenML GmbH diff --git a/examples/mlops_starter/steps/data_preprocessor.py b/examples/mlops_starter/steps/data_preprocessor.py index 0cf9d3ab521..f94d1e85f6d 100644 --- a/examples/mlops_starter/steps/data_preprocessor.py +++ b/examples/mlops_starter/steps/data_preprocessor.py @@ -23,7 +23,7 @@ from typing_extensions import Annotated from utils.preprocess import ColumnsDropper, DataFrameCaster, NADropper -from zenml import log_artifact_metadata, step +from zenml import log_metadata, step @step @@ -87,8 +87,9 @@ def data_preprocessor( dataset_tst = preprocess_pipeline.transform(dataset_tst) # Log metadata so we can load it in the inference pipeline - log_artifact_metadata( - artifact_name="preprocess_pipeline", + log_metadata( metadata={"random_state": random_state, "target": target}, + artifact_name="preprocess_pipeline", + infer_artifact=True, ) return dataset_trn, dataset_tst, preprocess_pipeline diff --git a/examples/mlops_starter/steps/model_evaluator.py b/examples/mlops_starter/steps/model_evaluator.py index 2a9b6ee9e75..c63c53109f4 100644 --- a/examples/mlops_starter/steps/model_evaluator.py +++ b/examples/mlops_starter/steps/model_evaluator.py @@ -20,7 +20,8 @@ import pandas as pd from sklearn.base import ClassifierMixin -from zenml import log_artifact_metadata, step +from zenml import log_metadata, step +from zenml.client import Client from zenml.logger import get_logger logger = get_logger(__name__) @@ -79,27 +80,31 @@ def model_evaluator( dataset_tst.drop(columns=[target]), dataset_tst[target], ) - logger.info(f"Train accuracy={trn_acc*100:.2f}%") - logger.info(f"Test accuracy={tst_acc*100:.2f}%") + logger.info(f"Train accuracy={trn_acc * 100:.2f}%") + logger.info(f"Test accuracy={tst_acc * 100:.2f}%") messages = [] if trn_acc < min_train_accuracy: messages.append( - f"Train accuracy {trn_acc*100:.2f}% is below {min_train_accuracy*100:.2f}% !" + f"Train accuracy {trn_acc * 100:.2f}% is below {min_train_accuracy * 100:.2f}% !" ) if tst_acc < min_test_accuracy: messages.append( - f"Test accuracy {tst_acc*100:.2f}% is below {min_test_accuracy*100:.2f}% !" + f"Test accuracy {tst_acc * 100:.2f}% is below {min_test_accuracy * 100:.2f}% !" ) else: for message in messages: logger.warning(message) - log_artifact_metadata( + client = Client() + latest_classifier = client.get_artifact_version("sklearn_classifier") + + log_metadata( metadata={ "train_accuracy": float(trn_acc), "test_accuracy": float(tst_acc), }, - artifact_name="sklearn_classifier", + artifact_version_id=latest_classifier.id, ) + return float(tst_acc) diff --git a/examples/quickstart/steps/model_evaluator.py b/examples/quickstart/steps/model_evaluator.py index fc8dac00132..4ae2e979396 100644 --- a/examples/quickstart/steps/model_evaluator.py +++ b/examples/quickstart/steps/model_evaluator.py @@ -20,7 +20,7 @@ T5ForConditionalGeneration, ) -from zenml import get_step_context, log_metadata, step +from zenml import log_metadata, step from zenml.logger import get_logger logger = get_logger(__name__) @@ -50,11 +50,4 @@ def evaluate_model( avg_loss = total_loss / num_batches print(f"Average loss on the dataset: {avg_loss}") - step_context = get_step_context() - - if step_context.model: - log_metadata( - metadata={"Average Loss": avg_loss}, - model_name=step_context.model.name, - model_version=step_context.model.version, - ) + log_metadata(metadata={"Average Loss": avg_loss}, infer_model=True) diff --git a/examples/quickstart/steps/model_tester.py b/examples/quickstart/steps/model_tester.py index 72d68ed7d57..93e261b7ef4 100644 --- a/examples/quickstart/steps/model_tester.py +++ b/examples/quickstart/steps/model_tester.py @@ -21,7 +21,7 @@ T5TokenizerFast, ) -from zenml import get_step_context, log_metadata, step +from zenml import log_metadata, step from zenml.logger import get_logger from .data_loader import PROMPT @@ -70,11 +70,7 @@ def test_model( sentence_without_prompt: decoded_output } - step_context = get_step_context() - - if step_context.model: - log_metadata( - metadata={"Example Prompts": test_collection}, - model_name=step_context.model.name, - model_version=step_context.model.version, - ) + log_metadata( + metadata={"Example Prompts": test_collection}, + infer_model=True, + ) diff --git a/src/zenml/artifacts/utils.py b/src/zenml/artifacts/utils.py index 22067abe33b..2573964aa77 100644 --- a/src/zenml/artifacts/utils.py +++ b/src/zenml/artifacts/utils.py @@ -14,6 +14,7 @@ """Utility functions for handling artifacts.""" import base64 +import contextlib import os import tempfile import zipfile @@ -41,7 +42,6 @@ ArtifactSaveType, ArtifactType, ExecutionStatus, - MetadataResourceTypes, StackComponentType, VisualizationType, ) @@ -404,50 +404,71 @@ def log_artifact_metadata( artifact_version: The version of the artifact to log metadata for. If not provided, when being called inside a step that produces an artifact named `artifact_name`, the metadata will be associated to - the corresponding newly created artifact. Or, if not provided when - being called outside a step, or in a step that does not produce - any artifact named `artifact_name`, the metadata will be associated - to the latest version of that artifact. + the corresponding newly created artifact. Raises: ValueError: If no artifact name is provided and the function is not called inside a step with a single output, or, if neither an artifact nor an output with the given name exists. + """ logger.warning( "The `log_artifact_metadata` function is deprecated and will soon be " "removed. Please use `log_metadata` instead." ) - try: + + from zenml import log_metadata + + if artifact_name and artifact_version: + assert artifact_name is not None + + log_metadata( + metadata=metadata, + artifact_name=artifact_name, + artifact_version=artifact_version, + ) + + step_context = None + with contextlib.suppress(RuntimeError): step_context = get_step_context() - in_step_outputs = (artifact_name in step_context._outputs) or ( - not artifact_name and len(step_context._outputs) == 1 + + if step_context and artifact_name in step_context._outputs.keys(): + log_metadata( + metadata=metadata, + artifact_name=artifact_name, + infer_artifact=True, ) - except RuntimeError: - step_context = None - in_step_outputs = False - - if not step_context or not in_step_outputs or artifact_version: - if not artifact_name: - raise ValueError( - "Artifact name must be provided unless the function is called " - "inside a step with a single output." - ) + elif step_context and len(step_context._outputs) == 1: + single_output_name = list(step_context._outputs.keys())[0] + + log_metadata( + metadata=metadata, + artifact_name=single_output_name, + infer_artifact=True, + ) + elif artifact_name: client = Client() - response = client.get_artifact_version(artifact_name, artifact_version) - client.create_run_metadata( + logger.warning( + "Deprecation warning! Currently, you are calling " + "`log_artifact_metadata` from a context, where we use the " + "`artifact_name` to fetch it and link the metadata to its " + "latest version. This behavior is deprecated and will be " + "removed in the future. To circumvent this, please check" + "the `log_metadata` function." + ) + artifact_version_model = client.get_artifact_version( + name_id_or_prefix=artifact_name + ) + log_metadata( metadata=metadata, - resource_id=response.id, - resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + artifact_version_id=artifact_version_model.id, ) - else: - try: - step_context.add_output_metadata( - metadata=metadata, output_name=artifact_name - ) - except StepContextError as e: - raise ValueError(e) + raise ValueError( + "You need to call `log_artifact_metadata` either within a step " + "(potentially with an artifact name) or outside of a step with an " + "artifact name (and/or version)." + ) # ----------------- diff --git a/src/zenml/cli/base.py b/src/zenml/cli/base.py index 23df0206161..8bc22c45446 100644 --- a/src/zenml/cli/base.py +++ b/src/zenml/cli/base.py @@ -79,19 +79,19 @@ def copier_github_url(self) -> str: ZENML_PROJECT_TEMPLATES = dict( e2e_batch=ZenMLProjectTemplateLocation( github_url="zenml-io/template-e2e-batch", - github_tag="2024.11.20", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + github_tag="2024.11.28", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), starter=ZenMLProjectTemplateLocation( github_url="zenml-io/template-starter", - github_tag="2024.10.30", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + github_tag="2024.11.28", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), nlp=ZenMLProjectTemplateLocation( github_url="zenml-io/template-nlp", - github_tag="2024.10.30", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + github_tag="2024.11.28", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), llm_finetuning=ZenMLProjectTemplateLocation( github_url="zenml-io/template-llm-finetuning", - github_tag="2024.11.08", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + github_tag="2024.11.28", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), ) diff --git a/src/zenml/client.py b/src/zenml/client.py index af3d6d2daa1..995f2d8bdb3 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -60,7 +60,6 @@ from zenml.enums import ( ArtifactType, LogicalOperators, - MetadataResourceTypes, ModelStages, OAuthDeviceStatus, PluginSubType, @@ -137,6 +136,7 @@ PipelineRunFilter, PipelineRunResponse, RunMetadataRequest, + RunMetadataResource, RunTemplateFilter, RunTemplateRequest, RunTemplateResponse, @@ -4432,23 +4432,20 @@ def _delete_artifact_from_artifact_store( def create_run_metadata( self, metadata: Dict[str, "MetadataType"], - resource_id: UUID, - resource_type: MetadataResourceTypes, + resources: List[RunMetadataResource], stack_component_id: Optional[UUID] = None, + publisher_step_id: Optional[UUID] = None, ) -> None: """Create run metadata. Args: metadata: The metadata to create as a dictionary of key-value pairs. - resource_id: The ID of the resource for which the - metadata was produced. - resource_type: The type of the resource for which the + resources: The list of IDs and types of the resources for that the metadata was produced. stack_component_id: The ID of the stack component that produced the metadata. - - Returns: - None + publisher_step_id: The ID of the step execution that publishes + this metadata automatically. """ from zenml.metadata.metadata_types import get_metadata_type @@ -4477,14 +4474,13 @@ def create_run_metadata( run_metadata = RunMetadataRequest( workspace=self.active_workspace.id, user=self.active_user.id, - resource_id=resource_id, - resource_type=resource_type, + resources=resources, stack_component_id=stack_component_id, + publisher_step_id=publisher_step_id, values=values, types=types, ) self.zen_store.create_run_metadata(run_metadata) - return None # -------------------------------- Secrets --------------------------------- diff --git a/src/zenml/model/model.py b/src/zenml/model/model.py index c2f287f9bb5..b7f3c591518 100644 --- a/src/zenml/model/model.py +++ b/src/zenml/model/model.py @@ -336,12 +336,16 @@ def log_metadata( metadata: The metadata to log. """ from zenml.client import Client + from zenml.models import RunMetadataResource response = self._get_or_create_model_version() Client().create_run_metadata( metadata=metadata, - resource_id=response.id, - resource_type=MetadataResourceTypes.MODEL_VERSION, + resources=[ + RunMetadataResource( + id=response.id, type=MetadataResourceTypes.MODEL_VERSION + ) + ], ) @property diff --git a/src/zenml/model/utils.py b/src/zenml/model/utils.py index 6f83bd2bd60..a3612fc2c12 100644 --- a/src/zenml/model/utils.py +++ b/src/zenml/model/utils.py @@ -52,29 +52,31 @@ def log_model_metadata( `model` in decorator. Raises: - ValueError: If no model name/version is provided and the function is not - called inside a step with configured `model` in decorator. + ValueError: If the function is not called with proper input. """ logger.warning( "The `log_model_metadata` function is deprecated and will soon be " "removed. Please use `log_metadata` instead." ) - if model_name and model_version: - from zenml import Model + from zenml import log_metadata - mv = Model(name=model_name, version=model_version) + if model_name and model_version: + log_metadata( + metadata=metadata, + model_version=model_version, + model_name=model_name, + ) + elif model_name is None and model_version is None: + log_metadata( + metadata=metadata, + infer_model=True, + ) else: - try: - step_context = get_step_context() - except RuntimeError: - raise ValueError( - "Model name and version must be provided unless the function is " - "called inside a step with configured `model` in decorator." - ) - mv = step_context.model - - mv.log_metadata(metadata) + raise ValueError( + "You can call `log_model_metadata` by either providing both " + "`model_name` and `model_version` or keeping both of them None." + ) def link_artifact_version_to_model_version( @@ -107,7 +109,7 @@ def link_artifact_to_model( model: The model to link to. Raises: - RuntimeError: If called outside of a step. + RuntimeError: If called outside a step. """ if not model: is_issue = False diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index 1bd3af5f22d..d43b1b09708 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -372,6 +372,10 @@ OAuthRedirectResponse, OAuthTokenResponse, ) +from zenml.models.v2.misc.run_metadata import ( + RunMetadataEntry, + RunMetadataResource, +) from zenml.models.v2.misc.server_models import ( ServerModel, ServerDatabaseType, @@ -752,4 +756,6 @@ "ServiceConnectorInfo", "ServiceConnectorResourcesInfo", "ResourcesInfo", + "RunMetadataEntry", + "RunMetadataResource", ] diff --git a/src/zenml/models/v2/core/artifact_version.py b/src/zenml/models/v2/core/artifact_version.py index 646e5c16ce1..cd5089a3db4 100644 --- a/src/zenml/models/v2/core/artifact_version.py +++ b/src/zenml/models/v2/core/artifact_version.py @@ -576,6 +576,7 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: ModelVersionArtifactSchema, ModelVersionSchema, PipelineRunSchema, + RunMetadataResourceSchema, RunMetadataSchema, StepRunInputArtifactSchema, StepRunOutputArtifactSchema, @@ -679,10 +680,12 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: for key, value in self.run_metadata.items(): additional_filter = and_( - RunMetadataSchema.resource_id == ArtifactVersionSchema.id, - RunMetadataSchema.resource_type + RunMetadataResourceSchema.resource_id + == ArtifactVersionSchema.id, + RunMetadataResourceSchema.resource_type == MetadataResourceTypes.ARTIFACT_VERSION, - RunMetadataSchema.key == key, + RunMetadataResourceSchema.run_metadata_id + == RunMetadataSchema.id, self.generate_custom_query_conditions_for_column( value=value, table=RunMetadataSchema, diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index 9ce9b1692b5..d1a7a951978 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -652,6 +652,7 @@ def get_custom_filters( from zenml.zen_stores.schemas import ( ModelVersionSchema, + RunMetadataResourceSchema, RunMetadataSchema, UserSchema, ) @@ -672,10 +673,12 @@ def get_custom_filters( for key, value in self.run_metadata.items(): additional_filter = and_( - RunMetadataSchema.resource_id == ModelVersionSchema.id, - RunMetadataSchema.resource_type + RunMetadataResourceSchema.resource_id + == ModelVersionSchema.id, + RunMetadataResourceSchema.resource_type == MetadataResourceTypes.MODEL_VERSION, - RunMetadataSchema.key == key, + RunMetadataResourceSchema.run_metadata_id + == RunMetadataSchema.id, self.generate_custom_query_conditions_for_column( value=value, table=RunMetadataSchema, diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index 5d0fdd55052..958d662a515 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -735,6 +735,7 @@ def get_custom_filters( PipelineDeploymentSchema, PipelineRunSchema, PipelineSchema, + RunMetadataResourceSchema, RunMetadataSchema, ScheduleSchema, StackComponentSchema, @@ -910,10 +911,12 @@ def get_custom_filters( for key, value in self.run_metadata.items(): additional_filter = and_( - RunMetadataSchema.resource_id == PipelineRunSchema.id, - RunMetadataSchema.resource_type + RunMetadataResourceSchema.resource_id + == PipelineRunSchema.id, + RunMetadataResourceSchema.resource_type == MetadataResourceTypes.PIPELINE_RUN, - RunMetadataSchema.key == key, + RunMetadataResourceSchema.run_metadata_id + == RunMetadataSchema.id, self.generate_custom_query_conditions_for_column( value=value, table=RunMetadataSchema, diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index c4a2ef8e678..5822451357d 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -13,16 +13,16 @@ # permissions and limitations under the License. """Models representing run metadata.""" -from typing import Dict, Optional +from typing import Dict, List, Optional from uuid import UUID -from pydantic import Field +from pydantic import Field, model_validator -from zenml.enums import MetadataResourceTypes from zenml.metadata.metadata_types import MetadataType, MetadataTypeEnum from zenml.models.v2.base.scoped import ( WorkspaceScopedRequest, ) +from zenml.models.v2.misc.run_metadata import RunMetadataResource # ------------------ Request Model ------------------ @@ -30,14 +30,12 @@ class RunMetadataRequest(WorkspaceScopedRequest): """Request model for run metadata.""" - resource_id: UUID = Field( - title="The ID of the resource that this metadata belongs to.", - ) - resource_type: MetadataResourceTypes = Field( - title="The type of the resource that this metadata belongs to.", + resources: List[RunMetadataResource] = Field( + title="The list of resources that this metadata belongs to." ) stack_component_id: Optional[UUID] = Field( - title="The ID of the stack component that this metadata belongs to." + title="The ID of the stack component that this metadata belongs to.", + default=None, ) values: Dict[str, "MetadataType"] = Field( title="The metadata to be created.", @@ -45,3 +43,26 @@ class RunMetadataRequest(WorkspaceScopedRequest): types: Dict[str, "MetadataTypeEnum"] = Field( title="The types of the metadata to be created.", ) + publisher_step_id: Optional[UUID] = Field( + title="The ID of the step execution that published this metadata.", + default=None, + ) + + @model_validator(mode="after") + def validate_values_keys(self) -> "RunMetadataRequest": + """Validates if the keys in the metadata are properly defined. + + Returns: + self + + Raises: + ValueError: if one of the key in the metadata contains `:` + """ + invalid_keys = [key for key in self.values.keys() if ":" in key] + if invalid_keys: + raise ValueError( + "You can not use colons (`:`) in the key names when you " + "are creating metadata for your ZenML objects. Please change " + f"the following keys: {invalid_keys}" + ) + return self diff --git a/src/zenml/models/v2/core/step_run.py b/src/zenml/models/v2/core/step_run.py index 2916d9236ce..d9ac5e0354a 100644 --- a/src/zenml/models/v2/core/step_run.py +++ b/src/zenml/models/v2/core/step_run.py @@ -594,6 +594,7 @@ def get_custom_filters( from zenml.zen_stores.schemas import ( ModelSchema, ModelVersionSchema, + RunMetadataResourceSchema, RunMetadataSchema, StepRunSchema, ) @@ -612,10 +613,11 @@ def get_custom_filters( for key, value in self.run_metadata.items(): additional_filter = and_( - RunMetadataSchema.resource_id == StepRunSchema.id, - RunMetadataSchema.resource_type + RunMetadataResourceSchema.resource_id == StepRunSchema.id, + RunMetadataResourceSchema.resource_type == MetadataResourceTypes.STEP_RUN, - RunMetadataSchema.key == key, + RunMetadataResourceSchema.run_metadata_id + == RunMetadataSchema.id, self.generate_custom_query_conditions_for_column( value=value, table=RunMetadataSchema, diff --git a/src/zenml/models/v2/misc/run_metadata.py b/src/zenml/models/v2/misc/run_metadata.py new file mode 100644 index 00000000000..1769ff30ad6 --- /dev/null +++ b/src/zenml/models/v2/misc/run_metadata.py @@ -0,0 +1,38 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Utility classes for modeling run metadata.""" + +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel, Field + +from zenml.enums import MetadataResourceTypes +from zenml.metadata.metadata_types import MetadataType + + +class RunMetadataResource(BaseModel): + """Utility class to help identify resources to tag metadata to.""" + + id: UUID = Field(title="The ID of the resource.") + type: MetadataResourceTypes = Field(title="The type of the resource.") + + +class RunMetadataEntry(BaseModel): + """Utility class to sort/list run metadata entries.""" + + value: MetadataType = Field(title="The value for the run metadata entry") + created: datetime = Field( + title="The timestamp when this resource was created." + ) diff --git a/src/zenml/orchestrators/publish_utils.py b/src/zenml/orchestrators/publish_utils.py index 7e4cf89c6e3..0d5cea792ae 100644 --- a/src/zenml/orchestrators/publish_utils.py +++ b/src/zenml/orchestrators/publish_utils.py @@ -21,6 +21,7 @@ from zenml.models import ( PipelineRunResponse, PipelineRunUpdate, + RunMetadataResource, StepRunResponse, StepRunUpdate, ) @@ -129,8 +130,11 @@ def publish_pipeline_run_metadata( for stack_component_id, metadata in pipeline_run_metadata.items(): client.create_run_metadata( metadata=metadata, - resource_id=pipeline_run_id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, + resources=[ + RunMetadataResource( + id=pipeline_run_id, type=MetadataResourceTypes.PIPELINE_RUN + ) + ], stack_component_id=stack_component_id, ) @@ -150,7 +154,10 @@ def publish_step_run_metadata( for stack_component_id, metadata in step_run_metadata.items(): client.create_run_metadata( metadata=metadata, - resource_id=step_run_id, - resource_type=MetadataResourceTypes.STEP_RUN, + resources=[ + RunMetadataResource( + id=step_run_id, type=MetadataResourceTypes.STEP_RUN + ) + ], stack_component_id=stack_component_id, ) diff --git a/src/zenml/steps/utils.py b/src/zenml/steps/utils.py index cd4b720c068..780ccd62589 100644 --- a/src/zenml/steps/utils.py +++ b/src/zenml/steps/utils.py @@ -42,6 +42,7 @@ from zenml.exceptions import StepInterfaceError from zenml.logger import get_logger from zenml.metadata.metadata_types import MetadataType +from zenml.models import RunMetadataResource from zenml.steps.step_context import get_step_context from zenml.utils import settings_utils, source_code_utils, typing_utils @@ -489,8 +490,11 @@ def log_step_metadata( step_run_id = pipeline_run.steps[step_name].id client.create_run_metadata( metadata=metadata, - resource_id=step_run_id, - resource_type=MetadataResourceTypes.STEP_RUN, + resources=[ + RunMetadataResource( + id=step_run_id, type=MetadataResourceTypes.STEP_RUN + ) + ], ) diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 47bd4f06e38..2b4e641f039 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -13,28 +13,30 @@ # permissions and limitations under the License. """Utility functions to handle metadata for ZenML entities.""" -import contextlib -from typing import Dict, Optional, Union, overload +from typing import Dict, List, Optional, Union, overload from uuid import UUID from zenml.client import Client -from zenml.enums import MetadataResourceTypes +from zenml.enums import MetadataResourceTypes, ModelStages from zenml.logger import get_logger from zenml.metadata.metadata_types import MetadataType +from zenml.models import RunMetadataResource from zenml.steps.step_context import get_step_context logger = get_logger(__name__) @overload -def log_metadata(metadata: Dict[str, MetadataType]) -> None: ... +def log_metadata( + metadata: Dict[str, MetadataType], +) -> None: ... @overload def log_metadata( *, metadata: Dict[str, MetadataType], - artifact_version_id: UUID, + step_id: UUID, ) -> None: ... @@ -42,8 +44,8 @@ def log_metadata( def log_metadata( *, metadata: Dict[str, MetadataType], - artifact_name: str, - artifact_version: Optional[str] = None, + step_name: str, + run_id_name_or_prefix: Union[UUID, str], ) -> None: ... @@ -51,7 +53,7 @@ def log_metadata( def log_metadata( *, metadata: Dict[str, MetadataType], - model_version_id: UUID, + run_id_name_or_prefix: Union[UUID, str], ) -> None: ... @@ -59,8 +61,7 @@ def log_metadata( def log_metadata( *, metadata: Dict[str, MetadataType], - model_name: str, - model_version: str, + artifact_version_id: UUID, ) -> None: ... @@ -68,7 +69,8 @@ def log_metadata( def log_metadata( *, metadata: Dict[str, MetadataType], - step_id: UUID, + artifact_name: str, + artifact_version: Optional[str] = None, ) -> None: ... @@ -76,33 +78,53 @@ def log_metadata( def log_metadata( *, metadata: Dict[str, MetadataType], - run_id_name_or_prefix: Union[UUID, str], + infer_artifact: bool = False, + artifact_name: Optional[str] = None, ) -> None: ... +# Model Metadata @overload def log_metadata( *, metadata: Dict[str, MetadataType], - step_name: str, - run_id_name_or_prefix: Union[UUID, str], + model_version_id: UUID, ) -> None: ... +@overload def log_metadata( + *, metadata: Dict[str, MetadataType], - # Parameters to manually log metadata for steps and runs + model_name: str, + model_version: Union[ModelStages, int, str], +) -> None: ... + + +@overload +def log_metadata( + *, + metadata: Dict[str, MetadataType], + infer_model: bool = False, +) -> None: ... + + +def log_metadata( + metadata: Dict[str, MetadataType], + # Steps and runs step_id: Optional[UUID] = None, step_name: Optional[str] = None, run_id_name_or_prefix: Optional[Union[UUID, str]] = None, - # Parameters to manually log metadata for artifacts + # Artifacts artifact_version_id: Optional[UUID] = None, artifact_name: Optional[str] = None, artifact_version: Optional[str] = None, - # Parameters to manually log metadata for models + infer_artifact: bool = False, + # Models model_version_id: Optional[UUID] = None, model_name: Optional[str] = None, - model_version: Optional[str] = None, + model_version: Optional[Union[ModelStages, int, str]] = None, + infer_model: bool = False, ) -> None: """Logs metadata for various resource types in a generalized way. @@ -114,9 +136,13 @@ def log_metadata( artifact_version_id: The ID of the artifact version artifact_name: The name of the artifact. artifact_version: The version of the artifact. + infer_artifact: Flag deciding whether the artifact version should be + inferred from the step context. model_version_id: The ID of the model version. model_name: The name of the model. - model_version: The version of the model + model_version: The version of the model. + infer_model: Flag deciding whether the model version should be + inferred from the step context. Raises: ValueError: If no identifiers are provided and the function is not @@ -124,141 +150,147 @@ def log_metadata( """ client = Client() - # If a step name is provided, we need a run_id_name_or_prefix and will log - # metadata for the steps pipeline and model accordingly. - if step_name is not None and run_id_name_or_prefix is not None: - run_model = client.get_pipeline_run( - name_id_or_prefix=run_id_name_or_prefix - ) - step_model = run_model.steps[step_name] + resources: List[RunMetadataResource] = [] + publisher_step_id = None - client.create_run_metadata( - metadata=metadata, - resource_id=run_model.id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, - ) - client.create_run_metadata( - metadata=metadata, - resource_id=step_model.id, - resource_type=MetadataResourceTypes.STEP_RUN, - ) - if step_model.model_version: - client.create_run_metadata( - metadata=metadata, - resource_id=step_model.model_version.id, - resource_type=MetadataResourceTypes.MODEL_VERSION, + # Log metadata to a step by ID + if step_id is not None: + resources = [ + RunMetadataResource( + id=step_id, type=MetadataResourceTypes.STEP_RUN ) + ] - # If a step is identified by id, fetch it directly through the client, - # follow a similar procedure and log metadata for its pipeline and model - # as well. - elif step_id is not None: - step_model = client.get_run_step(step_run_id=step_id) - client.create_run_metadata( - metadata=metadata, - resource_id=step_model.pipeline_run_id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, - ) - client.create_run_metadata( - metadata=metadata, - resource_id=step_model.id, - resource_type=MetadataResourceTypes.STEP_RUN, + # Log metadata to a step by name and run ID + elif step_name is not None and run_id_name_or_prefix is not None: + step_model_id = ( + client.get_pipeline_run(name_id_or_prefix=run_id_name_or_prefix) + .steps[step_name] + .id ) - if step_model.model_version: - client.create_run_metadata( - metadata=metadata, - resource_id=step_model.model_version.id, - resource_type=MetadataResourceTypes.MODEL_VERSION, + resources = [ + RunMetadataResource( + id=step_model_id, type=MetadataResourceTypes.STEP_RUN ) + ] - # If a pipeline run id is identified, we need to log metadata to it and its - # model as well. + # Log metadata to a run by ID elif run_id_name_or_prefix is not None: run_model = client.get_pipeline_run( name_id_or_prefix=run_id_name_or_prefix ) - client.create_run_metadata( - metadata=metadata, - resource_id=run_model.id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, - ) - if run_model.model_version: - client.create_run_metadata( - metadata=metadata, - resource_id=run_model.model_version.id, - resource_type=MetadataResourceTypes.MODEL_VERSION, + resources = [ + RunMetadataResource( + id=run_model.id, type=MetadataResourceTypes.PIPELINE_RUN ) + ] - # If the user provides a model name and version, we use to model abstraction - # to fetch the model version and attach the corresponding metadata to it. + # Log metadata to a model version by name and version elif model_name is not None and model_version is not None: - from zenml import Model - - mv = Model(name=model_name, version=model_version) - client.create_run_metadata( - metadata=metadata, - resource_id=mv.id, - resource_type=MetadataResourceTypes.MODEL_VERSION, + model_version_model = client.get_model_version( + model_name_or_id=model_name, + model_version_name_or_number_or_id=model_version, ) + resources = [ + RunMetadataResource( + id=model_version_model.id, + type=MetadataResourceTypes.MODEL_VERSION, + ) + ] - # If the user provides a model version id, we use the client to fetch it and - # attach the metadata to it. + # Log metadata to a model version by id elif model_version_id is not None: - model_version_id = client.get_model_version( - model_version_name_or_number_or_id=model_version_id - ).id - client.create_run_metadata( - metadata=metadata, - resource_id=model_version_id, - resource_type=MetadataResourceTypes.MODEL_VERSION, + resources = [ + RunMetadataResource( + id=model_version_id, + type=MetadataResourceTypes.MODEL_VERSION, + ) + ] + + # Log metadata to a model through the step context + elif infer_model is True: + try: + step_context = get_step_context() + except RuntimeError: + raise ValueError( + "If you are using the `infer_model` option, the function must " + "be called inside a step with configured `model` in decorator." + "Otherwise, you can provide a `model_version_id` or a " + "combination of `model_name` and `model_version`." + ) + + if step_context.model_version is None: + raise ValueError( + "The step context does not feature any model versions." + ) + + resources = [ + RunMetadataResource( + id=step_context.model_version.id, + type=MetadataResourceTypes.MODEL_VERSION, + ) + ] + + # Log metadata to an artifact version by its name and version + elif artifact_name is not None and artifact_version is not None: + artifact_version_model = client.get_artifact_version( + name_id_or_prefix=artifact_name, version=artifact_version ) + resources = [ + RunMetadataResource( + id=artifact_version_model.id, + type=MetadataResourceTypes.ARTIFACT_VERSION, + ) + ] - # If the user provides an artifact name, there are three possibilities. If - # an artifact version is also provided with the name, we use both to fetch - # the artifact version and use it to log the metadata. If no version is - # provided, if the function is called within a step we search the artifacts - # of the step if not we fetch the latest version and attach the metadata - # to the latest version. - elif artifact_name is not None: - if artifact_version: - artifact_version_model = client.get_artifact_version( - name_id_or_prefix=artifact_name, version=artifact_version + # Log metadata to an artifact version by its ID + elif artifact_version_id is not None: + resources = [ + RunMetadataResource( + id=artifact_version_id, + type=MetadataResourceTypes.ARTIFACT_VERSION, ) - client.create_run_metadata( - metadata=metadata, - resource_id=artifact_version_model.id, - resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + ] + + # Log metadata to an artifact version through the step context + elif infer_artifact is True: + try: + step_context = get_step_context() + except RuntimeError: + raise ValueError( + "When you are using the `infer_artifact` option when you call " + "`log_metadata`, it must be called inside a step with outputs." + "Otherwise, you can provide a `artifact_version_id` or a " + "combination of `artifact_name` and `artifact_version`." ) - else: - step_context = None - with contextlib.suppress(RuntimeError): - step_context = get_step_context() - if step_context: - step_context.add_output_metadata( - metadata=metadata, output_name=artifact_name - ) - else: - artifact_version_model = client.get_artifact_version( - name_id_or_prefix=artifact_name + step_output_names = list(step_context._outputs.keys()) + + if artifact_name is not None: + # If a name provided, ensure it is in the outputs + if artifact_name not in step_output_names: + raise ValueError( + f"The provided artifact name`{artifact_name}` does not " + f"exist in the step outputs: {step_output_names}." ) - client.create_run_metadata( - metadata=metadata, - resource_id=artifact_version_model.id, - resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + else: + # If no name provided, ensure there is only one output + if len(step_output_names) > 1: + raise ValueError( + "There is more than one output. If you would like to use " + "the `infer_artifact` option, you need to define an " + "`artifact_name`." ) - # If the user directly provides an artifact_version_id, we use the client to - # fetch is and attach the metadata accordingly. - elif artifact_version_id is not None: - artifact_version_model = client.get_artifact_version( - name_id_or_prefix=artifact_version_id, - ) - client.create_run_metadata( - metadata=metadata, - resource_id=artifact_version_model.id, - resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + if len(step_output_names) == 0: + raise ValueError("The step does not have any outputs.") + + artifact_name = step_output_names[0] + + step_context.add_output_metadata( + metadata=metadata, output_name=artifact_name ) + return # If every additional value is None, that means we are calling it bare bones # and this call needs to happen during a step execution. We will use the @@ -287,22 +319,14 @@ def log_metadata( "of the step execution, please provide the required " "identifiers." ) - client.create_run_metadata( - metadata=metadata, - resource_id=step_context.pipeline_run.id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, - ) - client.create_run_metadata( - metadata=metadata, - resource_id=step_context.step_run.id, - resource_type=MetadataResourceTypes.STEP_RUN, - ) - if step_context.model_version: - client.create_run_metadata( - metadata=metadata, - resource_id=step_context.model_version.id, - resource_type=MetadataResourceTypes.MODEL_VERSION, + + resources = [ + RunMetadataResource( + id=step_context.step_run.id, + type=MetadataResourceTypes.STEP_RUN, ) + ] + publisher_step_id = step_context.step_run.id else: raise ValueError( @@ -310,26 +334,35 @@ def log_metadata( Unsupported way to call the `log_metadata`. Possible combinations " include: - # Inside a step - # Logs the metadata to the step, its run and possibly its model + # Automatic logging to a step (within a step) log_metadata(metadata={}) - # Manually logging for a step - # Logs the metadata to the step, its run and possibly its model + # Manual logging to a step log_metadata(metadata={}, step_name=..., run_id_name_or_prefix=...) log_metadata(metadata={}, step_id=...) - # Manually logging for a run - # Logs the metadata to the run, possibly its model + # Manual logging to a run log_metadata(metadata={}, run_id_name_or_prefix=...) - # Manually logging for a model + # Automatic logging to a model (within a step) + log_metadata(metadata={}, infer_model=True) + + # Manual logging to a model log_metadata(metadata={}, model_name=..., model_version=...) log_metadata(metadata={}, model_version_id=...) - # Manually logging for an artifact - log_metadata(metadata={}, artifact_name=...) # inside a step + # Automatic logging to an artifact (within a step) + log_metadata(metadata={}, infer_artifact=True) # step with single output + log_metadata(metadata={}, artifact_name=..., infer_artifact=True) # specific output of a step + + # Manual logging to an artifact log_metadata(metadata={}, artifact_name=..., artifact_version=...) log_metadata(metadata={}, artifact_version_id=...) """ ) + + client.create_run_metadata( + metadata=metadata, + resources=resources, + publisher_step_id=publisher_step_id, + ) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 955c0630f1e..dd26289d05a 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Endpoint definitions for workspaces.""" -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -102,6 +102,7 @@ ) from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import ( + batch_verify_permissions_for_models, dehydrate_page, dehydrate_response_model, get_allowed_resource_ids, @@ -998,24 +999,23 @@ def create_run_metadata( "is not supported." ) - if run_metadata.resource_type == MetadataResourceTypes.PIPELINE_RUN: - run = zen_store().get_run(run_metadata.resource_id) - verify_permission_for_model(run, action=Action.UPDATE) - elif run_metadata.resource_type == MetadataResourceTypes.STEP_RUN: - step = zen_store().get_run_step(run_metadata.resource_id) - verify_permission_for_model(step, action=Action.UPDATE) - elif run_metadata.resource_type == MetadataResourceTypes.ARTIFACT_VERSION: - artifact_version = zen_store().get_artifact_version( - run_metadata.resource_id - ) - verify_permission_for_model(artifact_version, action=Action.UPDATE) - elif run_metadata.resource_type == MetadataResourceTypes.MODEL_VERSION: - model_version = zen_store().get_model_version(run_metadata.resource_id) - verify_permission_for_model(model_version, action=Action.UPDATE) - else: - raise RuntimeError( - f"Unknown resource type: {run_metadata.resource_type}" - ) + verify_models: List[Any] = [] + for resource in run_metadata.resources: + if resource.type == MetadataResourceTypes.PIPELINE_RUN: + verify_models.append(zen_store().get_run(resource.id)) + elif resource.type == MetadataResourceTypes.STEP_RUN: + verify_models.append(zen_store().get_run_step(resource.id)) + elif resource.type == MetadataResourceTypes.ARTIFACT_VERSION: + verify_models.append(zen_store().get_artifact_version(resource.id)) + elif resource.type == MetadataResourceTypes.MODEL_VERSION: + verify_models.append(zen_store().get_model_version(resource.id)) + else: + raise RuntimeError(f"Unknown resource type: {resource.type}") + + batch_verify_permissions_for_models( + models=verify_models, + action=Action.UPDATE, + ) verify_permission( resource_type=ResourceType.RUN_METADATA, action=Action.CREATE diff --git a/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py new file mode 100644 index 00000000000..52a4cbd8ef2 --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/cc269488e5a9_separate_run_metadata.py @@ -0,0 +1,135 @@ +"""Separate run metadata into resource link table with new UUIDs. + +Revision ID: cc269488e5a9 +Revises: b73bc71f1106 +Create Date: 2024-11-12 09:46:46.587478 +""" + +import uuid + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "cc269488e5a9" +down_revision = "b73bc71f1106" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Creates the 'run_metadata_resource' table and migrates data.""" + op.create_table( + "run_metadata_resource", + sa.Column( + "id", + sqlmodel.sql.sqltypes.GUID(), + nullable=False, + primary_key=True, + ), + sa.Column("resource_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("resource_type", sa.String(length=255), nullable=False), + sa.Column( + "run_metadata_id", + sqlmodel.sql.sqltypes.GUID(), + sa.ForeignKey("run_metadata.id", ondelete="CASCADE"), + nullable=False, + ), + ) + + connection = op.get_bind() + + run_metadata_data = connection.execute( + sa.text(""" + SELECT id, resource_id, resource_type + FROM run_metadata + """) + ).fetchall() + + # Prepare data with new UUIDs for bulk insert + resource_data = [ + { + "id": str(uuid.uuid4()), # Generate a new UUID for each row + "resource_id": row.resource_id, + "resource_type": row.resource_type, + "run_metadata_id": row.id, + } + for row in run_metadata_data + ] + + # Perform bulk insert into `run_metadata_resource` + if resource_data: # Only perform insert if there's data to migrate + op.bulk_insert( + sa.table( + "run_metadata_resource", + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column( + "resource_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column( + "resource_type", sa.String(length=255), nullable=False + ), + sa.Column( + "run_metadata_id", + sqlmodel.sql.sqltypes.GUID(), + nullable=False, + ), # Changed to BIGINT + ), + resource_data, + ) + + op.drop_column("run_metadata", "resource_id") + op.drop_column("run_metadata", "resource_type") + + op.add_column( + "run_metadata", + sa.Column( + "publisher_step_id", sqlmodel.sql.sqltypes.GUID(), nullable=True + ), + ) + + +def downgrade() -> None: + """Reverts the 'run_metadata_resource' table and migrates data back.""" + # Recreate the `resource_id` and `resource_type` columns in `run_metadata` + op.add_column( + "run_metadata", + sa.Column("resource_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + ) + op.add_column( + "run_metadata", + sa.Column("resource_type", sa.String(length=255), nullable=True), + ) + + # Migrate data back from `run_metadata_resource` to `run_metadata` + connection = op.get_bind() + + # Fetch data from `run_metadata_resource` + run_metadata_resource_data = connection.execute( + sa.text(""" + SELECT resource_id, resource_type, run_metadata_id + FROM run_metadata_resource + """) + ).fetchall() + + # Update `run_metadata` with the data from `run_metadata_resource` + for row in run_metadata_resource_data: + connection.execute( + sa.text(""" + UPDATE run_metadata + SET resource_id = :resource_id, resource_type = :resource_type + WHERE id = :run_metadata_id + """), + { + "resource_id": row.resource_id, + "resource_type": row.resource_type, + "run_metadata_id": row.run_metadata_id, + }, + ) + + # Drop the `run_metadata_resource` table + op.drop_table("run_metadata_resource") + + # Drop the cached column + op.drop_column("run_metadata", "publisher_step_id") diff --git a/src/zenml/zen_stores/schemas/__init__.py b/src/zenml/zen_stores/schemas/__init__.py index 2faf233723a..dadb2b747f6 100644 --- a/src/zenml/zen_stores/schemas/__init__.py +++ b/src/zenml/zen_stores/schemas/__init__.py @@ -39,7 +39,10 @@ from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema from zenml.zen_stores.schemas.pipeline_schemas import PipelineSchema from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema -from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema +from zenml.zen_stores.schemas.run_metadata_schemas import ( + RunMetadataResourceSchema, + RunMetadataSchema, +) from zenml.zen_stores.schemas.schedule_schema import ScheduleSchema from zenml.zen_stores.schemas.secret_schemas import SecretSchema from zenml.zen_stores.schemas.service_schemas import ServiceSchema @@ -90,6 +93,7 @@ "PipelineDeploymentSchema", "PipelineRunSchema", "PipelineSchema", + "RunMetadataResourceSchema", "RunMetadataSchema", "ScheduleSchema", "SecretSchema", diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index eda2f17927e..02e842a5fb5 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """SQLModel implementation of artifact table.""" -import json from datetime import datetime from typing import TYPE_CHECKING, Any, List, Optional from uuid import UUID @@ -50,6 +49,7 @@ StepRunOutputArtifactSchema, ) from zenml.zen_stores.schemas.user_schemas import UserSchema +from zenml.zen_stores.schemas.utils import RunMetadataInterface from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema if TYPE_CHECKING: @@ -59,7 +59,9 @@ from zenml.zen_stores.schemas.model_schemas import ( ModelVersionArtifactSchema, ) - from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema + from zenml.zen_stores.schemas.run_metadata_schemas import ( + RunMetadataResourceSchema, + ) from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema @@ -171,7 +173,7 @@ def update(self, artifact_update: ArtifactUpdate) -> "ArtifactSchema": return self -class ArtifactVersionSchema(BaseSchema, table=True): +class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True): """SQL Model for artifact versions.""" __tablename__ = "artifact_version" @@ -242,12 +244,12 @@ class ArtifactVersionSchema(BaseSchema, table=True): workspace: "WorkspaceSchema" = Relationship( back_populates="artifact_versions" ) - run_metadata: List["RunMetadataSchema"] = Relationship( - back_populates="artifact_version", + run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship( + back_populates="artifact_versions", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataSchema.resource_id)==ArtifactVersionSchema.id)", + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)", cascade="delete", - overlaps="run_metadata", + overlaps="run_metadata_resources", ), ) output_of_step_runs: List["StepRunOutputArtifactSchema"] = Relationship( @@ -376,9 +378,7 @@ def to_model( workspace=self.workspace.to_model(), producer_step_run_id=producer_step_run_id, visualizations=[v.to_model() for v in self.visualizations], - run_metadata={ - m.key: json.loads(m.value) for m in self.run_metadata - }, + run_metadata=self.fetch_metadata(), ) resources = None diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index d6d438aea37..feb4a93dc80 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """SQLModel implementation of model tables.""" -import json from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast from uuid import UUID @@ -51,11 +50,16 @@ from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema from zenml.zen_stores.schemas.constants import MODEL_VERSION_TABLENAME from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema -from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema +from zenml.zen_stores.schemas.run_metadata_schemas import ( + RunMetadataResourceSchema, +) from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema from zenml.zen_stores.schemas.user_schemas import UserSchema -from zenml.zen_stores.schemas.utils import get_page_from_list +from zenml.zen_stores.schemas.utils import ( + RunMetadataInterface, + get_page_from_list, +) from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema if TYPE_CHECKING: @@ -219,7 +223,7 @@ def update( return self -class ModelVersionSchema(NamedSchema, table=True): +class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): """SQL Model for model version.""" __tablename__ = MODEL_VERSION_TABLENAME @@ -299,12 +303,12 @@ class ModelVersionSchema(NamedSchema, table=True): description: str = Field(sa_column=Column(TEXT, nullable=True)) stage: str = Field(sa_column=Column(TEXT, nullable=True)) - run_metadata: List["RunMetadataSchema"] = Relationship( - back_populates="model_version", + run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship( + back_populates="model_versions", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataSchema.resource_id)==ModelVersionSchema.id)", + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)", cascade="delete", - overlaps="run_metadata", + overlaps="run_metadata_resources", ), ) pipeline_runs: List["PipelineRunSchema"] = Relationship( @@ -402,9 +406,7 @@ def to_model( metadata = ModelVersionResponseMetadata( workspace=self.workspace.to_model(), description=self.description, - run_metadata={ - rm.key: json.loads(rm.value) for rm in self.run_metadata - }, + run_metadata=self.fetch_metadata(), ) resources = None diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 70e5a41d76f..d0af218b629 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -15,7 +15,7 @@ import json from datetime import datetime -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional from uuid import UUID from pydantic import ConfigDict @@ -34,6 +34,7 @@ PipelineRunResponseBody, PipelineRunResponseMetadata, PipelineRunUpdate, + RunMetadataEntry, ) from zenml.models.v2.core.pipeline_run import PipelineRunResponseResources from zenml.zen_stores.schemas.base_schemas import NamedSchema @@ -48,6 +49,7 @@ from zenml.zen_stores.schemas.stack_schemas import StackSchema from zenml.zen_stores.schemas.trigger_schemas import TriggerExecutionSchema from zenml.zen_stores.schemas.user_schemas import UserSchema +from zenml.zen_stores.schemas.utils import RunMetadataInterface from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema if TYPE_CHECKING: @@ -56,13 +58,15 @@ ModelVersionPipelineRunSchema, ModelVersionSchema, ) - from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema + from zenml.zen_stores.schemas.run_metadata_schemas import ( + RunMetadataResourceSchema, + ) from zenml.zen_stores.schemas.service_schemas import ServiceSchema from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema -class PipelineRunSchema(NamedSchema, table=True): +class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): """SQL Model for pipeline runs.""" __tablename__ = "pipeline_run" @@ -136,12 +140,12 @@ class PipelineRunSchema(NamedSchema, table=True): ) workspace: "WorkspaceSchema" = Relationship(back_populates="runs") user: Optional["UserSchema"] = Relationship(back_populates="runs") - run_metadata: List["RunMetadataSchema"] = Relationship( - back_populates="pipeline_run", + run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship( + back_populates="pipeline_runs", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataSchema.resource_id)==PipelineRunSchema.id)", + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)", cascade="delete", - overlaps="run_metadata", + overlaps="run_metadata_resources", ), ) logs: Optional["LogsSchema"] = Relationship( @@ -249,6 +253,24 @@ def from_request( model_version_id=request.model_version_id, ) + def fetch_metadata_collection(self) -> Dict[str, List[RunMetadataEntry]]: + """Fetches all the metadata entries related to the pipeline run. + + Returns: + a dictionary, where the key is the key of the metadata entry + and the values represent the list of entries with this key. + """ + # Fetch the metadata related to this run + metadata_collection = super().fetch_metadata_collection() + + # Fetch the metadata related to the steps of this run + for s in self.step_runs: + step_metadata = s.fetch_metadata_collection() + for k, v in step_metadata.items(): + metadata_collection[f"{s.name}::{k}"] = v + + return metadata_collection + def to_model( self, include_metadata: bool = False, @@ -275,11 +297,6 @@ def to_model( else {} ) - run_metadata = { - metadata_schema.key: json.loads(metadata_schema.value) - for metadata_schema in self.run_metadata - } - if self.deployment is not None: deployment = self.deployment.to_model() @@ -356,7 +373,7 @@ def to_model( } metadata = PipelineRunResponseMetadata( workspace=self.workspace.to_model(), - run_metadata=run_metadata, + run_metadata=self.fetch_metadata(), config=config, steps=steps, start_time=self.start_time, diff --git a/src/zenml/zen_stores/schemas/run_metadata_schemas.py b/src/zenml/zen_stores/schemas/run_metadata_schemas.py index 18d203111c7..f4465b13e66 100644 --- a/src/zenml/zen_stores/schemas/run_metadata_schemas.py +++ b/src/zenml/zen_stores/schemas/run_metadata_schemas.py @@ -14,15 +14,16 @@ """SQLModel implementation of pipeline run metadata tables.""" from typing import TYPE_CHECKING, List, Optional -from uuid import UUID +from uuid import UUID, uuid4 from sqlalchemy import TEXT, VARCHAR, Column -from sqlmodel import Field, Relationship +from sqlmodel import Field, Relationship, SQLModel from zenml.enums import MetadataResourceTypes from zenml.zen_stores.schemas.base_schemas import BaseSchema from zenml.zen_stores.schemas.component_schemas import StackComponentSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field +from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema from zenml.zen_stores.schemas.user_schemas import UserSchema from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema @@ -30,7 +31,6 @@ from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema from zenml.zen_stores.schemas.model_schemas import ModelVersionSchema from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema - from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema class RunMetadataSchema(BaseSchema, table=True): @@ -38,35 +38,10 @@ class RunMetadataSchema(BaseSchema, table=True): __tablename__ = "run_metadata" - resource_id: UUID - resource_type: str = Field(sa_column=Column(VARCHAR(255), nullable=False)) - pipeline_run: List["PipelineRunSchema"] = Relationship( - back_populates="run_metadata", - sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataSchema.resource_id)==PipelineRunSchema.id)", - overlaps="run_metadata,step_run,artifact_version,model_version", - ), - ) - step_run: List["StepRunSchema"] = Relationship( - back_populates="run_metadata", - sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataSchema.resource_id)==StepRunSchema.id)", - overlaps="run_metadata,pipeline_run,artifact_version,model_version", - ), - ) - artifact_version: List["ArtifactVersionSchema"] = Relationship( - back_populates="run_metadata", - sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataSchema.resource_id)==ArtifactVersionSchema.id)", - overlaps="run_metadata,pipeline_run,step_run,model_version", - ), - ) - model_version: List["ModelVersionSchema"] = Relationship( + # Relationship to link to resources + resources: List["RunMetadataResourceSchema"] = Relationship( back_populates="run_metadata", - sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataSchema.resource_id)==ModelVersionSchema.id)", - overlaps="run_metadata,pipeline_run,step_run,artifact_version", - ), + sa_relationship_kwargs={"cascade": "delete"}, ) stack_component_id: Optional[UUID] = build_foreign_key_field( source=__tablename__, @@ -103,3 +78,63 @@ class RunMetadataSchema(BaseSchema, table=True): key: str value: str = Field(sa_column=Column(TEXT, nullable=False)) type: str + + publisher_step_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=StepRunSchema.__tablename__, + source_column="publisher_step_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + + +class RunMetadataResourceSchema(SQLModel, table=True): + """Table for linking resources to run metadata entries.""" + + __tablename__ = "run_metadata_resource" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + resource_id: UUID + resource_type: str = Field(sa_column=Column(VARCHAR(255), nullable=False)) + run_metadata_id: UUID = build_foreign_key_field( + source=__tablename__, + target=RunMetadataSchema.__tablename__, + source_column="run_metadata_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + + # Relationship back to the base metadata table + run_metadata: RunMetadataSchema = Relationship(back_populates="resources") + + # Relationship to link specific resource types + pipeline_runs: List["PipelineRunSchema"] = Relationship( + back_populates="run_metadata_resources", + sa_relationship_kwargs=dict( + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)", + overlaps="run_metadata_resources,step_runs,artifact_versions,model_versions", + ), + ) + step_runs: List["StepRunSchema"] = Relationship( + back_populates="run_metadata_resources", + sa_relationship_kwargs=dict( + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)", + overlaps="run_metadata_resources,pipeline_runs,artifact_versions,model_versions", + ), + ) + artifact_versions: List["ArtifactVersionSchema"] = Relationship( + back_populates="run_metadata_resources", + sa_relationship_kwargs=dict( + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)", + overlaps="run_metadata_resources,pipeline_runs,step_runs,model_versions", + ), + ) + model_versions: List["ModelVersionSchema"] = Relationship( + back_populates="run_metadata_resources", + sa_relationship_kwargs=dict( + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)", + overlaps="run_metadata_resources,pipeline_runs,step_runs,artifact_versions", + ), + ) diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index 860ff794421..f8788505156 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -51,16 +51,19 @@ from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.user_schemas import UserSchema +from zenml.zen_stores.schemas.utils import RunMetadataInterface from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema if TYPE_CHECKING: from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema from zenml.zen_stores.schemas.logs_schemas import LogsSchema from zenml.zen_stores.schemas.model_schemas import ModelVersionSchema - from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema + from zenml.zen_stores.schemas.run_metadata_schemas import ( + RunMetadataResourceSchema, + ) -class StepRunSchema(NamedSchema, table=True): +class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): """SQL Model for steps of pipeline runs.""" __tablename__ = "step_run" @@ -140,12 +143,12 @@ class StepRunSchema(NamedSchema, table=True): deployment: Optional["PipelineDeploymentSchema"] = Relationship( back_populates="step_runs" ) - run_metadata: List["RunMetadataSchema"] = Relationship( - back_populates="step_run", + run_metadata_resources: List["RunMetadataResourceSchema"] = Relationship( + back_populates="step_runs", sa_relationship_kwargs=dict( - primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataSchema.resource_id)==StepRunSchema.id)", + primaryjoin=f"and_(RunMetadataResourceSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)", cascade="delete", - overlaps="run_metadata", + overlaps="run_metadata_resources", ), ) input_artifacts: List["StepRunInputArtifactSchema"] = Relationship( @@ -170,6 +173,9 @@ class StepRunSchema(NamedSchema, table=True): model_version: "ModelVersionSchema" = Relationship( back_populates="step_runs", ) + original_step_run: Optional["StepRunSchema"] = Relationship( + sa_relationship_kwargs={"remote_side": "StepRunSchema.id"} + ) model_config = ConfigDict(protected_namespaces=()) # type: ignore[assignment] @@ -222,11 +228,6 @@ def to_model( RuntimeError: If the step run schema does not have a deployment_id or a step_configuration. """ - run_metadata = { - metadata_schema.key: json.loads(metadata_schema.value) - for metadata_schema in self.run_metadata - } - input_artifacts = { artifact.name: StepRunInputResponse( input_type=StepRunInputArtifactType(artifact.type), @@ -313,7 +314,7 @@ def to_model( pipeline_run_id=self.pipeline_run_id, original_step_run_id=self.original_step_run_id, parent_step_ids=[p.parent_id for p in self.parents], - run_metadata=run_metadata, + run_metadata=self.fetch_metadata(), ) resources = None diff --git a/src/zenml/zen_stores/schemas/utils.py b/src/zenml/zen_stores/schemas/utils.py index ad458a5423e..5484a6a9cc8 100644 --- a/src/zenml/zen_stores/schemas/utils.py +++ b/src/zenml/zen_stores/schemas/utils.py @@ -13,11 +13,14 @@ # permissions and limitations under the License. """Utils for schemas.""" +import json import math -from typing import List, Type, TypeVar +from typing import Dict, List, Type, TypeVar -from zenml.models.v2.base.base import BaseResponse -from zenml.models.v2.base.page import Page +from sqlmodel import Relationship + +from zenml.metadata.metadata_types import MetadataType +from zenml.models import BaseResponse, Page, RunMetadataEntry from zenml.zen_stores.schemas.base_schemas import BaseSchema S = TypeVar("S", bound=BaseSchema) @@ -67,3 +70,44 @@ def get_page_from_list( total=total, items=page_items, ) + + +class RunMetadataInterface: + """The interface for entities with run metadata.""" + + run_metadata_resources = Relationship() + + def fetch_metadata_collection(self) -> Dict[str, List[RunMetadataEntry]]: + """Fetches all the metadata entries related to the artifact version. + + Returns: + a dictionary, where the key is the key of the metadata entry + and the values represent the list of entries with this key. + """ + metadata_collection: Dict[str, List[RunMetadataEntry]] = {} + + # Fetch the metadata related to this step + for rm in self.run_metadata_resources: + if rm.run_metadata.key not in metadata_collection: + metadata_collection[rm.run_metadata.key] = [] + metadata_collection[rm.run_metadata.key].append( + RunMetadataEntry( + value=json.loads(rm.run_metadata.value), + created=rm.run_metadata.created, + ) + ) + + return metadata_collection + + def fetch_metadata(self) -> Dict[str, MetadataType]: + """Fetches the latest metadata entry related to the artifact version. + + Returns: + a dictionary, where the key is the key of the metadata entry + and the values represent the latest entry with this key. + """ + metadata_collection = self.fetch_metadata_collection() + return { + k: sorted(v, key=lambda x: x.created, reverse=True)[0].value + for k, v in metadata_collection.items() + } diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 98ea30fb8e9..5f44873e87b 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -26,6 +26,7 @@ from functools import lru_cache from pathlib import Path from typing import ( + TYPE_CHECKING, Any, Callable, ClassVar, @@ -219,6 +220,7 @@ PipelineRunUpdate, PipelineUpdate, RunMetadataRequest, + RunMetadataResource, RunTemplateFilter, RunTemplateRequest, RunTemplateResponse, @@ -325,6 +327,7 @@ PipelineDeploymentSchema, PipelineRunSchema, PipelineSchema, + RunMetadataResourceSchema, RunMetadataSchema, RunTemplateSchema, ScheduleSchema, @@ -354,6 +357,9 @@ SqlSecretsStoreConfiguration, ) +if TYPE_CHECKING: + from zenml.metadata.metadata_types import MetadataType, MetadataTypeEnum + AnyNamedSchema = TypeVar("AnyNamedSchema", bound=NamedSchema) AnySchema = TypeVar("AnySchema", bound=BaseSchema) @@ -2918,17 +2924,41 @@ def create_artifact_version( # Save metadata of the artifact if artifact_version.metadata: + values: Dict[str, "MetadataType"] = {} + types: Dict[str, "MetadataTypeEnum"] = {} for key, value in artifact_version.metadata.items(): - run_metadata_schema = RunMetadataSchema( - workspace_id=artifact_version.workspace, - user_id=artifact_version.user, - resource_id=artifact_version_id, - resource_type=MetadataResourceTypes.ARTIFACT_VERSION, - key=key, - value=json.dumps(value), - type=get_metadata_type(value), + # Skip metadata that is too large to be stored in the DB. + if len(json.dumps(value)) > TEXT_FIELD_MAX_LENGTH: + logger.warning( + f"Metadata value for key '{key}' is too large to be " + "stored in the database. Skipping." + ) + continue + # Skip metadata that is not of a supported type. + try: + metadata_type = get_metadata_type(value) + except ValueError as e: + logger.warning( + f"Metadata value for key '{key}' is not of a " + f"supported type. Skipping. Full error: {e}" + ) + continue + values[key] = value + types[key] = metadata_type + self.create_run_metadata( + RunMetadataRequest( + workspace=artifact_version.workspace, + user=artifact_version.user, + resources=[ + RunMetadataResource( + id=artifact_version_id, + type=MetadataResourceTypes.ARTIFACT_VERSION, + ) + ], + values=values, + types=types, ) - session.add(run_metadata_schema) + ) session.commit() artifact_version_schema = session.exec( @@ -5532,20 +5562,29 @@ def create_run_metadata(self, run_metadata: RunMetadataRequest) -> None: The created run metadata. """ with Session(self.engine) as session: - for key, value in run_metadata.values.items(): - type_ = run_metadata.types[key] - run_metadata_schema = RunMetadataSchema( - workspace_id=run_metadata.workspace, - user_id=run_metadata.user, - resource_id=run_metadata.resource_id, - resource_type=run_metadata.resource_type.value, - stack_component_id=run_metadata.stack_component_id, - key=key, - value=json.dumps(value), - type=type_, - ) - session.add(run_metadata_schema) - session.commit() + if run_metadata.resources: + for key, value in run_metadata.values.items(): + type_ = run_metadata.types[key] + run_metadata_schema = RunMetadataSchema( + workspace_id=run_metadata.workspace, + user_id=run_metadata.user, + stack_component_id=run_metadata.stack_component_id, + key=key, + value=json.dumps(value), + type=type_, + publisher_step_id=run_metadata.publisher_step_id, + ) + session.add(run_metadata_schema) + session.commit() + + for resource in run_metadata.resources: + rm_resource_link = RunMetadataResourceSchema( + resource_id=resource.id, + resource_type=resource.type.value, + run_metadata_id=run_metadata_schema.id, + ) + session.add(rm_resource_link) + session.commit() return None # ----------------------------- Schedules ----------------------------- @@ -8156,6 +8195,46 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse: ) session.add(log_entry) + # If cached, attach metadata of the original step + if ( + step_run.status == ExecutionStatus.CACHED + and step_run.original_step_run_id is not None + ): + original_metadata_links = session.exec( + select(RunMetadataResourceSchema) + .where( + RunMetadataResourceSchema.run_metadata_id + == RunMetadataSchema.id + ) + .where( + RunMetadataResourceSchema.resource_id + == step_run.original_step_run_id + ) + .where( + RunMetadataResourceSchema.resource_type + == MetadataResourceTypes.STEP_RUN + ) + .where( + RunMetadataSchema.publisher_step_id + == step_run.original_step_run_id + ) + ).all() + + # Create new links in a batch + new_links = [ + RunMetadataResourceSchema( + resource_id=step_schema.id, + resource_type=link.resource_type, + run_metadata_id=link.run_metadata_id, + ) + for link in original_metadata_links + ] + # Add all new links in a single operation + session.add_all(new_links) + # Commit the changes + session.commit() + session.refresh(step_schema) + # Save parent step IDs into the database. for parent_step_id in step_run.parent_step_ids: self._set_run_step_parent_step( diff --git a/tests/integration/functional/artifacts/test_utils.py b/tests/integration/functional/artifacts/test_utils.py index 5c091e61fb0..79cb52212a8 100644 --- a/tests/integration/functional/artifacts/test_utils.py +++ b/tests/integration/functional/artifacts/test_utils.py @@ -123,15 +123,16 @@ def _load_pipeline(expected_value, name, version): def test_log_metadata_existing(clean_client): """Test logging artifact metadata for existing artifacts.""" - save_artifact(42, "meaning_of_life") + av = save_artifact(42, "meaning_of_life") log_metadata( metadata={"description": "Aria is great!"}, - artifact_name="meaning_of_life", + artifact_version_id=av.id, ) save_artifact(43, "meaning_of_life", version="43") log_metadata( metadata={"description_2": "Blupus is great!"}, artifact_name="meaning_of_life", + artifact_version="43", ) log_metadata( metadata={"description_3": "Axl is great!"}, @@ -215,7 +216,11 @@ def artifact_multi_output_metadata_logging_step() -> ( "description": "Blupus is great!", "metrics": {"accuracy": 0.9}, } - log_metadata(metadata=output_metadata, artifact_name="int_output") + log_metadata( + metadata=output_metadata, + artifact_name="int_output", + infer_artifact=True, + ) return "42", 42 diff --git a/tests/integration/functional/model/test_model_version.py b/tests/integration/functional/model/test_model_version.py index 7266a541146..d16b9dc31bd 100644 --- a/tests/integration/functional/model/test_model_version.py +++ b/tests/integration/functional/model/test_model_version.py @@ -107,10 +107,18 @@ def __exit__(self, exc_type, exc_value, exc_traceback): @step def step_metadata_logging_functional(mdl_name: str): """Functional logging using implicit Model from context.""" - log_metadata({"foo": "bar"}) + model = get_step_context().model + + log_metadata( + metadata={"foo": "bar"}, + model_name=model.name, + model_version=model.version, + ) assert get_step_context().model.run_metadata["foo"] == "bar" log_metadata( - metadata={"foo": "bar"}, model_name=mdl_name, model_version="other" + metadata={"foo": "bar"}, + model_name=mdl_name, + model_version="other", ) diff --git a/tests/integration/functional/pipelines/test_pipeline_context.py b/tests/integration/functional/pipelines/test_pipeline_context.py index 70e7608f7a8..f070ca0272f 100644 --- a/tests/integration/functional/pipelines/test_pipeline_context.py +++ b/tests/integration/functional/pipelines/test_pipeline_context.py @@ -109,7 +109,7 @@ def producer() -> Annotated[str, "bar"]: ) log_metadata( metadata={"foobar": "artifact_meta_" + model.version}, - artifact_name="bar", + infer_artifact=True, ) return "artifact_data_" + model.version diff --git a/tests/integration/functional/steps/test_step_context.py b/tests/integration/functional/steps/test_step_context.py index d520cfd83a4..f34e53b8e4b 100644 --- a/tests/integration/functional/steps/test_step_context.py +++ b/tests/integration/functional/steps/test_step_context.py @@ -93,7 +93,8 @@ def _simple_step_pipeline(): @step def output_metadata_logging_step() -> Annotated[int, "my_output"]: log_metadata( - metadata={"some_key": "some_value"}, artifact_name="my_output" + metadata={"some_key": "some_value"}, + infer_artifact=True, ) return 42 diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index bd9583a8a1d..72557777ab3 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -63,6 +63,7 @@ PipelineBuildRequest, PipelineDeploymentRequest, PipelineRequest, + RunMetadataResource, StackResponse, ) from zenml.utils import io_utils @@ -484,8 +485,11 @@ def test_create_run_metadata_for_pipeline_run(clean_client_with_run: Client): # Assert that the created metadata is correct clean_client_with_run.create_run_metadata( metadata={"axel": "is awesome"}, - resource_id=pipeline_run.id, - resource_type=MetadataResourceTypes.PIPELINE_RUN, + resources=[ + RunMetadataResource( + id=pipeline_run.id, type=MetadataResourceTypes.PIPELINE_RUN + ) + ], ) rm = clean_client_with_run.get_pipeline_run(pipeline_run.id).run_metadata @@ -501,8 +505,11 @@ def test_create_run_metadata_for_step_run(clean_client_with_run: Client): # Assert that the created metadata is correct clean_client_with_run.create_run_metadata( metadata={"axel": "is awesome"}, - resource_id=step_run.id, - resource_type=MetadataResourceTypes.STEP_RUN, + resources=[ + RunMetadataResource( + id=step_run.id, type=MetadataResourceTypes.STEP_RUN + ) + ], ) rm = clean_client_with_run.get_run_step(step_run.id).run_metadata @@ -518,8 +525,12 @@ def test_create_run_metadata_for_artifact(clean_client_with_run: Client): # Assert that the created metadata is correct clean_client_with_run.create_run_metadata( metadata={"axel": "is awesome"}, - resource_id=artifact_version.id, - resource_type=MetadataResourceTypes.ARTIFACT_VERSION, + resources=[ + RunMetadataResource( + id=artifact_version.id, + type=MetadataResourceTypes.ARTIFACT_VERSION, + ) + ], ) rm = clean_client_with_run.get_artifact_version( @@ -968,7 +979,8 @@ def lazy_producer_test_artifact() -> Annotated[str, "new_one"]: from zenml.client import Client log_metadata( - metadata={"some_meta": "meta_new_one"}, artifact_name="new_one" + metadata={"some_meta": "meta_new_one"}, + infer_artifact=True, ) client = Client() @@ -1140,14 +1152,14 @@ def dummy(): artifact_name="preexisting", artifact_version="1.2.3", ) + with pytest.raises(KeyError): + clean_client.get_artifact_version("new_one") + dummy() log_metadata( metadata={"some_meta": "meta_preexisting"}, model_name="aria", model_version="model_version", ) - with pytest.raises(KeyError): - clean_client.get_artifact_version("new_one") - dummy() class TestModel: diff --git a/tests/integration/functional/utils/test_metadata_utils.py b/tests/integration/functional/utils/test_metadata_utils.py new file mode 100644 index 00000000000..263cc658291 --- /dev/null +++ b/tests/integration/functional/utils/test_metadata_utils.py @@ -0,0 +1,184 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + + +from typing import Annotated, Tuple + +import pytest + +from zenml import Model, log_metadata, pipeline, step + + +@step +def step_multiple_calls() -> None: + """Step calls log_metadata twice, latest value should be returned.""" + log_metadata(metadata={"blupus": 1}) + log_metadata(metadata={"blupus": 2}) + + +@step +def step_single_output() -> Annotated[int, "first"]: + """Step that tests the usage of infer_artifact flag.""" + log_metadata(metadata={"aria": 1}, infer_artifact=True) + log_metadata( + metadata={"aria": 2}, infer_artifact=True, artifact_name="first" + ) + return 1 + + +@step +def step_multiple_outputs() -> ( + Tuple[Annotated[int, "second"], Annotated[int, "third"]] +): + """Step that tests infer_artifact flag with multiple outputs.""" + log_metadata( + metadata={"axl": 1}, infer_artifact=True, artifact_name="second" + ) + return 1, 2 + + +@step +def step_pipeline_model() -> None: + """Step that tests the infer_model flag.""" + log_metadata(metadata={"p": 1}, infer_model=True) + + +@step(model=Model(name="model_name", version="89a")) +def step_step_model() -> None: + """Step that tests the infer_model flag with a custom model version.""" + log_metadata(metadata={"s": 1}, infer_model=True) + + +@pipeline(model=Model(name="model_name", version="a89"), enable_cache=True) +def pipeline_to_log_metadata(): + """Pipeline definition to test the metadata utils.""" + step_multiple_calls() + step_single_output() + step_multiple_outputs() + step_pipeline_model() + step_step_model() + + +def test_metadata_utils(clean_client): + """Testing different functionalities of the log_metadata function.""" + # Run the pipeline + first_run = pipeline_to_log_metadata() + first_steps = first_run.steps + + # Check if the metadata was tagged correctly + assert first_run.run_metadata["step_multiple_calls::blupus"] == 2 + assert first_steps["step_multiple_calls"].run_metadata["blupus"] == 2 + assert ( + first_steps["step_single_output"] + .outputs["first"][0] + .run_metadata["aria"] + == 2 + ) + assert ( + first_steps["step_multiple_outputs"] + .outputs["second"][0] + .run_metadata["axl"] + == 1 + ) + + model_version_s = Model(name="model_name", version="89a") + assert model_version_s.run_metadata["s"] == 1 + + model_version_p = Model(name="model_name", version="a89") + assert model_version_p.run_metadata["p"] == 1 + + # Manually tag the run + log_metadata( + metadata={"manual_run": True}, run_id_name_or_prefix=first_run.id + ) + + # Manually tag the step + log_metadata( + metadata={"manual_step_1": True}, + step_id=first_run.steps["step_multiple_calls"].id, + ) + log_metadata( + metadata={"manual_step_2": True}, + step_name="step_multiple_calls", + run_id_name_or_prefix=first_run.id, + ) + + # Manually tag a model + log_metadata( + metadata={"manual_model_1": True}, model_version_id=model_version_p.id + ) + log_metadata( + metadata={"manual_model_2": True}, + model_name=model_version_p.name, + model_version=model_version_p.version, + ) + + # Manually tag an artifact + log_metadata( + metadata={"manual_artifact_1": True}, + artifact_version_id=first_run.steps["step_single_output"].output.id, + ) + log_metadata( + metadata={"manual_artifact_2": True}, + artifact_name=first_run.steps["step_single_output"].output.name, + artifact_version=first_run.steps["step_single_output"].output.version, + ) + + # Manually tag one step to test the caching logic later + log_metadata( + metadata={"blupus": 3}, + step_id=first_run.steps["step_multiple_calls"].id, + ) + + # Fetch the run and steps again + first_run_fetched = clean_client.get_pipeline_run( + name_id_or_prefix=first_run.id + ) + first_steps_fetched = first_run_fetched.steps + + assert first_run_fetched.run_metadata["manual_run"] + assert first_run_fetched.run_metadata["step_multiple_calls::manual_step_1"] + assert first_run_fetched.run_metadata["step_multiple_calls::manual_step_2"] + assert first_steps_fetched["step_multiple_calls"].run_metadata[ + "manual_step_1" + ] + assert first_steps_fetched["step_multiple_calls"].run_metadata[ + "manual_step_2" + ] + assert first_steps_fetched["step_single_output"].output.run_metadata[ + "manual_artifact_1" + ] + assert first_steps_fetched["step_single_output"].output.run_metadata[ + "manual_artifact_2" + ] + + # Fetch the model again + model_version_p_fetched = Model(name="model_name", version="a89") + + assert model_version_p_fetched.run_metadata["manual_model_1"] + assert model_version_p_fetched.run_metadata["manual_model_2"] + + # Run the cached pipeline + second_run = pipeline_to_log_metadata() + assert second_run.steps["step_multiple_calls"].run_metadata["blupus"] == 2 + + # Test some of the invalid usages + with pytest.raises(ValueError): + log_metadata(metadata={"auto_step_1": True}) + + with pytest.raises(ValueError): + log_metadata(metadata={"auto_model_1": True}, infer_model=True) + + with pytest.raises(ValueError): + log_metadata(metadata={"auto_artifact_1": True}, infer_artifact=True) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 7d6c24e56e5..3c87f39e190 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -102,6 +102,7 @@ ModelVersionUpdate, PipelineRunFilter, PipelineRunResponse, + RunMetadataResource, ServiceAccountFilter, ServiceAccountRequest, ServiceAccountUpdate, @@ -2909,7 +2910,7 @@ def test_deleting_run_deletes_steps(): @step def step_to_log_metadata(metadata: Union[str, int, bool]) -> int: - log_metadata({"blupus": metadata}) + log_metadata(metadata={"blupus": metadata}) return 42 @@ -2954,16 +2955,6 @@ def test_pipeline_run_filters_with_oneof_and_run_metadata(clean_client): with pytest.raises(ValidationError): PipelineRunFilter(name="oneof:random_value") - # Test metadata filtering - runs_filter = PipelineRunFilter(run_metadata={"blupus": "lt:30"}) - runs = store.list_runs(runs_filter_model=runs_filter) - assert len(runs) == 2 # The run with 3 and 25 - - for r in runs: - assert "blupus" in r.run_metadata - assert isinstance(r.run_metadata["blupus"], int) - assert r.run_metadata["blupus"] < 30 - # .--------------------. # | Pipeline run steps | @@ -5457,8 +5448,7 @@ def test_metadata_full_cycle_with_cascade_deletion( RunMetadataRequest( user=client.active_user.id, workspace=client.active_workspace.id, - resource_id=resource.id, - resource_type=type_, + resources=[RunMetadataResource(id=resource.id, type=type_)], values={"foo": "bar"}, types={"foo": MetadataTypeEnum.STRING}, stack_component_id=sc.id