diff --git a/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py b/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py index d406fa36e00db..2b847cb6fe4f9 100644 --- a/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py +++ b/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py @@ -15,16 +15,21 @@ ) from datahub.metadata.schema_classes import ( AuditStampClass, + DataPlatformInstanceClass, DataProcessInstanceRunEventClass, DataProcessInstanceRunResultClass, DataProcessRunStatusClass, DataProcessTypeClass, + SubTypesClass, + ContainerClass, ) +from datahub.metadata.urns import DataPlatformInstanceUrn, DataPlatformUrn, ContainerUrn from datahub.utilities.str_enum import StrEnum from datahub.utilities.urns.data_flow_urn import DataFlowUrn from datahub.utilities.urns.data_job_urn import DataJobUrn from datahub.utilities.urns.data_process_instance_urn import DataProcessInstanceUrn from datahub.utilities.urns.dataset_urn import DatasetUrn +from datahub.emitter.mcp_builder import ContainerKey class DataProcessInstanceKey(DatahubKey): @@ -61,7 +66,9 @@ class DataProcessInstance: orchestrator: str cluster: Optional[str] = None type: str = DataProcessTypeClass.BATCH_SCHEDULED - template_urn: Optional[Union[DataJobUrn, DataFlowUrn, DatasetUrn]] = None + template_urn: Optional[ + Union[DataJobUrn, DataFlowUrn, DatasetUrn, ContainerUrn] + ] = None parent_instance: Optional[DataProcessInstanceUrn] = None properties: Dict[str, str] = field(default_factory=dict) url: Optional[str] = None @@ -71,6 +78,10 @@ class DataProcessInstance: _template_object: Optional[Union[DataJob, DataFlow]] = field( init=False, default=None, repr=False ) + data_platform: Optional[str] = None + data_plaform_instance: Optional[str] = None + subtype: Optional[str] = None + container_urn: Optional[str] = None def __post_init__(self): self.urn = DataProcessInstanceUrn( @@ -80,6 +91,36 @@ def __post_init__(self): id=self.id, ).guid() ) + if self.data_platform is None: + self.data_platform = self.orchestrator + + try: + # We first try to create from string assuming its an urn + self.data_platform = str( + DataPlatformUrn.create_from_string(self.data_platform) + ) + except Exception: + # If it fails, we assume its an id + self.data_platform = str(DataPlatformUrn.create_from_id(self.data_platform)) + + if self.data_plaform_instance is None and self.cluster is not None: + self.data_plaform_instance = self.cluster + + if self.data_plaform_instance is not None: + try: + # We first try to create from string assuming its an urn + self.data_plaform_instance = str( + DataPlatformInstanceUrn.create_from_string( + self.data_plaform_instance + ) + ) + except Exception: + # If it fails, we assume its an id + self.data_plaform_instance = str( + DataPlatformInstanceUrn( + platform=self.data_platform, instance=self.data_plaform_instance + ) + ) def start_event_mcp( self, start_timestamp_millis: int, attempt: Optional[int] = None @@ -269,6 +310,29 @@ def generate_mcp( ) yield mcp + assert self.data_platform + + mcp = MetadataChangeProposalWrapper( + entityUrn=str(self.urn), + aspect=DataPlatformInstanceClass( + platform=self.data_platform, instance=self.data_plaform_instance + ), + ) + yield mcp + + if self.subtype: + mcp = MetadataChangeProposalWrapper( + entityUrn=str(self.urn), aspect=SubTypesClass(typeNames=[self.subtype]) + ) + yield mcp + + if self.container_urn: + mcp = MetadataChangeProposalWrapper( + entityUrn=str(self.urn), + aspect=ContainerClass(container=self.container_urn), + ) + yield mcp + yield from self.generate_inlet_outlet_mcp(materialize_iolets=materialize_iolets) @staticmethod @@ -331,6 +395,31 @@ def from_datajob( dpi.outlets = datajob.outlets return dpi + @staticmethod + def from_container( + container_key: ContainerKey, + id: str, + ) -> "DataProcessInstance": + """ + Generates DataProcessInstance from a Container + + :param datajob: (DataJob) the datajob from generate the DataProcessInstance + :param id: (str) the id for the DataProcessInstance + :param clone_inlets: (bool) whether to clone datajob's inlets + :param clone_outlets: (bool) whether to clone datajob's outlets + :return: DataProcessInstance + """ + dpi: DataProcessInstance = DataProcessInstance( + id=id, + orchestrator=DataPlatformUrn.from_string( + container_key.platform + ).platform_name, + template_urn=None, + container_urn=container_key.as_urn(), + ) + + return dpi + @staticmethod def from_dataflow(dataflow: DataFlow, id: str) -> "DataProcessInstance": """ diff --git a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py index 26d160acf330c..1a934f6aab173 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py +++ b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py @@ -1,13 +1,15 @@ from dataclasses import dataclass -from typing import Any, Callable, Iterable, Optional, TypeVar, Union +from typing import Any, Callable, Iterable, Optional, TypeVar, Union, List +import time from mlflow import MlflowClient from mlflow.entities import Run -from mlflow.entities.model_registry import ModelVersion, RegisteredModel +from mlflow.entities.model_registry import ModelVersion, RegisteredModel, Experiment from mlflow.store.entities import PagedList from pydantic.fields import Field import datahub.emitter.mce_builder as builder +from datahub.emitter.mcp_builder import ContainerKey from datahub.configuration.source_common import EnvConfigMixin from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.common import PipelineContext @@ -29,39 +31,79 @@ TagAssociationClass, TagPropertiesClass, VersionTagClass, - _Aspect, + DataProcessInstanceRunEventClass, + DataProcessInstancePropertiesClass, + ContainerPropertiesClass, + AuditStampClass, + TimeStampClass, + DataProcessRunStatusClass, + SubTypesClass, + DataPlatformInstanceClass, + BrowsePathsV2Class, + MetadataChangeProposalClass, + MLTrainingRunPropertiesClass, + DataProcessInstanceRunResultClass, + DataProcessInstanceOutputClass, +) +from datahub.metadata.urns import DatasetUrn, DataPlatformUrn, MlModelUrn, MlModelGroupUrn, DataProcessInstanceUrn, DataPlatformInstanceUrn +from datahub.api.entities.dataprocess.dataprocess_instance import ( + DataProcessInstance, + InstanceRunResult, ) T = TypeVar("T") +class ContainerKeyWithId(ContainerKey): + id: str + + +@dataclass +class Container: + key: ContainerKeyWithId + subtype: str + name: Optional[str] = None + description: Optional[str] = None + + def generate_mcp( + self, + ) -> Iterable[Union[MetadataChangeProposalClass, MetadataChangeProposalWrapper]]: + container_urn = self.key.as_urn() + + container_subtype = SubTypesClass(typeNames=[self.subtype]) + + container_info = ContainerPropertiesClass( + name=self.name or self.key.id, + description=self.description, + customProperties={}, + ) + + browse_path = BrowsePathsV2Class(path=[]) + + dpi = DataPlatformInstanceClass( + platform=self.key.platform, + instance=self.key.instance, + ) + + return MetadataChangeProposalWrapper.construct_many( + entityUrn=container_urn, + aspects=[container_subtype, container_info, browse_path, dpi], + ) + + class MLflowConfig(EnvConfigMixin): tracking_uri: Optional[str] = Field( default=None, - description=( - "Tracking server URI. If not set, an MLflow default tracking_uri is used" - " (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)" - ), + description="Tracking server URI. If not set, an MLflow default tracking_uri is used (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)", ) registry_uri: Optional[str] = Field( default=None, - description=( - "Registry server URI. If not set, an MLflow default registry_uri is used" - " (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)" - ), + description="Registry server URI. If not set, an MLflow default registry_uri is used (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)", ) model_name_separator: str = Field( default="_", description="A string which separates model name from its version (e.g. model_1 or model-1)", ) - base_external_url: Optional[str] = Field( - default=None, - description=( - "Base URL to use when constructing external URLs to MLflow." - " If not set, tracking_uri is used if it's an HTTP URL." - " If neither is set, external URLs are not generated." - ), - ) @dataclass @@ -118,12 +160,10 @@ def get_report(self) -> SourceReport: def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: yield from self._get_tags_workunits() + yield from self._get_experiment_workunits() yield from self._get_ml_model_workunits() def _get_tags_workunits(self) -> Iterable[MetadataWorkUnit]: - """ - Create tags for each Stage in MLflow Model Registry. - """ for stage_info in self.registered_model_stages_info: tag_urn = self._make_stage_tag_urn(stage_info.name) tag_properties = TagPropertiesClass( @@ -142,36 +182,205 @@ def _make_stage_tag_urn(self, stage_name: str) -> str: def _make_stage_tag_name(self, stage_name: str) -> str: return f"{self.platform}_{stage_name.lower()}" - def _create_workunit(self, urn: str, aspect: _Aspect) -> MetadataWorkUnit: - """ - Utility to create an MCP workunit. - """ + def _create_workunit(self, urn: str, aspect: Any) -> MetadataWorkUnit: return MetadataChangeProposalWrapper( entityUrn=urn, aspect=aspect, ).as_workunit() - def _get_ml_model_workunits(self) -> Iterable[MetadataWorkUnit]: - """ - Traverse each Registered Model in Model Registry and generate a corresponding workunit. - """ - registered_models = self._get_mlflow_registered_models() - for registered_model in registered_models: - yield self._get_ml_group_workunit(registered_model) - model_versions = self._get_mlflow_model_versions(registered_model) - for model_version in model_versions: - run = self._get_mlflow_run(model_version) - yield self._get_ml_model_properties_workunit( - registered_model=registered_model, - model_version=model_version, - run=run, - ) - yield self._get_global_tags_workunit(model_version=model_version) + def _get_experiment_workunits(self) -> Iterable[MetadataWorkUnit]: + experiments = self._get_mlflow_experiments() + for experiment in experiments: + # Yield each workunit from the container workunits + for wu in self._get_experiment_container_workunit(experiment): + yield wu + + runs = self._get_mlflow_runs_from_experiment(experiment) + if runs: + for run in runs: + for wu in self._get_run_workunits(experiment, run): + yield wu + + def _get_experiment_custom_properties(self, experiment): + experiment_custom_props = getattr(experiment, "tags", {}) or {} + experiment_custom_props.pop("mlflow.note.content", None) + experiment_custom_props["artifacts_location"] = experiment.artifact_location + return experiment_custom_props + + def _get_experiment_container_workunit( + self, experiment: Experiment + ) -> List[MetadataWorkUnit]: + experiment_container = Container( + key=ContainerKeyWithId( + platform=str(DataPlatformUrn.create_from_id("mlflow")), + id=experiment.name + ), + subtype="ML Experiment", + name=experiment.name, + description=experiment.tags.get("mlflow.note.content"), + ) # TODO: this generates a urn as guid, should we change this to use experiment.id? + + print("experiment.key.id:", experiment.key.id) # this should be same as container key as urn + print("experiment.key.as_urn(): ", experiment.key.as_urn()) + + workunits = [mcp.as_workunit() for mcp in experiment.generate_mcp()] + return workunits + + def _get_run_custom_properties(self, run: Run): + custom_props = {} + custom_props.update(getattr(run, "tags", {}) or {}) + return custom_props + + def _get_run_metrics(self, run: Run): + return [ + MLMetricClass(name=k, value=str(v)) for k, v in run.data.metrics.items() + ] + + def _get_run_params(self, run: Run): + return [ + MLHyperParamClass(name=k, value=str(v)) for k, v in run.data.params.items() + ] + + def _convert_run_result_type( + self, status: str + ) -> DataProcessInstanceRunResultClass: + if status == "FINISHED": + return DataProcessInstanceRunResultClass( + type="SUCCESS", nativeResultType="mlflow" + ) + elif status == "FAILED": + return DataProcessInstanceRunResultClass( + type="FAILURE", nativeResultType="mlflow" + ) + else: + return DataProcessInstanceRunResultClass( + type="SKIPPED", nativeResultType="mlflow" + ) + + def _get_run_workunits( + self, experiment: Experiment, run: Run + ) -> List[MetadataWorkUnit]: + experiment_key = ContainerKeyWithId( + platform=str(DataPlatformUrn.create_from_id("mlflow")), id=experiment.name + ) + + data_process_instance = DataProcessInstance.from_container( + container_key=experiment_key, id=run.info.run_name + ) # TODO: this generates a urn as guid, should we change this to use run.info.run_id? + workunits = [] + + run_custom_props = self._get_run_custom_properties(run) + created_time = run.info.start_time or int(time.time() * 1000) + created_actor = ( + f"urn:li:platformResource:{run.info.user_id}" if run.info.user_id else None + ) + + workunits.append( + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=DataProcessInstancePropertiesClass( + name=run.info.run_name or run.info.run_id, + created=AuditStampClass( + time=created_time, + actor=created_actor, + ), + externalUrl=self._make_external_url_from_run(experiment, run), + customProperties=run_custom_props, + ), + ).as_workunit() + ) + + # get model from run + model_versions = self.get_mlflow_model_versions_from_run(run.info.run_id) + model_version_urn = self._make_ml_model_urn(model_versions[0]) + model_version_urn = "urn:li:dataset:(urn:li:dataPlatform:mlflow,sk-learn-random-forest-reg_1,PROD)" + if model_versions: + workunits.append( + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=DataProcessInstanceOutputClass( + outputs=[model_version_urn] + ), + ).as_workunit() + ) + + metrics = self._get_run_metrics(run) + hyperparams = self._get_run_params(run) + workunits.append( + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=MLTrainingRunPropertiesClass( + hyperParams=hyperparams, + trainingMetrics=metrics, + outputUrls=[run.info.artifact_uri], + id=run.info.run_id, + ), + ).as_workunit() + ) + + result = ( + run.info.status + ) # TODO: this should be SUCCESS, SKIPPED, FAILURE, UP_FOR_RETRY + duration_millis = run.info.end_time - run.info.start_time + + + if run.info.end_time: + workunits.append( + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=DataProcessInstanceRunEventClass( + status=DataProcessRunStatusClass.COMPLETE, + timestampMillis=run.info.end_time, + result=DataProcessInstanceRunResultClass( + type=self._convert_run_result_type(result).type, + nativeResultType="mlflow", + ), + durationMillis=duration_millis, + ), + ).as_workunit() + ) + + workunits.append( + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=DataPlatformInstanceClass( + platform=str(DataPlatformUrn.create_from_id("mlflow")) + ), + ).as_workunit() + ) + + workunits.append( + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=DataProcessInstancePropertiesClass( # Changed from RunEventClass + name=run.info.run_name or run.info.run_id, + created=AuditStampClass( + time=created_time, + actor=created_actor, + ) + ), + ).as_workunit() + ) + + workunits.append( + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=DataPlatformInstanceClass( + platform=str(DataPlatformUrn.create_from_id("mlflow")) + ), + ).as_workunit() + ) + + workunits.append( + MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=SubTypesClass(typeNames=["ML Training Run"]), + ).as_workunit() + ) + + return workunits def _get_mlflow_registered_models(self) -> Iterable[RegisteredModel]: - """ - Get all Registered Models in MLflow Model Registry. - """ registered_models: Iterable[ RegisteredModel ] = self._traverse_mlflow_search_func( @@ -179,34 +388,57 @@ def _get_mlflow_registered_models(self) -> Iterable[RegisteredModel]: ) return registered_models + def _get_mlflow_experiments(self) -> Iterable[Experiment]: + experiments: Iterable[Experiment] = self._traverse_mlflow_search_func( + search_func=self.client.search_experiments, + ) + return experiments + + def _get_mlflow_runs_from_experiment(self, experiment: Experiment) -> List[Run]: + runs: List[Run] = self._traverse_mlflow_search_func( + search_func=self.client.search_runs, + experiment_ids=[experiment.experiment_id], + ) + return runs + @staticmethod def _traverse_mlflow_search_func( search_func: Callable[..., PagedList[T]], **kwargs: Any, ) -> Iterable[T]: - """ - Utility to traverse an MLflow search_* functions which return PagedList. - """ next_page_token = None while True: paged_list = search_func(page_token=next_page_token, **kwargs) - yield from paged_list.to_list() + yield from paged_list next_page_token = paged_list.token if not next_page_token: return + def _get_latest_version(self, registered_model: RegisteredModel) -> Optional[str]: + return ( + str(registered_model.latest_versions[0].version) + if registered_model.latest_versions + else None + ) + def _get_ml_group_workunit( self, registered_model: RegisteredModel, ) -> MetadataWorkUnit: - """ - Generate an MLModelGroup workunit for an MLflow Registered Model. - """ ml_model_group_urn = self._make_ml_model_group_urn(registered_model) ml_model_group_properties = MLModelGroupPropertiesClass( customProperties=registered_model.tags, description=registered_model.description, - createdAt=registered_model.creation_timestamp, + created=TimeStampClass( + time=registered_model.creation_timestamp, actor=None + ), + lastModified=TimeStampClass( + time=registered_model.last_updated_timestamp, + actor=None, + ), + version=VersionTagClass( + versionTag=self._get_latest_version(registered_model) + ), ) wu = self._create_workunit( urn=ml_model_group_urn, @@ -226,9 +458,6 @@ def _get_mlflow_model_versions( self, registered_model: RegisteredModel, ) -> Iterable[ModelVersion]: - """ - Get all Model Versions for each Registered Model. - """ filter_string = f"name = '{registered_model.name}'" model_versions: Iterable[ModelVersion] = self._traverse_mlflow_search_func( search_func=self.client.search_model_versions, @@ -236,51 +465,78 @@ def _get_mlflow_model_versions( ) return model_versions + def get_mlflow_model_versions_from_run(self, run_id): + filter_string = f"run_id = '{run_id}'" + + model_versions: Iterable[ModelVersion] = self._traverse_mlflow_search_func( + search_func=self.client.search_model_versions, + filter_string=filter_string, + ) + + return list(model_versions) + def _get_mlflow_run(self, model_version: ModelVersion) -> Union[None, Run]: - """ - Get a Run associated with a Model Version. Some MVs may exist without Run. - """ if model_version.run_id: run = self.client.get_run(model_version.run_id) return run else: return None + def _get_ml_model_workunits(self) -> Iterable[MetadataWorkUnit]: + registered_models = self._get_mlflow_registered_models() + for registered_model in registered_models: + yield self._get_ml_group_workunit(registered_model) + model_versions = self._get_mlflow_model_versions(registered_model) + for model_version in model_versions: + run = self._get_mlflow_run(model_version) + yield self._get_ml_model_properties_workunit( + registered_model=registered_model, + model_version=model_version, + run=run, + ) + yield self._get_global_tags_workunit(model_version=model_version) + def _get_ml_model_properties_workunit( self, registered_model: RegisteredModel, model_version: ModelVersion, run: Union[None, Run], ) -> MetadataWorkUnit: - """ - Generate an MLModel workunit for an MLflow Model Version. - Every Model Version is a DataHub MLModel entity associated with an MLModelGroup corresponding to a Registered Model. - If a model was registered without an associated Run then hyperparams and metrics are not available. - """ ml_model_group_urn = self._make_ml_model_group_urn(registered_model) ml_model_urn = self._make_ml_model_urn(model_version) + if run: - hyperparams = [ - MLHyperParamClass(name=k, value=str(v)) - for k, v in run.data.params.items() - ] - training_metrics = [ - MLMetricClass(name=k, value=str(v)) for k, v in run.data.metrics.items() - ] + # Use the same metrics and hyperparams from the run + hyperparams = self._get_run_params(run) + training_metrics = self._get_run_metrics(run) else: hyperparams = None training_metrics = None + + created_time = model_version.creation_timestamp + created_actor = ( + f"urn:li:platformResource:{model_version.user_id}" + if model_version.user_id + else None + ) + ml_model_properties = MLModelPropertiesClass( customProperties=model_version.tags, - externalUrl=self._make_external_url(model_version), + lastModified=TimeStampClass( + time=model_version.last_updated_timestamp, + actor=None, + ), + externalUrl=self._make_external_url_from_model_version(model_version), description=model_version.description, - date=model_version.creation_timestamp, + created=TimeStampClass( + time=created_time, + actor=created_actor, + ), version=VersionTagClass(versionTag=str(model_version.version)), hyperParams=hyperparams, trainingMetrics=training_metrics, - # mlflow tags are dicts, but datahub tags are lists. currently use only keys from mlflow tags tags=list(model_version.tags.keys()), - groups=[ml_model_group_urn], + groups=[str(ml_model_group_urn)], ) wu = self._create_workunit(urn=ml_model_urn, aspect=ml_model_properties) return wu @@ -293,24 +549,21 @@ def _make_ml_model_urn(self, model_version: ModelVersion) -> str: ) return urn - def _get_base_external_url_from_tracking_uri(self) -> Optional[str]: - if isinstance( - self.client.tracking_uri, str - ) and self.client.tracking_uri.startswith("http"): - return self.client.tracking_uri + def _make_external_url_from_model_version( + self, model_version: ModelVersion + ) -> Union[None, str]: + base_uri = self.client.tracking_uri + if base_uri.startswith("http"): + return f"{base_uri.rstrip('/')}/#/models/{model_version.name}/versions/{model_version.version}" else: return None - def _make_external_url(self, model_version: ModelVersion) -> Optional[str]: - """ - Generate URL for a Model Version to MLflow UI. - """ - base_uri = ( - self.config.base_external_url - or self._get_base_external_url_from_tracking_uri() - ) - if base_uri: - return f"{base_uri.rstrip('/')}/#/models/{model_version.name}/versions/{model_version.version}" + def _make_external_url_from_run( + self, experiment: Experiment, run: Run + ) -> Union[None, str]: + base_uri = self.client.tracking_uri + if base_uri.startswith("http"): + return f"{base_uri.rstrip('/')}/#/experiments/{experiment.experiment_id}/runs/{run.info.run_id}" else: return None