diff --git a/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/common/mappers/TimeStampMapper.java b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/common/mappers/TimeStampMapper.java new file mode 100644 index 0000000000000..9792079e66f64 --- /dev/null +++ b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/common/mappers/TimeStampMapper.java @@ -0,0 +1,28 @@ +package com.linkedin.datahub.graphql.types.common.mappers; + +import com.linkedin.datahub.graphql.QueryContext; +import com.linkedin.datahub.graphql.generated.TimeStamp; +import com.linkedin.datahub.graphql.types.mappers.ModelMapper; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +public class TimeStampMapper implements ModelMapper { + + public static final TimeStampMapper INSTANCE = new TimeStampMapper(); + + public static TimeStamp map( + @Nullable QueryContext context, @Nonnull final com.linkedin.common.TimeStamp timestamp) { + return INSTANCE.apply(context, timestamp); + } + + @Override + public TimeStamp apply( + @Nullable QueryContext context, @Nonnull final com.linkedin.common.TimeStamp timestamp) { + final TimeStamp result = new TimeStamp(); + result.setTime(timestamp.getTime()); + if (timestamp.hasActor()) { + result.setActor(timestamp.getActor().toString()); + } + return result; + } +} diff --git a/datahub-graphql-core/src/main/resources/entity.graphql b/datahub-graphql-core/src/main/resources/entity.graphql index e086273068ee5..8d633e95d3f0f 100644 --- a/datahub-graphql-core/src/main/resources/entity.graphql +++ b/datahub-graphql-core/src/main/resources/entity.graphql @@ -14,6 +14,11 @@ Root type used for fetching DataHub Metadata Coming soon listEntity queries for listing all entities of a given type """ type Query { + _empty: String +} + + +extend type Query { """ Fetch a Data Platform by primary key (urn) """ @@ -262,8 +267,16 @@ type Query { Fetch all Business Attributes """ listBusinessAttributes(input: ListBusinessAttributesInput!): ListBusinessAttributesResult + + """ + Fetch a Data Process Instance by primary key (urn) + """ + dataProcessInstance(urn: String!): DataProcessInstance + + } + """ An ERModelRelationship is a high-level abstraction that dictates what datasets fields are erModelRelationshiped. """ @@ -9488,6 +9501,18 @@ type AuditStamp { actor: String } +type TimeStamp { + """ + When did the event occur + """ + time: Long! + + """ + The actor involved in the event + """ + actor: String +} + """ Input for creating a new group """ @@ -9832,15 +9857,42 @@ type MLModelGroup implements EntityWithRelationships & Entity & BrowsableEntity privileges: EntityPrivileges } +""" +Properties describing a group of related ML models +""" type MLModelGroupProperties { + """ + Display name of the model group + """ + name: String + """ + Detailed description of the model group's purpose and contents + """ description: String - createdAt: Long + """ + When this model group was created + """ + created: TimeStamp - version: VersionTag + """ + When this model group was last modified + """ + lastModified: TimeStamp + """ + Version identifier for this model group + """ + version: VersionTag + customProperties: [CustomPropertiesEntry!] + + """ + Deprecated creation timestamp + @deprecated Use the 'created' field instead + """ + createdAt: Long @deprecated(reason: "Use `created` instead") } """ @@ -9990,40 +10042,103 @@ description: String } type MLMetric { + """ + Name of the metric (e.g. accuracy, precision, recall) + """ name: String + """ + Description of what this metric measures + """ description: String + """ + The computed value of the metric + """ value: String + """ + Timestamp when this metric was recorded + """ createdAt: Long } type MLModelProperties { + """ + The display name of the model used in the UI + """ + name: String! + """ + Detailed description of the model's purpose and characteristics + """ description: String - date: Long + """ + When the model was last modified + """ + lastModified: TimeStamp + """ + Version identifier for this model + """ version: String + """ + The type/category of ML model (e.g. classification, regression) + """ type: String + """ + Mapping of hyperparameter configurations + """ hyperParameters: HyperParameterMap - hyperParams: [MLHyperParam] + """ + List of hyperparameter settings used to train this model + """ + hyperParams: [MLHyperParam] + """ + Performance metrics from model training + """ trainingMetrics: [MLMetric] + """ + Names of ML features used by this model + """ mlFeatures: [String!] + """ + Tags for categorizing and searching models + """ tags: [String!] + """ + Model groups this model belongs to + """ groups: [MLModelGroup] + """ + Additional custom properties specific to this model + """ customProperties: [CustomPropertiesEntry!] + """ + URL to view this model in external system + """ externalUrl: String + + """ + When this model was created + """ + created: TimeStamp + + """ + Deprecated timestamp for model creation + @deprecated Use 'created' field instead + """ + date: Long @deprecated(reason: "Use `created` instead") } type MLFeatureProperties { @@ -12804,3 +12919,92 @@ type CronSchedule { """ timezone: String! } + + +""" +Properties describing a data process instance's execution metadata +""" +type DataProcessInstanceProperties { + """ + The display name of this process instance + """ + name: String! + + """ + URL to view this process instance in the external system + """ + externalUrl: String + + """ + When this process instance was created + """ + created: TimeStamp + + """ + Additional custom properties specific to this process instance + """ + customProperties: [CustomPropertiesEntry!] +} + +""" +Properties specific to an ML model training run instance +""" +type MLTrainingRunProperties { + """ + Unique identifier for this training run + """ + id: String + + """ + List of URLs to access training run outputs (e.g. model artifacts, logs) + """ + outputUrls: [String] + + """ + Hyperparameters used in this training run + """ + hyperParams: [MLHyperParam] + + """ + Performance metrics recorded during this training run + """ + trainingMetrics: [MLMetric] +} + +extend type DataProcessInstance { + + """ + Additional read only properties associated with the Data Job + """ + properties: DataProcessInstanceProperties + + """ + The specific instance of the data platform that this entity belongs to + """ + dataPlatformInstance: DataPlatformInstance + + """ + Sub Types that this entity implements + """ + subTypes: SubTypes + + """ + The parent container in which the entity resides + """ + container: Container + + """ + Standardized platform urn where the data process instance is defined + """ + platform: DataPlatform! + + """ + Recursively get the lineage of containers for this entity + """ + parentContainers: ParentContainersResult + + """ + Additional properties when subtype is Training Run + """ + mlTrainingRunProperties: MLTrainingRunProperties +} \ No newline at end of file diff --git a/metadata-ingestion/examples/recipes/mlflow_to_datahub.dhub.yaml b/metadata-ingestion/examples/recipes/mlflow_to_datahub.dhub.yaml new file mode 100644 index 0000000000000..07e9ed5d786cd --- /dev/null +++ b/metadata-ingestion/examples/recipes/mlflow_to_datahub.dhub.yaml @@ -0,0 +1,9 @@ +source: + type: mlflow + config: + tracking_uri: "http://127.0.0.1:5000" + +sink: + type: datahub-rest + config: + server: "http://localhost:8080" \ No newline at end of file diff --git a/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py b/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py index d406fa36e00db..2b847cb6fe4f9 100644 --- a/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py +++ b/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py @@ -15,16 +15,21 @@ ) from datahub.metadata.schema_classes import ( AuditStampClass, + DataPlatformInstanceClass, DataProcessInstanceRunEventClass, DataProcessInstanceRunResultClass, DataProcessRunStatusClass, DataProcessTypeClass, + SubTypesClass, + ContainerClass, ) +from datahub.metadata.urns import DataPlatformInstanceUrn, DataPlatformUrn, ContainerUrn from datahub.utilities.str_enum import StrEnum from datahub.utilities.urns.data_flow_urn import DataFlowUrn from datahub.utilities.urns.data_job_urn import DataJobUrn from datahub.utilities.urns.data_process_instance_urn import DataProcessInstanceUrn from datahub.utilities.urns.dataset_urn import DatasetUrn +from datahub.emitter.mcp_builder import ContainerKey class DataProcessInstanceKey(DatahubKey): @@ -61,7 +66,9 @@ class DataProcessInstance: orchestrator: str cluster: Optional[str] = None type: str = DataProcessTypeClass.BATCH_SCHEDULED - template_urn: Optional[Union[DataJobUrn, DataFlowUrn, DatasetUrn]] = None + template_urn: Optional[ + Union[DataJobUrn, DataFlowUrn, DatasetUrn, ContainerUrn] + ] = None parent_instance: Optional[DataProcessInstanceUrn] = None properties: Dict[str, str] = field(default_factory=dict) url: Optional[str] = None @@ -71,6 +78,10 @@ class DataProcessInstance: _template_object: Optional[Union[DataJob, DataFlow]] = field( init=False, default=None, repr=False ) + data_platform: Optional[str] = None + data_plaform_instance: Optional[str] = None + subtype: Optional[str] = None + container_urn: Optional[str] = None def __post_init__(self): self.urn = DataProcessInstanceUrn( @@ -80,6 +91,36 @@ def __post_init__(self): id=self.id, ).guid() ) + if self.data_platform is None: + self.data_platform = self.orchestrator + + try: + # We first try to create from string assuming its an urn + self.data_platform = str( + DataPlatformUrn.create_from_string(self.data_platform) + ) + except Exception: + # If it fails, we assume its an id + self.data_platform = str(DataPlatformUrn.create_from_id(self.data_platform)) + + if self.data_plaform_instance is None and self.cluster is not None: + self.data_plaform_instance = self.cluster + + if self.data_plaform_instance is not None: + try: + # We first try to create from string assuming its an urn + self.data_plaform_instance = str( + DataPlatformInstanceUrn.create_from_string( + self.data_plaform_instance + ) + ) + except Exception: + # If it fails, we assume its an id + self.data_plaform_instance = str( + DataPlatformInstanceUrn( + platform=self.data_platform, instance=self.data_plaform_instance + ) + ) def start_event_mcp( self, start_timestamp_millis: int, attempt: Optional[int] = None @@ -269,6 +310,29 @@ def generate_mcp( ) yield mcp + assert self.data_platform + + mcp = MetadataChangeProposalWrapper( + entityUrn=str(self.urn), + aspect=DataPlatformInstanceClass( + platform=self.data_platform, instance=self.data_plaform_instance + ), + ) + yield mcp + + if self.subtype: + mcp = MetadataChangeProposalWrapper( + entityUrn=str(self.urn), aspect=SubTypesClass(typeNames=[self.subtype]) + ) + yield mcp + + if self.container_urn: + mcp = MetadataChangeProposalWrapper( + entityUrn=str(self.urn), + aspect=ContainerClass(container=self.container_urn), + ) + yield mcp + yield from self.generate_inlet_outlet_mcp(materialize_iolets=materialize_iolets) @staticmethod @@ -331,6 +395,31 @@ def from_datajob( dpi.outlets = datajob.outlets return dpi + @staticmethod + def from_container( + container_key: ContainerKey, + id: str, + ) -> "DataProcessInstance": + """ + Generates DataProcessInstance from a Container + + :param datajob: (DataJob) the datajob from generate the DataProcessInstance + :param id: (str) the id for the DataProcessInstance + :param clone_inlets: (bool) whether to clone datajob's inlets + :param clone_outlets: (bool) whether to clone datajob's outlets + :return: DataProcessInstance + """ + dpi: DataProcessInstance = DataProcessInstance( + id=id, + orchestrator=DataPlatformUrn.from_string( + container_key.platform + ).platform_name, + template_urn=None, + container_urn=container_key.as_urn(), + ) + + return dpi + @staticmethod def from_dataflow(dataflow: DataFlow, id: str) -> "DataProcessInstance": """ diff --git a/metadata-ingestion/src/datahub/emitter/mcp.py b/metadata-ingestion/src/datahub/emitter/mcp.py index c6fcfad2e0aba..75a7dc4e8740c 100644 --- a/metadata-ingestion/src/datahub/emitter/mcp.py +++ b/metadata-ingestion/src/datahub/emitter/mcp.py @@ -14,7 +14,7 @@ SystemMetadataClass, _Aspect, ) -from datahub.utilities.urns.urn import guess_entity_type +from datahub.utilities.urns.urn import Urn, guess_entity_type if TYPE_CHECKING: from datahub.ingestion.api.workunit import MetadataWorkUnit @@ -71,7 +71,11 @@ class MetadataChangeProposalWrapper: systemMetadata: Union[None, SystemMetadataClass] = None def __post_init__(self) -> None: - if self.entityUrn and self.entityType == _ENTITY_TYPE_UNSET: + if isinstance(self.entityUrn, Urn): + if self.entityType == _ENTITY_TYPE_UNSET: + self.entityType = self.entityUrn.entity_type + self.entityUrn = str(self.entityUrn) + elif self.entityUrn and self.entityType == _ENTITY_TYPE_UNSET: self.entityType = guess_entity_type(self.entityUrn) elif self.entityUrn and self.entityType: guessed_entity_type = guess_entity_type(self.entityUrn).lower() @@ -105,6 +109,7 @@ def construct_many( return [cls(entityUrn=entityUrn, aspect=aspect) for aspect in aspects if aspect] def _make_mcp_without_aspects(self) -> MetadataChangeProposalClass: + assert self.entityUrn is None or isinstance(self.entityUrn, str) return MetadataChangeProposalClass( entityType=self.entityType, entityUrn=self.entityUrn, diff --git a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py index 26d160acf330c..4a3d0a72bf187 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py +++ b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py @@ -1,13 +1,15 @@ from dataclasses import dataclass -from typing import Any, Callable, Iterable, Optional, TypeVar, Union +from typing import Any, Callable, Iterable, Optional, TypeVar, Union, List +import time from mlflow import MlflowClient -from mlflow.entities import Run +from mlflow.entities import Run, Experiment from mlflow.entities.model_registry import ModelVersion, RegisteredModel from mlflow.store.entities import PagedList from pydantic.fields import Field import datahub.emitter.mce_builder as builder +from datahub.emitter.mcp_builder import ContainerKey from datahub.configuration.source_common import EnvConfigMixin from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.common import PipelineContext @@ -29,39 +31,86 @@ TagAssociationClass, TagPropertiesClass, VersionTagClass, - _Aspect, + DataProcessInstanceRunEventClass, + DataProcessInstancePropertiesClass, + ContainerPropertiesClass, + AuditStampClass, + TimeStampClass, + DataProcessRunStatusClass, + SubTypesClass, + DataPlatformInstanceClass, + BrowsePathsV2Class, + MetadataChangeProposalClass, + MLTrainingRunPropertiesClass, + DataProcessInstanceRunResultClass, + DataProcessInstanceOutputClass, +) +from datahub.metadata.urns import ( + DatasetUrn, + DataPlatformUrn, + MlModelUrn, + MlModelGroupUrn, + DataProcessInstanceUrn, + DataPlatformInstanceUrn, +) +from datahub.api.entities.dataprocess.dataprocess_instance import ( + DataProcessInstance, + InstanceRunResult, ) T = TypeVar("T") +class ContainerKeyWithId(ContainerKey): + id: str + + +@dataclass +class Container: + key: ContainerKeyWithId + subtype: str + name: Optional[str] = None + description: Optional[str] = None + + def generate_mcp( + self, + ) -> Iterable[Union[MetadataChangeProposalClass, MetadataChangeProposalWrapper]]: + container_urn = self.key.as_urn() + + container_subtype = SubTypesClass(typeNames=[self.subtype]) + + container_info = ContainerPropertiesClass( + name=self.name or self.key.id, + description=self.description, + customProperties={}, + ) + + browse_path = BrowsePathsV2Class(path=[]) + + dpi = DataPlatformInstanceClass( + platform=self.key.platform, + instance=self.key.instance, + ) + + return MetadataChangeProposalWrapper.construct_many( + entityUrn=container_urn, + aspects=[container_subtype, container_info, browse_path, dpi], + ) + + class MLflowConfig(EnvConfigMixin): tracking_uri: Optional[str] = Field( default=None, - description=( - "Tracking server URI. If not set, an MLflow default tracking_uri is used" - " (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)" - ), + description="Tracking server URI. If not set, an MLflow default tracking_uri is used (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)", ) registry_uri: Optional[str] = Field( default=None, - description=( - "Registry server URI. If not set, an MLflow default registry_uri is used" - " (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)" - ), + description="Registry server URI. If not set, an MLflow default registry_uri is used (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)", ) model_name_separator: str = Field( default="_", description="A string which separates model name from its version (e.g. model_1 or model-1)", ) - base_external_url: Optional[str] = Field( - default=None, - description=( - "Base URL to use when constructing external URLs to MLflow." - " If not set, tracking_uri is used if it's an HTTP URL." - " If neither is set, external URLs are not generated." - ), - ) @dataclass @@ -118,12 +167,10 @@ def get_report(self) -> SourceReport: def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: yield from self._get_tags_workunits() + yield from self._get_experiment_workunits() yield from self._get_ml_model_workunits() def _get_tags_workunits(self) -> Iterable[MetadataWorkUnit]: - """ - Create tags for each Stage in MLflow Model Registry. - """ for stage_info in self.registered_model_stages_info: tag_urn = self._make_stage_tag_urn(stage_info.name) tag_properties = TagPropertiesClass( @@ -142,36 +189,204 @@ def _make_stage_tag_urn(self, stage_name: str) -> str: def _make_stage_tag_name(self, stage_name: str) -> str: return f"{self.platform}_{stage_name.lower()}" - def _create_workunit(self, urn: str, aspect: _Aspect) -> MetadataWorkUnit: - """ - Utility to create an MCP workunit. - """ + def _create_workunit(self, urn: str, aspect: Any) -> MetadataWorkUnit: return MetadataChangeProposalWrapper( entityUrn=urn, aspect=aspect, ).as_workunit() - def _get_ml_model_workunits(self) -> Iterable[MetadataWorkUnit]: - """ - Traverse each Registered Model in Model Registry and generate a corresponding workunit. - """ - registered_models = self._get_mlflow_registered_models() - for registered_model in registered_models: - yield self._get_ml_group_workunit(registered_model) - model_versions = self._get_mlflow_model_versions(registered_model) - for model_version in model_versions: - run = self._get_mlflow_run(model_version) - yield self._get_ml_model_properties_workunit( - registered_model=registered_model, - model_version=model_version, - run=run, - ) - yield self._get_global_tags_workunit(model_version=model_version) + def _get_experiment_workunits(self) -> Iterable[MetadataWorkUnit]: + experiments = self._get_mlflow_experiments() + for experiment in experiments: + # Yield each workunit from the container workunits + for wu in self._get_experiment_container_workunit(experiment): + yield wu + + runs = self._get_mlflow_runs_from_experiment(experiment) + if runs: + for run in runs: + for wu in self._get_run_workunits(experiment, run): + yield wu + + def _get_experiment_custom_properties(self, experiment): + experiment_custom_props = getattr(experiment, "tags", {}) or {} + experiment_custom_props.pop("mlflow.note.content", None) + experiment_custom_props["artifacts_location"] = experiment.artifact_location + return experiment_custom_props + + def _get_experiment_container_workunit( + self, experiment: Experiment + ) -> List[MetadataWorkUnit]: + experiment_container = Container( + key=ContainerKeyWithId( + platform=str(DataPlatformUrn.create_from_id("mlflow")), + id=experiment.name, + ), + subtype="ML Experiment", + name=experiment.name, + description=experiment.tags.get("mlflow.note.content"), + ) # TODO: this generates a urn as guid, should we change this to use experiment.id? + + print( + "experiment.key.id:", experiment.key.id + ) # this should be same as container key as urn + print("experiment.key.as_urn(): ", experiment.key.as_urn()) + + workunits = [mcp.as_workunit() for mcp in experiment.generate_mcp()] + return workunits + + def _get_run_custom_properties(self, run: Run): + custom_props = {} + custom_props.update(getattr(run, "tags", {}) or {}) + return custom_props + + def _get_run_metrics(self, run: Run): + return [ + MLMetricClass(name=k, value=str(v)) for k, v in run.data.metrics.items() + ] + + def _get_run_params(self, run: Run): + return [ + MLHyperParamClass(name=k, value=str(v)) for k, v in run.data.params.items() + ] + + def _convert_run_result_type( + self, status: str + ) -> DataProcessInstanceRunResultClass: + if status == "FINISHED": + return DataProcessInstanceRunResultClass( + type="SUCCESS", nativeResultType="mlflow" + ) + elif status == "FAILED": + return DataProcessInstanceRunResultClass( + type="FAILURE", nativeResultType="mlflow" + ) + else: + return DataProcessInstanceRunResultClass( + type="SKIPPED", nativeResultType="mlflow" + ) + + def _get_run_workunits( + self, experiment: Experiment, run: Run + ) -> List[MetadataWorkUnit]: + experiment_key = ContainerKeyWithId( + platform=str(DataPlatformUrn.create_from_id("mlflow")), id=experiment.name + ) + + data_process_instance = DataProcessInstance.from_container( + container_key=experiment_key, id=run.info.run_name + ) # TODO: this generates a urn as guid, should we change this to use run.info.run_id? + workunits = [] + + run_custom_props = self._get_run_custom_properties(run) + created_time = run.info.start_time or int(time.time() * 1000) + created_actor = ( + f"urn:li:platformResource:{run.info.user_id}" if run.info.user_id else None + ) + + workunits.append( + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=DataProcessInstancePropertiesClass( + name=run.info.run_name or run.info.run_id, + created=AuditStampClass( + time=created_time, + actor=created_actor, + ), + externalUrl=self._make_external_url_from_run(experiment, run), + customProperties=run_custom_props, + ), + ).as_workunit() + ) + + # get model from run + model_versions = self.get_mlflow_model_versions_from_run(run.info.run_id) + model_version_urn = self._make_ml_model_urn(model_versions[0]) + model_version_urn = "urn:li:dataset:(urn:li:dataPlatform:mlflow,sk-learn-random-forest-reg_1,PROD)" + if model_versions: + workunits.append( + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=DataProcessInstanceOutputClass(outputs=[model_version_urn]), + ).as_workunit() + ) + + metrics = self._get_run_metrics(run) + hyperparams = self._get_run_params(run) + workunits.append( + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=MLTrainingRunPropertiesClass( + hyperParams=hyperparams, + trainingMetrics=metrics, + outputUrls=[run.info.artifact_uri], + id=run.info.run_id, + ), + ).as_workunit() + ) + + result = ( + run.info.status + ) # TODO: this should be SUCCESS, SKIPPED, FAILURE, UP_FOR_RETRY + duration_millis = run.info.end_time - run.info.start_time + + if run.info.end_time: + workunits.append( + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=DataProcessInstanceRunEventClass( + status=DataProcessRunStatusClass.COMPLETE, + timestampMillis=run.info.end_time, + result=DataProcessInstanceRunResultClass( + type=self._convert_run_result_type(result).type, + nativeResultType="mlflow", + ), + durationMillis=duration_millis, + ), + ).as_workunit() + ) + + workunits.append( + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=DataPlatformInstanceClass( + platform=str(DataPlatformUrn.create_from_id("mlflow")) + ), + ).as_workunit() + ) + + workunits.append( + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=DataProcessInstancePropertiesClass( # Changed from RunEventClass + name=run.info.run_name or run.info.run_id, + created=AuditStampClass( + time=created_time, + actor=created_actor, + ), + ), + ).as_workunit() + ) + + workunits.append( + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=DataPlatformInstanceClass( + platform=str(DataPlatformUrn.create_from_id("mlflow")) + ), + ).as_workunit() + ) + + workunits.append( + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=SubTypesClass(typeNames=["ML Training Run"]), + ).as_workunit() + ) + + return workunits def _get_mlflow_registered_models(self) -> Iterable[RegisteredModel]: - """ - Get all Registered Models in MLflow Model Registry. - """ registered_models: Iterable[ RegisteredModel ] = self._traverse_mlflow_search_func( @@ -179,34 +394,57 @@ def _get_mlflow_registered_models(self) -> Iterable[RegisteredModel]: ) return registered_models + def _get_mlflow_experiments(self) -> Iterable[Experiment]: + experiments: Iterable[Experiment] = self._traverse_mlflow_search_func( + search_func=self.client.search_experiments, + ) + return experiments + + def _get_mlflow_runs_from_experiment(self, experiment: Experiment) -> List[Run]: + runs: List[Run] = self._traverse_mlflow_search_func( + search_func=self.client.search_runs, + experiment_ids=[experiment.experiment_id], + ) + return runs + @staticmethod def _traverse_mlflow_search_func( search_func: Callable[..., PagedList[T]], **kwargs: Any, ) -> Iterable[T]: - """ - Utility to traverse an MLflow search_* functions which return PagedList. - """ next_page_token = None while True: paged_list = search_func(page_token=next_page_token, **kwargs) - yield from paged_list.to_list() + yield from paged_list next_page_token = paged_list.token if not next_page_token: return + def _get_latest_version(self, registered_model: RegisteredModel) -> Optional[str]: + return ( + str(registered_model.latest_versions[0].version) + if registered_model.latest_versions + else None + ) + def _get_ml_group_workunit( self, registered_model: RegisteredModel, ) -> MetadataWorkUnit: - """ - Generate an MLModelGroup workunit for an MLflow Registered Model. - """ ml_model_group_urn = self._make_ml_model_group_urn(registered_model) ml_model_group_properties = MLModelGroupPropertiesClass( customProperties=registered_model.tags, description=registered_model.description, - createdAt=registered_model.creation_timestamp, + created=TimeStampClass( + time=registered_model.creation_timestamp, actor=None + ), + lastModified=TimeStampClass( + time=registered_model.last_updated_timestamp, + actor=None, + ), + version=VersionTagClass( + versionTag=self._get_latest_version(registered_model) + ), ) wu = self._create_workunit( urn=ml_model_group_urn, @@ -226,9 +464,6 @@ def _get_mlflow_model_versions( self, registered_model: RegisteredModel, ) -> Iterable[ModelVersion]: - """ - Get all Model Versions for each Registered Model. - """ filter_string = f"name = '{registered_model.name}'" model_versions: Iterable[ModelVersion] = self._traverse_mlflow_search_func( search_func=self.client.search_model_versions, @@ -236,51 +471,78 @@ def _get_mlflow_model_versions( ) return model_versions + def get_mlflow_model_versions_from_run(self, run_id): + filter_string = f"run_id = '{run_id}'" + + model_versions: Iterable[ModelVersion] = self._traverse_mlflow_search_func( + search_func=self.client.search_model_versions, + filter_string=filter_string, + ) + + return list(model_versions) + def _get_mlflow_run(self, model_version: ModelVersion) -> Union[None, Run]: - """ - Get a Run associated with a Model Version. Some MVs may exist without Run. - """ if model_version.run_id: run = self.client.get_run(model_version.run_id) return run else: return None + def _get_ml_model_workunits(self) -> Iterable[MetadataWorkUnit]: + registered_models = self._get_mlflow_registered_models() + for registered_model in registered_models: + yield self._get_ml_group_workunit(registered_model) + model_versions = self._get_mlflow_model_versions(registered_model) + for model_version in model_versions: + run = self._get_mlflow_run(model_version) + yield self._get_ml_model_properties_workunit( + registered_model=registered_model, + model_version=model_version, + run=run, + ) + yield self._get_global_tags_workunit(model_version=model_version) + def _get_ml_model_properties_workunit( self, registered_model: RegisteredModel, model_version: ModelVersion, run: Union[None, Run], ) -> MetadataWorkUnit: - """ - Generate an MLModel workunit for an MLflow Model Version. - Every Model Version is a DataHub MLModel entity associated with an MLModelGroup corresponding to a Registered Model. - If a model was registered without an associated Run then hyperparams and metrics are not available. - """ ml_model_group_urn = self._make_ml_model_group_urn(registered_model) ml_model_urn = self._make_ml_model_urn(model_version) + if run: - hyperparams = [ - MLHyperParamClass(name=k, value=str(v)) - for k, v in run.data.params.items() - ] - training_metrics = [ - MLMetricClass(name=k, value=str(v)) for k, v in run.data.metrics.items() - ] + # Use the same metrics and hyperparams from the run + hyperparams = self._get_run_params(run) + training_metrics = self._get_run_metrics(run) else: hyperparams = None training_metrics = None + + created_time = model_version.creation_timestamp + created_actor = ( + f"urn:li:platformResource:{model_version.user_id}" + if model_version.user_id + else None + ) + ml_model_properties = MLModelPropertiesClass( customProperties=model_version.tags, - externalUrl=self._make_external_url(model_version), + lastModified=TimeStampClass( + time=model_version.last_updated_timestamp, + actor=None, + ), + externalUrl=self._make_external_url_from_model_version(model_version), description=model_version.description, - date=model_version.creation_timestamp, + created=TimeStampClass( + time=created_time, + actor=created_actor, + ), version=VersionTagClass(versionTag=str(model_version.version)), hyperParams=hyperparams, trainingMetrics=training_metrics, - # mlflow tags are dicts, but datahub tags are lists. currently use only keys from mlflow tags tags=list(model_version.tags.keys()), - groups=[ml_model_group_urn], + groups=[str(ml_model_group_urn)], ) wu = self._create_workunit(urn=ml_model_urn, aspect=ml_model_properties) return wu @@ -293,24 +555,21 @@ def _make_ml_model_urn(self, model_version: ModelVersion) -> str: ) return urn - def _get_base_external_url_from_tracking_uri(self) -> Optional[str]: - if isinstance( - self.client.tracking_uri, str - ) and self.client.tracking_uri.startswith("http"): - return self.client.tracking_uri + def _make_external_url_from_model_version( + self, model_version: ModelVersion + ) -> Union[None, str]: + base_uri = self.client.tracking_uri + if base_uri.startswith("http"): + return f"{base_uri.rstrip('/')}/#/models/{model_version.name}/versions/{model_version.version}" else: return None - def _make_external_url(self, model_version: ModelVersion) -> Optional[str]: - """ - Generate URL for a Model Version to MLflow UI. - """ - base_uri = ( - self.config.base_external_url - or self._get_base_external_url_from_tracking_uri() - ) - if base_uri: - return f"{base_uri.rstrip('/')}/#/models/{model_version.name}/versions/{model_version.version}" + def _make_external_url_from_run( + self, experiment: Experiment, run: Run + ) -> Union[None, str]: + base_uri = self.client.tracking_uri + if base_uri.startswith("http"): + return f"{base_uri.rstrip('/')}/#/experiments/{experiment.experiment_id}/runs/{run.info.run_id}" else: return None diff --git a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceOutput.pdl b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceOutput.pdl index f33c41e63efed..fe782dbe01ca9 100644 --- a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceOutput.pdl +++ b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceOutput.pdl @@ -15,7 +15,7 @@ record DataProcessInstanceOutput { @Relationship = { "/*": { "name": "Produces", - "entityTypes": [ "dataset" ] + "entityTypes": [ "dataset", "mlModel" ] } } @Searchable = { diff --git a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceProperties.pdl b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceProperties.pdl index c63cb1a97c017..5c6bfaecf1ef4 100644 --- a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceProperties.pdl +++ b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceProperties.pdl @@ -52,4 +52,4 @@ record DataProcessInstanceProperties includes CustomProperties, ExternalReferenc } created: AuditStamp -} +} \ No newline at end of file diff --git a/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLModelGroupProperties.pdl b/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLModelGroupProperties.pdl index b54e430038082..3cc87633c3e07 100644 --- a/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLModelGroupProperties.pdl +++ b/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLModelGroupProperties.pdl @@ -4,6 +4,7 @@ import com.linkedin.common.Urn import com.linkedin.common.Time import com.linkedin.common.VersionTag import com.linkedin.common.CustomProperties +import com.linkedin.common.TimeStamp /** * Properties associated with an ML Model Group @@ -13,6 +14,17 @@ import com.linkedin.common.CustomProperties } record MLModelGroupProperties includes CustomProperties { + /** + * Display name of the MLModelGroup + */ + @Searchable = { + "fieldType": "WORD_GRAM", + "enableAutocomplete": true, + "boostScore": 10.0, + "queryByDefault": true, + } + name: optional string + /** * Documentation of the MLModelGroup */ @@ -25,10 +37,21 @@ record MLModelGroupProperties includes CustomProperties { /** * Date when the MLModelGroup was developed */ + @deprecated createdAt: optional Time + /** + * Time and Actor who created the MLModelGroup + */ + created: optional TimeStamp + + /** + * Date when the MLModelGroup was last modified + */ + lastModified: optional TimeStamp + /** * Version of the MLModelGroup */ version: optional VersionTag -} +} \ No newline at end of file diff --git a/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLModelProperties.pdl b/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLModelProperties.pdl index 621a3e1747b50..2a047316ee238 100644 --- a/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLModelProperties.pdl +++ b/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLModelProperties.pdl @@ -6,6 +6,7 @@ import com.linkedin.common.Time import com.linkedin.common.VersionTag import com.linkedin.common.CustomProperties import com.linkedin.common.ExternalReference +import com.linkedin.common.TimeStamp /** * Properties associated with a ML Model @@ -15,6 +16,18 @@ import com.linkedin.common.ExternalReference } record MLModelProperties includes CustomProperties, ExternalReference { + /** + * Display name of the MLModel + */ + @Searchable = { + "fieldType": "WORD_GRAM", + "enableAutocomplete": true, + "boostScore": 10.0, + "queryByDefault": true, + } + name: optional string + + /** * Documentation of the MLModel */ @@ -27,8 +40,19 @@ record MLModelProperties includes CustomProperties, ExternalReference { /** * Date when the MLModel was developed */ + @deprecated date: optional Time + /** + * Audit stamp containing who created this and when + */ + created: optional TimeStamp + + /** + * Date when the MLModel was last modified + */ + lastModified: optional TimeStamp + /** * Version of the MLModel */ @@ -129,4 +153,4 @@ record MLModelProperties includes CustomProperties, ExternalReference { } } groups: optional array[Urn] -} +} \ No newline at end of file diff --git a/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLTrainingRunProperties.pdl b/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLTrainingRunProperties.pdl new file mode 100644 index 0000000000000..f8b8eeafe908b --- /dev/null +++ b/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLTrainingRunProperties.pdl @@ -0,0 +1,36 @@ +namespace com.linkedin.ml.metadata + +import com.linkedin.common.AuditStamp +import com.linkedin.common.CustomProperties +import com.linkedin.common.ExternalReference +import com.linkedin.common.Urn +import com.linkedin.common.JobFlowUrn +import com.linkedin.common.DataJobUrn +/** + * The inputs and outputs of this training run + */ +@Aspect = { + "name": "mlTrainingRunProperties", +} +record MLTrainingRunProperties includes CustomProperties, ExternalReference { + + /** + * Run Id of the ML Training Run + */ + id: optional string + + /** + * List of URLs for the Outputs of the ML Training Run + */ + outputUrls: optional array[string] + + /** + * Hyperparameters of the ML Training Run + */ + hyperParams: optional array[MLHyperParam] + + /** + * Metrics of the ML Training Run + */ + trainingMetrics: optional array[MLMetric] +} \ No newline at end of file diff --git a/metadata-models/src/main/resources/entity-registry.yml b/metadata-models/src/main/resources/entity-registry.yml index 1c3eb5b574e20..4fe170ced69f3 100644 --- a/metadata-models/src/main/resources/entity-registry.yml +++ b/metadata-models/src/main/resources/entity-registry.yml @@ -116,6 +116,10 @@ entities: - dataProcessInstanceRunEvent - status - testResults + - dataPlatformInstance + - subTypes + - container + - mlTrainingRunProperties - name: chart category: core keyAspect: chartKey