diff --git a/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/common/mappers/TimeStampMapper.java b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/common/mappers/TimeStampMapper.java new file mode 100644 index 00000000000000..9792079e66f646 --- /dev/null +++ b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/common/mappers/TimeStampMapper.java @@ -0,0 +1,28 @@ +package com.linkedin.datahub.graphql.types.common.mappers; + +import com.linkedin.datahub.graphql.QueryContext; +import com.linkedin.datahub.graphql.generated.TimeStamp; +import com.linkedin.datahub.graphql.types.mappers.ModelMapper; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +public class TimeStampMapper implements ModelMapper { + + public static final TimeStampMapper INSTANCE = new TimeStampMapper(); + + public static TimeStamp map( + @Nullable QueryContext context, @Nonnull final com.linkedin.common.TimeStamp timestamp) { + return INSTANCE.apply(context, timestamp); + } + + @Override + public TimeStamp apply( + @Nullable QueryContext context, @Nonnull final com.linkedin.common.TimeStamp timestamp) { + final TimeStamp result = new TimeStamp(); + result.setTime(timestamp.getTime()); + if (timestamp.hasActor()) { + result.setActor(timestamp.getActor().toString()); + } + return result; + } +} diff --git a/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/dataprocessinst/mappers/DataProcessInstanceMapper.java b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/dataprocessinst/mappers/DataProcessInstanceMapper.java index 7a4d342281fe54..4e37dc80002ffa 100644 --- a/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/dataprocessinst/mappers/DataProcessInstanceMapper.java +++ b/datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/dataprocessinst/mappers/DataProcessInstanceMapper.java @@ -7,7 +7,7 @@ import com.linkedin.datahub.graphql.QueryContext; import com.linkedin.datahub.graphql.generated.DataProcessInstance; import com.linkedin.datahub.graphql.generated.EntityType; -import com.linkedin.datahub.graphql.types.common.mappers.AuditStampMapper; +import com.linkedin.datahub.graphql.types.common.mappers.TimeStampMapper; import com.linkedin.datahub.graphql.types.common.mappers.util.MappingHelper; import com.linkedin.datahub.graphql.types.mappers.ModelMapper; import com.linkedin.dataprocess.DataProcessInstanceProperties; @@ -50,11 +50,13 @@ private void mapDataProcessProperties( DataProcessInstanceProperties dataProcessInstanceProperties = new DataProcessInstanceProperties(dataMap); dpi.setName(dataProcessInstanceProperties.getName()); - if (dataProcessInstanceProperties.hasCreated()) { - dpi.setCreated(AuditStampMapper.map(context, dataProcessInstanceProperties.getCreated())); - } + + com.linkedin.datahub.graphql.generated.DataProcessInstanceProperties properties = + new com.linkedin.datahub.graphql.generated.DataProcessInstanceProperties(); + properties.setCreated(TimeStampMapper.map(context, dataProcessInstanceProperties.getCreated())); if (dataProcessInstanceProperties.hasExternalUrl()) { dpi.setExternalUrl(dataProcessInstanceProperties.getExternalUrl().toString()); } + dpi.setProperties(properties); } } diff --git a/datahub-graphql-core/src/main/resources/entity.graphql b/datahub-graphql-core/src/main/resources/entity.graphql index 049527e5d77e3b..4c92080d5bca3c 100644 --- a/datahub-graphql-core/src/main/resources/entity.graphql +++ b/datahub-graphql-core/src/main/resources/entity.graphql @@ -14,6 +14,11 @@ Root type used for fetching DataHub Metadata Coming soon listEntity queries for listing all entities of a given type """ type Query { + _empty: String +} + + +extend type Query { """ Fetch a Data Platform by primary key (urn) """ @@ -262,8 +267,16 @@ type Query { Fetch all Business Attributes """ listBusinessAttributes(input: ListBusinessAttributesInput!): ListBusinessAttributesResult + + """ + Fetch a Data Process Instance by primary key (urn) + """ + dataProcessInstance(urn: String!): DataProcessInstance + + } + """ An ERModelRelationship is a high-level abstraction that dictates what datasets fields are erModelRelationshiped. """ @@ -6581,12 +6594,12 @@ type DataProcessInstance implements EntityWithRelationships & Entity { """ When the run was kicked off """ - created: AuditStamp + created: AuditStamp! """ The name of the data process """ - name: String + name: String! """ Edges extending from this entity. @@ -9483,6 +9496,18 @@ type AuditStamp { actor: String } +type TimeStamp { + """ + When did the event occur + """ + time: Long! + + """ + The actor involved in the event + """ + actor: String +} + """ Input for creating a new group """ @@ -9829,9 +9854,13 @@ type MLModelGroup implements EntityWithRelationships & Entity & BrowsableEntity type MLModelGroupProperties { -description: String + description: String - createdAt: Long + createdAt: Long @deprecated(reason: "Use `created` instead") + + created: TimeStamp + + lastModified: TimeStamp version: VersionTag } @@ -9994,9 +10023,16 @@ type MLMetric { type MLModelProperties { + """ + The name of the model used in display + """ + name: String! + description: String - date: Long + date: Long @deprecated(reason: "Use `created` instead") + + lastModified: TimeStamp version: String @@ -10008,6 +10044,8 @@ type MLModelProperties { trainingMetrics: [MLMetric] + trainingJobs: [String!] + mlFeatures: [String!] tags: [String!] @@ -10017,6 +10055,8 @@ type MLModelProperties { customProperties: [CustomPropertiesEntry!] externalUrl: String + + created: TimeStamp } type MLFeatureProperties { @@ -12793,3 +12833,65 @@ type CronSchedule { """ timezone: String! } + + +type DataProcessInstanceProperties { + + name: String! + + externalUrl: String + + created: TimeStamp + + customProperties: [CustomPropertiesEntry!] +} + + +type MLTrainingRunProperties { + + id: String + + outputUrls: [String] + + hyperParams: [MLHyperParam] + + trainingMetrics: [MLMetric] +} + +extend type DataProcessInstance { + + """ + Additional read only properties associated with the Data Job + """ + properties: DataProcessInstanceProperties + + """ + Additional properties when subtype is Training Run + """ + mlTrainingRunProperties: MLTrainingRunProperties + + """ + The specific instance of the data platform that this entity belongs to + """ + dataPlatformInstance: DataPlatformInstance + + """ + Sub Types that this entity implements + """ + subTypes: SubTypes + + """ + The parent container in which the entity resides + """ + container: Container + + """ + Standardized platform urn where the data process instance is defined + """ + platform: DataPlatform! + + """ + Recursively get the lineage of containers for this entity + """ + parentContainers: ParentContainersResult +} \ No newline at end of file diff --git a/datahub-web-react/src/graphql/dataProcessInstance.graphql b/datahub-web-react/src/graphql/dataProcessInstance.graphql new file mode 100644 index 00000000000000..aacf85bb6bc6eb --- /dev/null +++ b/datahub-web-react/src/graphql/dataProcessInstance.graphql @@ -0,0 +1,182 @@ +fragment processInstanceRelationshipResults on EntityRelationshipsResult { + start + count + total + relationships { + type + direction + entity { + urn + type + ... on Dataset { + name + properties { + name + description + qualifiedName + } + editableProperties { + description + } + platform { + ...platformFields + } + subTypes { + typeNames + } + status { + removed + } + } + ... on DataJob { + urn + type + dataFlow { + ...nonRecursiveDataFlowFields + } + jobId + properties { + name + description + externalUrl + customProperties { + key + value + } + } + deprecation { + ...deprecationFields + } + dataPlatformInstance { + ...dataPlatformInstanceFields + } + subTypes { + typeNames + } + editableProperties { + description + } + status { + removed + } + } + } + } +} + +fragment dataProcessInstanceFields on DataProcessInstance { + urn + type + name + platform { + ...platformFields + } + parentContainers { + ...parentContainersFields + } + container { + ...entityContainer + } + subTypes { + typeNames + } + properties { + name + createdTS: created { + time + actor + } + customProperties { + key + value + } + } + mlTrainingRunProperties { + outputUrls + trainingMetrics { + name + description + value + } + hyperParams { + name + description + value + } + } + dataPlatformInstance { + ...dataPlatformInstanceFields + } + state(startTimeMillis: null, endTimeMillis: null, limit: 1) { + status + attempt + result { + resultType + nativeResultType + } + timestampMillis + durationMillis + } + relationships(input: { types: ["InstanceOf", "Consumes", "Produces"], direction: OUTGOING, start: 0, count: 50 }) { + ...processInstanceRelationshipResults + } +} + +query getDataProcessInstance($urn: String!) { + dataProcessInstance(urn: $urn) { + urn + type + platform { + ...platformFields + } + parentContainers { + ...parentContainersFields + } + subTypes { + typeNames + } + container { + ...entityContainer + } + name + properties { + name + created { + time + actor + } + } + mlTrainingRunProperties { + id + outputUrls + trainingMetrics { + name + description + value + } + hyperParams { + name + description + value + } + } + relationships( + input: { types: ["InstanceOf", "Consumes", "Produces"], direction: OUTGOING, start: 0, count: 50 } + ) { + ...processInstanceRelationshipResults + } + dataPlatformInstance { + ...dataPlatformInstanceFields + } + state(startTimeMillis: null, endTimeMillis: null, limit: 1) { + status + attempt + result { + resultType + nativeResultType + } + timestampMillis + durationMillis + } + } +} \ No newline at end of file diff --git a/datahub-web-react/src/graphql/fragments.graphql b/datahub-web-react/src/graphql/fragments.graphql index 788c68349b4268..fbca516c3c7347 100644 --- a/datahub-web-react/src/graphql/fragments.graphql +++ b/datahub-web-react/src/graphql/fragments.graphql @@ -860,8 +860,16 @@ fragment nonRecursiveMLModel on MLModel { ...ownershipFields } properties { + name description - date + created { + time + actor + } + lastModified { + time + actor + } externalUrl version type @@ -870,6 +878,7 @@ fragment nonRecursiveMLModel on MLModel { description value } + trainingJobs hyperParams { name description @@ -954,6 +963,10 @@ fragment nonRecursiveMLModelGroupFields on MLModelGroup { } properties { description + created { + time + actor + } } browsePathV2 { ...browsePathV2Fields @@ -1544,4 +1557,4 @@ fragment businessAttribute on BusinessAttributeAssociation { } } associatedUrn -} +} \ No newline at end of file diff --git a/datahub-web-react/src/graphql/lineage.graphql b/datahub-web-react/src/graphql/lineage.graphql index ee05811cbb72de..c90bf236227fe0 100644 --- a/datahub-web-react/src/graphql/lineage.graphql +++ b/datahub-web-react/src/graphql/lineage.graphql @@ -259,6 +259,9 @@ fragment lineageNodeProperties on EntityWithRelationships { name description origin + tags { + ...globalTagsFields + } platform { ...platformFields } @@ -289,6 +292,21 @@ fragment lineageNodeProperties on EntityWithRelationships { status { removed } + properties { + createdTS: created { + time + actor + } + modelVersion: version + tags + customProperties { + key + value + } + } + editableProperties { + description + } structuredProperties { properties { ...structuredPropertiesFields @@ -328,6 +346,9 @@ fragment lineageNodeProperties on EntityWithRelationships { urn type } + ... on DataProcessInstance { + ...dataProcessInstanceFields + } } fragment lineageFields on EntityWithRelationships { diff --git a/datahub-web-react/src/graphql/mlModel.graphql b/datahub-web-react/src/graphql/mlModel.graphql index 2192888caef701..c02defb990e61b 100644 --- a/datahub-web-react/src/graphql/mlModel.graphql +++ b/datahub-web-react/src/graphql/mlModel.graphql @@ -20,6 +20,23 @@ query getMLModel($urn: String!) { } } } + trainedBy: relationships(input: { types: ["TrainedBy"], direction: OUTGOING, start: 0, count: 100 }) { + start + count + total + relationships { + type + direction + entity { + ... on DataProcessInstance { + urn + name + type + ...dataProcessInstanceFields + } + } + } + } privileges { ...entityPrivileges } diff --git a/datahub-web-react/src/graphql/mlModelGroup.graphql b/datahub-web-react/src/graphql/mlModelGroup.graphql index 81ab65d0b9a08d..ff4148579452bc 100644 --- a/datahub-web-react/src/graphql/mlModelGroup.graphql +++ b/datahub-web-react/src/graphql/mlModelGroup.graphql @@ -2,6 +2,18 @@ query getMLModelGroup($urn: String!) { mlModelGroup(urn: $urn) { urn type + properties { + description + created { + time + actor + } + lastModified { + time + actor + } + } + ...nonRecursiveMLModelGroupFields incoming: relationships( input: { diff --git a/datahub-web-react/src/graphql/search.graphql b/datahub-web-react/src/graphql/search.graphql index 58c9a51f3d7e90..90fc8a619d599c 100644 --- a/datahub-web-react/src/graphql/search.graphql +++ b/datahub-web-react/src/graphql/search.graphql @@ -930,6 +930,9 @@ fragment searchResultsWithoutSchemaField on Entity { ... on BusinessAttribute { ...businessAttributeFields } + ... on DataProcessInstance { + ...dataProcessInstanceFields + } ... on StructuredPropertyEntity { ...structuredPropertyFields } diff --git a/metadata-ingestion/examples/ai/data_job_instance.py b/metadata-ingestion/examples/ai/data_job_instance.py new file mode 100644 index 00000000000000..138d99faa32d40 --- /dev/null +++ b/metadata-ingestion/examples/ai/data_job_instance.py @@ -0,0 +1,155 @@ +from dataclasses import dataclass +from typing import Iterable, Optional, Union + +import datahub.metadata.schema_classes as models +from datahub.api.entities.datajob import DataFlow, DataJob +from datahub.api.entities.dataprocess.dataprocess_instance import ( + DataProcessInstance, + InstanceRunResult, +) +from datahub.api.entities.dataset.dataset import Dataset +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.graph.client import get_default_graph +from datahub.metadata.urns import DatasetUrn, DataPlatformUrn +from datahub.emitter.mcp_builder import ContainerKey + +ORCHESTRATOR_MLFLOW = "mlflow" +ORCHESTRATOR_AIRFLOW = "airflow" + + +class ContainerKeyWithId(ContainerKey): + id: str + + +@dataclass +class Container: + key: ContainerKeyWithId + subtype: str + name: Optional[str] = None + description: Optional[str] = None + + def generate_mcp( + self, + ) -> Iterable[ + Union[models.MetadataChangeProposalClass, MetadataChangeProposalWrapper] + ]: + container_urn = self.key.as_urn() + + container_subtype = models.SubTypesClass(typeNames=[self.subtype]) + + container_info = models.ContainerPropertiesClass( + name=self.name or self.key.id, + description=self.description, + customProperties={}, + ) + + browse_path = models.BrowsePathsV2Class(path=[]) + + dpi = models.DataPlatformInstanceClass( + platform=self.key.platform, + instance=self.key.instance, + ) + + yield from MetadataChangeProposalWrapper.construct_many( + entityUrn=container_urn, + aspects=[container_subtype, container_info, browse_path, dpi], + ) + + +def generate_pipeline( + pipeline_name: str, + orchestrator: str, +) -> Iterable[Union[models.MetadataChangeProposalClass, MetadataChangeProposalWrapper]]: + data_flow = DataFlow( + id=pipeline_name, + orchestrator=orchestrator, + cluster="default", + name=pipeline_name, + ) + + data_job = DataJob(id="training", flow_urn=data_flow.urn, name="Training") + + dataset_1 = Dataset( + id="input_data", + name="input_data", + description="Input data", + properties={}, + platform="s3", + schema=None, + ) + + dataset_2 = Dataset( + id="output_data", + name="output_data", + description="Output data", + properties={}, + platform="s3", + schema=None, + ) + + if orchestrator == ORCHESTRATOR_MLFLOW: + # For Mlflow we create an experiment and a run + + experiment = Container( + key=ContainerKeyWithId( + platform=str(DataPlatformUrn.create_from_id("mlflow")), + id="experiment_1", + ), + subtype="Experiment", + name="Experiment 1", + description="Experiment 1 description", + ) + + yield from experiment.generate_mcp() + + data_process_instance = DataProcessInstance.from_container( + container_key=experiment.key, id="training_2024_01_01" + ) + + if orchestrator == ORCHESTRATOR_AIRFLOW: + # For Airflow we create a DAG and a task + data_process_instance = DataProcessInstance.from_datajob( + datajob=data_job, id="training_2024_01_01" + ) + yield from data_flow.generate_mcp() + yield from data_job.generate_mcp() + + # data_process_instance = DataProcessInstance.from_datajob( + # datajob=data_job, id="training_2024_01_01" + # ) + data_process_instance.subtype = "Training Run" + data_process_instance.inlets = [DatasetUrn.from_string(dataset_1.urn)] + data_process_instance.outlets = [DatasetUrn.from_string(dataset_2.urn)] + + yield from dataset_1.generate_mcp() + yield from dataset_2.generate_mcp() + print(f"Generating for {data_process_instance.urn}") + yield from data_process_instance.generate_mcp( + created_ts_millis=None, materialize_iolets=False + ) + # Finally generate the start and end events + # start date is Dec 3rd 2024 at 10am UTC + start_time_millis = 1735689600000 + # the job ran for 1 hour + end_time_millis = start_time_millis + 60 * 60 * 1000 + yield from data_process_instance.start_event_mcp( + # 5 days ago + start_timestamp_millis=start_time_millis + ) + yield from data_process_instance.end_event_mcp( + end_timestamp_millis=end_time_millis, + result=InstanceRunResult.SUCCESS, + start_timestamp_millis=start_time_millis, + ) + + +if __name__ == "__main__": + with get_default_graph() as graph: + for mcp in generate_pipeline( + "training_pipeline_mlflow", orchestrator=ORCHESTRATOR_MLFLOW + ): + graph.emit(mcp) + for mcp in generate_pipeline( + "training_pipeline_airflow", orchestrator=ORCHESTRATOR_AIRFLOW + ): + graph.emit(mcp) \ No newline at end of file diff --git a/metadata-ingestion/examples/ai/demo_script.py b/metadata-ingestion/examples/ai/demo_script.py new file mode 100644 index 00000000000000..46a1bd56b5603a --- /dev/null +++ b/metadata-ingestion/examples/ai/demo_script.py @@ -0,0 +1,312 @@ +import time +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Iterable, List, Optional, Union +import random + +import datahub.metadata.schema_classes as models +from datahub.api.entities.datajob import DataFlow, DataJob +from datahub.api.entities.dataprocess.dataprocess_instance import ( + DataProcessInstance, + InstanceRunResult, +) +from datahub.api.entities.dataset.dataset import Dataset +from datahub.emitter.mcp import MetadataChangeProposalWrapper +from datahub.ingestion.graph.client import get_default_graph +from datahub.metadata.urns import DatasetUrn, DataPlatformUrn, MlModelGroupUrn, MlModelUrn +from datahub.emitter.mcp_builder import ContainerKey + +ORCHESTRATOR_MLFLOW = "mlflow" +ORCHESTRATOR_AIRFLOW = "airflow" + + +class ContainerKeyWithId(ContainerKey): + id: str + + +@dataclass +class Container: + key: ContainerKeyWithId + subtype: str + name: Optional[str] = None + description: Optional[str] = None + + def generate_mcp( + self, + ) -> Iterable[ + Union[models.MetadataChangeProposalClass, MetadataChangeProposalWrapper] + ]: + container_urn = self.key.as_urn() + + container_subtype = models.SubTypesClass(typeNames=[self.subtype]) + + container_info = models.ContainerPropertiesClass( + name=self.name or self.key.id, + description=self.description, + customProperties={}, + ) + + browse_path = models.BrowsePathsV2Class(path=[]) + + dpi = models.DataPlatformInstanceClass( + platform=self.key.platform, + instance=self.key.instance, + ) + + yield from MetadataChangeProposalWrapper.construct_many( + entityUrn=container_urn, + aspects=[container_subtype, container_info, browse_path, dpi], + ) + + +def create_model( + model_name: str, + model_group_urn: str, + data_process_instance_urn: str, + tags: List[str], + version_aliases: List[str], + index: int, + training_metrics: List[models.MLMetricClass], + hyper_params: List[models.MLHyperParamClass], + model_description: str, + created_at: int, +) -> Iterable[MetadataChangeProposalWrapper]: + model_urn = MlModelUrn(platform="mlflow", name=model_name) + model_info = models.MLModelPropertiesClass( + displayName=f"{model_name}", + description=model_description, + version=models.VersionTagClass(versionTag=f"{index}"), + groups=[str(model_group_urn)], + trainingJobs=[str(data_process_instance_urn)], + date=created_at, + lastModified=created_at, + createdBy=f"user_{index}", + versionAliases=[ + models.VersionAssociationClass( + version=models.VersionTagClass(versionTag=alias), + ) + for alias in version_aliases + ], + tags=tags, + trainingMetrics=training_metrics, + hyperParams=hyper_params, + ) + + yield MetadataChangeProposalWrapper( + entityUrn=model_urn, + aspect=model_info, + ) + + +def generate_pipeline( + pipeline_name: str, + orchestrator: str, +) -> Iterable[Union[models.MetadataChangeProposalClass, MetadataChangeProposalWrapper]]: + data_flow = DataFlow( + id=pipeline_name, + orchestrator=orchestrator, + cluster="default", + name=pipeline_name, + ) + + data_job = DataJob(id="training", flow_urn=data_flow.urn, name="Training") + + input_dataset = Dataset( + id="airline_passengers", + name="Airline Passengers", + description="Monthly airline passenger data", + properties={}, + platform="s3", + schema=None, + ) + + if orchestrator == ORCHESTRATOR_MLFLOW: + experiment = Container( + key=ContainerKeyWithId( + platform=str(DataPlatformUrn.create_from_id("mlflow")), + id="airline_forecast_experiment", + ), + subtype="Experiment", + name="Airline Forecast Experiment", + description="Experiment for forecasting airline passengers", + ) + + yield from experiment.generate_mcp() + + model_group_urn = MlModelGroupUrn(platform="mlflow", name="airline_forecast_models") + current_time = int(time.time() * 1000) + model_group_info = models.MLModelGroupPropertiesClass( + description="ML models for airline passenger forecasting", + customProperties={ + "stage": "production", + "team": "data_science", + }, + createdAt=current_time, + lastModified=current_time, + createdBy="john_doe", + ) + + yield MetadataChangeProposalWrapper( + entityUrn=model_group_urn, + aspect=model_group_info, + ) + + print("model_group_urn: ", model_group_urn) + + model_aliases = ["challenger", "champion", "production", "experimental", "deprecated"] + model_tags = ["stage:production", "stage:development", "team:data_science", "team:ml_engineering", "team:analytics"] + + model_dict = { + "arima_model_1": "ARIMA model for airline passenger forecasting", + "arima_model_2": "Enhanced ARIMA model with seasonal components", + "arima_model_3": "ARIMA model optimized for long-term forecasting", + "arima_model_4": "ARIMA model with hyperparameter tuning", + "arima_model_5": "ARIMA model trained on extended dataset", + } + + # Generate run timestamps within the last month + end_time = int(time.time() * 1000) # Current timestamp in milliseconds + start_time = end_time - (30 * 24 * 60 * 60 * 1000) # 30 days ago in milliseconds + run_timestamps = [ + start_time + (i * 5 * 24 * 60 * 60 * 1000) # 5-day intervals + for i in range(5) + ] + + run_dict = { + "run_1": {"start_time": run_timestamps[0], "duration": 45, "result": InstanceRunResult.SUCCESS}, + "run_2": {"start_time": run_timestamps[1], "duration": 60, "result": InstanceRunResult.FAILURE}, + "run_3": {"start_time": run_timestamps[2], "duration": 55, "result": InstanceRunResult.SUCCESS}, + "run_4": {"start_time": run_timestamps[3], "duration": 70, "result": InstanceRunResult.SUCCESS}, + "run_5": {"start_time": run_timestamps[4], "duration": 50, "result": InstanceRunResult.FAILURE}, + } + + for i, (model_name, model_description) in enumerate(model_dict.items(), start=1): + run_id = f"run_{i}" + data_process_instance = DataProcessInstance.from_container( + container_key=experiment.key, id=run_id + ) + + data_process_instance.subtype = "Training Run" + data_process_instance.inlets = [DatasetUrn.from_string(input_dataset.urn)] + + output_dataset = Dataset( + id=f"passenger_forecast_24_12_0{i}", + name=f"Passenger Forecast 24_12_0{i}", + description=f"Forecasted airline passenger numbers for run {i}", + properties={}, + platform="s3", + schema=None, + ) + yield from output_dataset.generate_mcp() + + data_process_instance.outlets = [DatasetUrn.from_string(output_dataset.urn)] + + # Training metrics and hyperparameters + training_metrics = [ + models.MLMetricClass( + name="accuracy", + value=str(random.uniform(0.7, 0.99)), + description="Test accuracy" + ), + models.MLMetricClass( + name="f1_score", + value=str(random.uniform(0.7, 0.99)), + description="Test F1 score" + ) + ] + hyper_params = [ + models.MLHyperParamClass( + name="n_estimators", + value=str(random.randint(50, 200)), + description="Number of trees" + ), + models.MLHyperParamClass( + name="max_depth", + value=str(random.randint(5, 15)), + description="Maximum tree depth" + ) + ] + + # DPI properties + created_at = int(time.time() * 1000) + print(start_time) + dpi_props = models.DataProcessInstancePropertiesClass( + name=f"Training {run_id}", + created=models.AuditStampClass(time=created_at, actor="urn:li:corpuser:datahub"), + createdAt=int(created_at/1000), + createdBy="jane_doe", + loggedModels=["sklearn"], + artifactsLocation="s3://mlflow/artifacts", + externalUrl="http://mlflow:5000", + customProperties={ + "framework": "statsmodels", + "python_version": "3.8", + }, + id=run_id, + trainingMetrics=training_metrics, + hyperParams=hyper_params, + ) + + yield from data_process_instance.generate_mcp( + created_ts_millis=created_at, materialize_iolets=True + ) + + yield MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=dpi_props, + ) + + # Generate start and end events + start_time_millis = run_dict[run_id]["start_time"] + duration_minutes = run_dict[run_id]["duration"] + end_time_millis = start_time_millis + duration_minutes * 60000 + result = run_dict[run_id]["result"] + result_type = "SUCCESS" if result == InstanceRunResult.SUCCESS else "FAILURE" + + yield from data_process_instance.start_event_mcp( + start_timestamp_millis=start_time_millis + ) + yield from data_process_instance.end_event_mcp( + end_timestamp_millis=end_time_millis, + result=result, + result_type=result_type, + start_timestamp_millis=start_time_millis, + ) + + print("data_process_instance.urn: ", data_process_instance.urn) + print("start Time:", start_time_millis) + print("start Time:", time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time_millis/1000))) + + # Model + selected_aliases = random.sample(model_aliases, k=random.randint(1, 2)) + selected_tags = random.sample(model_tags, 2) + yield from create_model( + model_name=model_name, + model_group_urn=str(model_group_urn), + data_process_instance_urn=str(data_process_instance.urn), + tags=selected_tags, + version_aliases=selected_aliases, + index=i, + training_metrics=training_metrics, + hyper_params=hyper_params, + model_description=model_description, + created_at=created_at, + ) + + if orchestrator == ORCHESTRATOR_AIRFLOW: + yield from data_flow.generate_mcp() + yield from data_job.generate_mcp() + + yield from input_dataset.generate_mcp() + + +if __name__ == "__main__": + with get_default_graph() as graph: + for mcp in generate_pipeline( + "airline_forecast_pipeline_mlflow", orchestrator=ORCHESTRATOR_MLFLOW + ): + graph.emit(mcp) + for mcp in generate_pipeline( + "airline_forecast_pipeline_airflow", orchestrator=ORCHESTRATOR_AIRFLOW + ): + graph.emit(mcp) \ No newline at end of file diff --git a/metadata-ingestion/examples/recipes/mlflow_to_datahub.dhub.yaml b/metadata-ingestion/examples/recipes/mlflow_to_datahub.dhub.yaml new file mode 100644 index 00000000000000..07e9ed5d786cd9 --- /dev/null +++ b/metadata-ingestion/examples/recipes/mlflow_to_datahub.dhub.yaml @@ -0,0 +1,9 @@ +source: + type: mlflow + config: + tracking_uri: "http://127.0.0.1:5000" + +sink: + type: datahub-rest + config: + server: "http://localhost:8080" \ No newline at end of file diff --git a/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py b/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py index d406fa36e00db6..f594ac1f646560 100644 --- a/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py +++ b/metadata-ingestion/src/datahub/api/entities/dataprocess/dataprocess_instance.py @@ -14,17 +14,22 @@ RunResultType, ) from datahub.metadata.schema_classes import ( - AuditStampClass, + DataPlatformInstanceClass, DataProcessInstanceRunEventClass, DataProcessInstanceRunResultClass, DataProcessRunStatusClass, DataProcessTypeClass, + TimeStampClass, + SubTypesClass, + ContainerClass ) +from datahub.metadata.urns import DataPlatformInstanceUrn, DataPlatformUrn, ContainerUrn from datahub.utilities.str_enum import StrEnum from datahub.utilities.urns.data_flow_urn import DataFlowUrn from datahub.utilities.urns.data_job_urn import DataJobUrn from datahub.utilities.urns.data_process_instance_urn import DataProcessInstanceUrn from datahub.utilities.urns.dataset_urn import DatasetUrn +from datahub.emitter.mcp_builder import ContainerKey class DataProcessInstanceKey(DatahubKey): @@ -61,7 +66,7 @@ class DataProcessInstance: orchestrator: str cluster: Optional[str] = None type: str = DataProcessTypeClass.BATCH_SCHEDULED - template_urn: Optional[Union[DataJobUrn, DataFlowUrn, DatasetUrn]] = None + template_urn: Optional[Union[DataJobUrn, DataFlowUrn, DatasetUrn, ContainerUrn]] = None parent_instance: Optional[DataProcessInstanceUrn] = None properties: Dict[str, str] = field(default_factory=dict) url: Optional[str] = None @@ -71,6 +76,10 @@ class DataProcessInstance: _template_object: Optional[Union[DataJob, DataFlow]] = field( init=False, default=None, repr=False ) + data_platform: Optional[str] = None + data_plaform_instance: Optional[str] = None + subtype: Optional[str] = None + container_urn: Optional[str] = None def __post_init__(self): self.urn = DataProcessInstanceUrn( @@ -80,6 +89,36 @@ def __post_init__(self): id=self.id, ).guid() ) + if self.data_platform is None: + self.data_platform = self.orchestrator + + try: + # We first try to create from string assuming its an urn + self.data_platform = str( + DataPlatformUrn.create_from_string(self.data_platform) + ) + except Exception: + # If it fails, we assume its an id + self.data_platform = str(DataPlatformUrn.create_from_id(self.data_platform)) + + if self.data_plaform_instance is None and self.cluster is not None: + self.data_plaform_instance = self.cluster + + if self.data_plaform_instance is not None: + try: + # We first try to create from string assuming its an urn + self.data_plaform_instance = str( + DataPlatformInstanceUrn.create_from_string( + self.data_plaform_instance + ) + ) + except Exception: + # If it fails, we assume its an id + self.data_plaform_instance = str( + DataPlatformInstanceUrn( + platform=self.data_platform, instance=self.data_plaform_instance + ) + ) def start_event_mcp( self, start_timestamp_millis: int, attempt: Optional[int] = None @@ -244,7 +283,7 @@ def generate_mcp( entityUrn=str(self.urn), aspect=DataProcessInstanceProperties( name=self.id, - created=AuditStampClass( + created=TimeStampClass( time=created_ts_millis or int(time.time() * 1000), actor="urn:li:corpuser:datahub", ), @@ -269,6 +308,29 @@ def generate_mcp( ) yield mcp + assert self.data_platform + + mcp = MetadataChangeProposalWrapper( + entityUrn=str(self.urn), + aspect=DataPlatformInstanceClass( + platform=self.data_platform, instance=self.data_plaform_instance + ), + ) + yield mcp + + if self.subtype: + mcp = MetadataChangeProposalWrapper( + entityUrn=str(self.urn), aspect=SubTypesClass(typeNames=[self.subtype]) + ) + yield mcp + + if self.container_urn: + mcp = MetadataChangeProposalWrapper( + entityUrn=str(self.urn), + aspect=ContainerClass(container=self.container_urn), + ) + yield mcp + yield from self.generate_inlet_outlet_mcp(materialize_iolets=materialize_iolets) @staticmethod @@ -331,6 +393,29 @@ def from_datajob( dpi.outlets = datajob.outlets return dpi + @staticmethod + def from_container( + container_key: ContainerKey, + id: str, + ) -> "DataProcessInstance": + """ + Generates DataProcessInstance from a Container + + :param datajob: (DataJob) the datajob from generate the DataProcessInstance + :param id: (str) the id for the DataProcessInstance + :param clone_inlets: (bool) whether to clone datajob's inlets + :param clone_outlets: (bool) whether to clone datajob's outlets + :return: DataProcessInstance + """ + dpi: DataProcessInstance = DataProcessInstance( + id=id, + orchestrator=DataPlatformUrn.from_string(container_key.platform).platform_name, + template_urn=None, + container_urn = container_key.as_urn(), + ) + + return dpi + @staticmethod def from_dataflow(dataflow: DataFlow, id: str) -> "DataProcessInstance": """ @@ -376,4 +461,4 @@ def generate_inlet_outlet_mcp( yield MetadataChangeProposalWrapper( entityUrn=str(iolet), aspect=iolet.to_key_aspect(), - ) + ) \ No newline at end of file diff --git a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py index 26d160acf330cf..7a82c270821000 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py +++ b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py @@ -1,13 +1,15 @@ from dataclasses import dataclass -from typing import Any, Callable, Iterable, Optional, TypeVar, Union +from typing import Any, Callable, Iterable, Optional, TypeVar, Union, List +import time from mlflow import MlflowClient from mlflow.entities import Run -from mlflow.entities.model_registry import ModelVersion, RegisteredModel +from mlflow.entities.model_registry import ModelVersion, RegisteredModel, Experiment from mlflow.store.entities import PagedList from pydantic.fields import Field import datahub.emitter.mce_builder as builder +from datahub.emitter.mcp_builder import ContainerKey from datahub.configuration.source_common import EnvConfigMixin from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.common import PipelineContext @@ -29,39 +31,78 @@ TagAssociationClass, TagPropertiesClass, VersionTagClass, - _Aspect, + DataProcessInstanceRunEventClass, + DataProcessInstancePropertiesClass, + ContainerPropertiesClass, + TimeStampClass, + DataProcessRunStatusClass, + SubTypesClass, + DataPlatformInstanceClass, + BrowsePathsV2Class, + MetadataChangeProposalClass, + MLTrainingRunPropertiesClass, + DataProcessInstanceRunResultClass, +) +from datahub.metadata.urns import DataPlatformUrn +from datahub.api.entities.dataprocess.dataprocess_instance import ( + DataProcessInstance, ) T = TypeVar("T") +class ContainerKeyWithId(ContainerKey): + id: str + + +@dataclass +class Container: + key: ContainerKeyWithId + subtype: str + name: Optional[str] = None + description: Optional[str] = None + + def generate_mcp( + self, + ) -> Iterable[ + Union[MetadataChangeProposalClass, MetadataChangeProposalWrapper] + ]: + container_urn = self.key.as_urn() + + container_subtype = SubTypesClass(typeNames=[self.subtype]) + + container_info = ContainerPropertiesClass( + name=self.name or self.key.id, + description=self.description, + customProperties={}, + ) + + browse_path = BrowsePathsV2Class(path=[]) + + dpi = DataPlatformInstanceClass( + platform=self.key.platform, + instance=self.key.instance, + ) + + return MetadataChangeProposalWrapper.construct_many( + entityUrn=container_urn, + aspects=[container_subtype, container_info, browse_path, dpi], + ) + + class MLflowConfig(EnvConfigMixin): tracking_uri: Optional[str] = Field( default=None, - description=( - "Tracking server URI. If not set, an MLflow default tracking_uri is used" - " (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)" - ), + description="Tracking server URI. If not set, an MLflow default tracking_uri is used (local `mlruns/` directory or `MLFLOW_TRACKING_URI` environment variable)", ) registry_uri: Optional[str] = Field( default=None, - description=( - "Registry server URI. If not set, an MLflow default registry_uri is used" - " (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)" - ), + description="Registry server URI. If not set, an MLflow default registry_uri is used (value of tracking_uri or `MLFLOW_REGISTRY_URI` environment variable)", ) model_name_separator: str = Field( default="_", description="A string which separates model name from its version (e.g. model_1 or model-1)", ) - base_external_url: Optional[str] = Field( - default=None, - description=( - "Base URL to use when constructing external URLs to MLflow." - " If not set, tracking_uri is used if it's an HTTP URL." - " If neither is set, external URLs are not generated." - ), - ) @dataclass @@ -119,11 +160,9 @@ def get_report(self) -> SourceReport: def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: yield from self._get_tags_workunits() yield from self._get_ml_model_workunits() + yield from self._get_experiment_workunits() def _get_tags_workunits(self) -> Iterable[MetadataWorkUnit]: - """ - Create tags for each Stage in MLflow Model Registry. - """ for stage_info in self.registered_model_stages_info: tag_urn = self._make_stage_tag_urn(stage_info.name) tag_properties = TagPropertiesClass( @@ -142,71 +181,208 @@ def _make_stage_tag_urn(self, stage_name: str) -> str: def _make_stage_tag_name(self, stage_name: str) -> str: return f"{self.platform}_{stage_name.lower()}" - def _create_workunit(self, urn: str, aspect: _Aspect) -> MetadataWorkUnit: - """ - Utility to create an MCP workunit. - """ + def _create_workunit(self, urn: str, aspect: Any) -> MetadataWorkUnit: return MetadataChangeProposalWrapper( entityUrn=urn, aspect=aspect, ).as_workunit() - def _get_ml_model_workunits(self) -> Iterable[MetadataWorkUnit]: - """ - Traverse each Registered Model in Model Registry and generate a corresponding workunit. - """ - registered_models = self._get_mlflow_registered_models() - for registered_model in registered_models: - yield self._get_ml_group_workunit(registered_model) - model_versions = self._get_mlflow_model_versions(registered_model) - for model_version in model_versions: - run = self._get_mlflow_run(model_version) - yield self._get_ml_model_properties_workunit( - registered_model=registered_model, - model_version=model_version, - run=run, + def _get_experiment_workunits(self) -> Iterable[MetadataWorkUnit]: + experiments = self._get_mlflow_experiments() + for experiment in experiments: + # Yield each workunit from the container workunits + for wu in self._get_experiment_container_workunit(experiment): + yield wu + + runs = self._get_mlflow_runs_from_experiment(experiment) + if runs: + for run in runs: + for wu in self._get_run_workunits(experiment, run): + yield wu + + def _get_experiment_custom_properties(self, experiment): + experiment_custom_props = getattr(experiment, 'tags', {}) or {} + experiment_custom_props.pop('mlflow.note.content', None) + experiment_custom_props['artifacts_location'] = experiment.artifact_location + return experiment_custom_props + + def _get_experiment_container_workunit(self, experiment: Experiment) -> List[MetadataWorkUnit]: + experiment_container = Container( + key=ContainerKeyWithId( + platform=str(DataPlatformUrn.create_from_id("mlflow")), + id=experiment.name, + ), + subtype="ML Experiment", + name=experiment.name, + description=experiment.tags.get('mlflow.note.content') + ) # TODO: this generates a urn as guid, should we change this to use experiment.id? + + workunits = [mcp.as_workunit() for mcp in experiment_container.generate_mcp()] + return workunits + + + def _get_run_custom_properties(self, run: Run): + custom_props = {} + custom_props.update(getattr(run, 'tags', {}) or {}) + return custom_props + + def _get_run_metrics(self, run: Run): + return [MLMetricClass(name=k, value=str(v)) for k, v in run.data.metrics.items()] + + def _get_run_params(self, run: Run): + return [MLHyperParamClass(name=k, value=str(v)) for k, v in run.data.params.items()] + + def _convert_run_result_type(self, status: str) -> DataProcessInstanceRunResultClass: + if status == "FINISHED": + return DataProcessInstanceRunResultClass(type="SUCCESS", nativeResultType="mlflow") + elif status == "FAILED": + return DataProcessInstanceRunResultClass(type="FAILURE", nativeResultType="mlflow") + else: + return DataProcessInstanceRunResultClass(type="SKIPPED", nativeResultType="mlflow") + + def _get_run_workunits(self, experiment: Experiment, run: Run) -> List[MetadataWorkUnit]: + experiment_key = ContainerKeyWithId( + platform=str(DataPlatformUrn.create_from_id("mlflow")), + id=experiment.name + ) + + data_process_instance = DataProcessInstance.from_container( + container_key=experiment_key, + id=run.info.run_name + ) # TODO: this generates a urn as guid, should we change this to use run.info.run_id? + workunits = [] + + run_custom_props = self._get_run_custom_properties(run) + created_time = run.info.start_time or int(time.time() * 1000) + created_actor = f"urn:li:platformResource:{run.info.user_id}" if run.info.user_id else None + + workunits.append(MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=DataProcessInstancePropertiesClass( + name=run.info.run_name or run.info.run_id, + created=TimeStampClass( + time=created_time, + actor=created_actor, + ), + externalUrl=self._make_external_url_from_run(experiment, run), + customProperties=run_custom_props, + ) + ).as_workunit()) + + metrics = self._get_run_metrics(run) + hyperparams = self._get_run_params(run) + + workunits.append(MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=MLTrainingRunPropertiesClass( + hyperParams=hyperparams, + trainingMetrics=metrics, + outputUrls=[run.info.artifact_uri], + id=run.info.run_id, ) - yield self._get_global_tags_workunit(model_version=model_version) + ).as_workunit() + ) + + result = run.info.status # TODO: this should be SUCCESS, SKIPPED, FAILURE, UP_FOR_RETRY + duration_millis = run.info.end_time - run.info.start_time + + # start event + if run.info.start_time: + workunits.append(MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=DataProcessInstanceRunEventClass( + status=DataProcessRunStatusClass.STARTED, + timestampMillis=run.info.start_time, + ) + ).as_workunit()) + + + # end event + if run.info.end_time: + workunits.append(MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=DataProcessInstanceRunEventClass( + status=DataProcessRunStatusClass.COMPLETE, + timestampMillis=run.info.end_time, + result=DataProcessInstanceRunResultClass( + type=self._convert_run_result_type(result).type, + nativeResultType="mlflow", + ), + durationMillis=duration_millis + + ) + ).as_workunit()) + + workunits.append(MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=DataPlatformInstanceClass(platform=str(DataPlatformUrn.create_from_id("mlflow"))) + ).as_workunit()) + + workunits.append(MetadataChangeProposalWrapper( + entityUrn=str(data_process_instance.urn), + aspect=SubTypesClass(typeNames=["ML Training Run"]) + ).as_workunit()) + + return workunits def _get_mlflow_registered_models(self) -> Iterable[RegisteredModel]: - """ - Get all Registered Models in MLflow Model Registry. - """ - registered_models: Iterable[ - RegisteredModel - ] = self._traverse_mlflow_search_func( + registered_models: Iterable[RegisteredModel] = self._traverse_mlflow_search_func( search_func=self.client.search_registered_models, ) return registered_models + def _get_mlflow_experiments(self) -> Iterable[Experiment]: + experiments: Iterable[Experiment] = self._traverse_mlflow_search_func( + search_func=self.client.search_experiments, + ) + return experiments + + def _get_mlflow_runs_from_experiment(self, experiment: Experiment) -> List[Run]: + runs: List[Run] = self._traverse_mlflow_search_func( + search_func=self.client.search_runs, + experiment_ids=[experiment.experiment_id], + ) + return runs + @staticmethod def _traverse_mlflow_search_func( - search_func: Callable[..., PagedList[T]], - **kwargs: Any, + search_func: Callable[..., PagedList[T]], + **kwargs: Any, ) -> Iterable[T]: - """ - Utility to traverse an MLflow search_* functions which return PagedList. - """ next_page_token = None while True: paged_list = search_func(page_token=next_page_token, **kwargs) - yield from paged_list.to_list() + yield from paged_list next_page_token = paged_list.token if not next_page_token: return + def _get_latest_version(self, registered_model: RegisteredModel) -> Optional[str]: + return ( + str(registered_model.latest_versions[0].version) + if registered_model.latest_versions + else None + ) + def _get_ml_group_workunit( - self, - registered_model: RegisteredModel, + self, + registered_model: RegisteredModel, ) -> MetadataWorkUnit: - """ - Generate an MLModelGroup workunit for an MLflow Registered Model. - """ ml_model_group_urn = self._make_ml_model_group_urn(registered_model) ml_model_group_properties = MLModelGroupPropertiesClass( customProperties=registered_model.tags, description=registered_model.description, - createdAt=registered_model.creation_timestamp, + created=TimeStampClass( + time=registered_model.creation_timestamp, + actor=None + ), + lastModified=TimeStampClass( + time=registered_model.last_updated_timestamp, + actor=None, + ), + version=VersionTagClass( + versionTag=self._get_latest_version(registered_model) + ), ) wu = self._create_workunit( urn=ml_model_group_urn, @@ -223,64 +399,76 @@ def _make_ml_model_group_urn(self, registered_model: RegisteredModel) -> str: return urn def _get_mlflow_model_versions( - self, - registered_model: RegisteredModel, + self, + registered_model: RegisteredModel, ) -> Iterable[ModelVersion]: - """ - Get all Model Versions for each Registered Model. - """ filter_string = f"name = '{registered_model.name}'" model_versions: Iterable[ModelVersion] = self._traverse_mlflow_search_func( search_func=self.client.search_model_versions, filter_string=filter_string, ) return model_versions - def _get_mlflow_run(self, model_version: ModelVersion) -> Union[None, Run]: - """ - Get a Run associated with a Model Version. Some MVs may exist without Run. - """ if model_version.run_id: run = self.client.get_run(model_version.run_id) return run else: return None + def _get_ml_model_workunits(self) -> Iterable[MetadataWorkUnit]: + registered_models = self._get_mlflow_registered_models() + for registered_model in registered_models: + yield self._get_ml_group_workunit(registered_model) + model_versions = self._get_mlflow_model_versions(registered_model) + for model_version in model_versions: + run = self._get_mlflow_run(model_version) + yield self._get_ml_model_properties_workunit( + registered_model=registered_model, + model_version=model_version, + run=run, + ) + yield self._get_global_tags_workunit(model_version=model_version) + def _get_ml_model_properties_workunit( - self, - registered_model: RegisteredModel, - model_version: ModelVersion, - run: Union[None, Run], + self, + registered_model: RegisteredModel, + model_version: ModelVersion, + run: Union[None, Run], ) -> MetadataWorkUnit: - """ - Generate an MLModel workunit for an MLflow Model Version. - Every Model Version is a DataHub MLModel entity associated with an MLModelGroup corresponding to a Registered Model. - If a model was registered without an associated Run then hyperparams and metrics are not available. - """ ml_model_group_urn = self._make_ml_model_group_urn(registered_model) ml_model_urn = self._make_ml_model_urn(model_version) + if run: - hyperparams = [ - MLHyperParamClass(name=k, value=str(v)) - for k, v in run.data.params.items() - ] - training_metrics = [ - MLMetricClass(name=k, value=str(v)) for k, v in run.data.metrics.items() - ] + # Use the same metrics and hyperparams from the run + hyperparams = self._get_run_params(run) + training_metrics = self._get_run_metrics(run) + training_jobs = [str(builder.make_data_process_instance_urn(run.info.run_id))] else: hyperparams = None training_metrics = None + training_jobs = None + + created_time = model_version.creation_timestamp + created_actor = f"urn:li:platformResource:{model_version.user_id}" if model_version.user_id else None + ml_model_properties = MLModelPropertiesClass( customProperties=model_version.tags, - externalUrl=self._make_external_url(model_version), + lastModified=TimeStampClass( + time=model_version.last_updated_timestamp, + actor=None, + ), + externalUrl=self._make_external_url_from_model_version(model_version), description=model_version.description, - date=model_version.creation_timestamp, + created=TimeStampClass( + time=created_time, + actor=created_actor, + ), version=VersionTagClass(versionTag=str(model_version.version)), hyperParams=hyperparams, trainingMetrics=training_metrics, - # mlflow tags are dicts, but datahub tags are lists. currently use only keys from mlflow tags + trainingJobs=training_jobs, tags=list(model_version.tags.keys()), - groups=[ml_model_group_urn], + groups=[str(ml_model_group_urn)], ) wu = self._create_workunit(urn=ml_model_urn, aspect=ml_model_properties) return wu @@ -293,30 +481,23 @@ def _make_ml_model_urn(self, model_version: ModelVersion) -> str: ) return urn - def _get_base_external_url_from_tracking_uri(self) -> Optional[str]: - if isinstance( - self.client.tracking_uri, str - ) and self.client.tracking_uri.startswith("http"): - return self.client.tracking_uri + def _make_external_url_from_model_version(self, model_version: ModelVersion) -> Union[None, str]: + base_uri = self.client.tracking_uri + if base_uri.startswith("http"): + return f"{base_uri.rstrip('/')}/#/models/{model_version.name}/versions/{model_version.version}" else: return None - def _make_external_url(self, model_version: ModelVersion) -> Optional[str]: - """ - Generate URL for a Model Version to MLflow UI. - """ - base_uri = ( - self.config.base_external_url - or self._get_base_external_url_from_tracking_uri() - ) - if base_uri: - return f"{base_uri.rstrip('/')}/#/models/{model_version.name}/versions/{model_version.version}" + def _make_external_url_from_run(self, experiment: Experiment, run: Run) -> Union[None, str]: + base_uri = self.client.tracking_uri + if base_uri.startswith("http"): + return f"{base_uri.rstrip('/')}/#/experiments/{experiment.experiment_id}/runs/{run.info.run_id}" else: return None def _get_global_tags_workunit( - self, - model_version: ModelVersion, + self, + model_version: ModelVersion, ) -> MetadataWorkUnit: """ Associate a Model Version Stage with a corresponding tag. @@ -337,4 +518,4 @@ def _get_global_tags_workunit( @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> Source: config = MLflowConfig.parse_obj(config_dict) - return cls(ctx, config) + return cls(ctx, config) \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stateful_ingestion.py b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stateful_ingestion.py index 96ab8f7a01a386..d1b5073bcea2fe 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stateful_ingestion.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stateful_ingestion.py @@ -31,9 +31,9 @@ DataProcessInstanceProperties, ) from datahub.metadata.schema_classes import ( - AuditStampClass, DataPlatformInstanceClass, StatusClass, + TimeStampClass, ) from datahub.metadata.urns import DataPlatformUrn, QueryUrn from datahub.utilities.urns.dataset_urn import DatasetUrn @@ -137,7 +137,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: entityUrn=str(dpi.urn), aspect=DataProcessInstanceProperties( name=dpi.id, - created=AuditStampClass( + created=TimeStampClass( time=int(time.time() * 1000), actor="urn:li:corpuser:datahub", ), diff --git a/metadata-integration/java/openlineage-converter/src/main/java/io/datahubproject/openlineage/converter/OpenLineageToDataHub.java b/metadata-integration/java/openlineage-converter/src/main/java/io/datahubproject/openlineage/converter/OpenLineageToDataHub.java index 9237ee60f473b4..8fe85a8bbc9359 100644 --- a/metadata-integration/java/openlineage-converter/src/main/java/io/datahubproject/openlineage/converter/OpenLineageToDataHub.java +++ b/metadata-integration/java/openlineage-converter/src/main/java/io/datahubproject/openlineage/converter/OpenLineageToDataHub.java @@ -706,12 +706,10 @@ private static DataProcessInstanceProperties getJobDataProcessInstanceProperties OpenLineage.RunEvent event) throws URISyntaxException { DataProcessInstanceProperties dpiProperties = new DataProcessInstanceProperties(); dpiProperties.setName(event.getRun().getRunId().toString()); - AuditStamp auditStamp = new AuditStamp(); + TimeStamp timeStamp = new TimeStamp(); if (event.getEventTime() != null) { - auditStamp.setTime(event.getEventTime().toInstant().toEpochMilli()); + dpiProperties.setCreated(timeStamp.setTime(event.getEventTime().toInstant().toEpochMilli())); } - auditStamp.setActor(Urn.createFromString(URN_LI_CORPUSER_DATAHUB)); - dpiProperties.setCreated(auditStamp); return dpiProperties; } diff --git a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceInput.pdl b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceInput.pdl index d005cd557cf77e..32329f60bfaa70 100644 --- a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceInput.pdl +++ b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceInput.pdl @@ -15,7 +15,8 @@ record DataProcessInstanceInput { @Relationship = { "/*": { "name": "Consumes", - "entityTypes": [ "dataset" ] + "entityTypes": [ "dataset" ], + "isLineage": true } } @Searchable = { diff --git a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceOutput.pdl b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceOutput.pdl index f33c41e63efed6..223abfb7e6f4c5 100644 --- a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceOutput.pdl +++ b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceOutput.pdl @@ -15,7 +15,9 @@ record DataProcessInstanceOutput { @Relationship = { "/*": { "name": "Produces", - "entityTypes": [ "dataset" ] + "entityTypes": [ "dataset" ], + "isLineage": true, + "isUpstream": false } } @Searchable = { diff --git a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceProperties.pdl b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceProperties.pdl index c63cb1a97c017d..59c482bd91e806 100644 --- a/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceProperties.pdl +++ b/metadata-models/src/main/pegasus/com/linkedin/dataprocess/DataProcessInstanceProperties.pdl @@ -1,6 +1,6 @@ namespace com.linkedin.dataprocess -import com.linkedin.common.AuditStamp +import com.linkedin.common.TimeStamp import com.linkedin.common.CustomProperties import com.linkedin.common.ExternalReference import com.linkedin.common.Urn @@ -50,6 +50,5 @@ record DataProcessInstanceProperties includes CustomProperties, ExternalReferenc "fieldName": "created" } } - created: AuditStamp - -} + created: TimeStamp +} \ No newline at end of file diff --git a/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLModelGroupProperties.pdl b/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLModelGroupProperties.pdl index b54e430038082d..a84b9c9d3994b4 100644 --- a/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLModelGroupProperties.pdl +++ b/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLModelGroupProperties.pdl @@ -4,6 +4,7 @@ import com.linkedin.common.Urn import com.linkedin.common.Time import com.linkedin.common.VersionTag import com.linkedin.common.CustomProperties +import com.linkedin.common.TimeStamp /** * Properties associated with an ML Model Group @@ -25,10 +26,21 @@ record MLModelGroupProperties includes CustomProperties { /** * Date when the MLModelGroup was developed */ + @deprecated createdAt: optional Time + /** + * Time and Actor who created the MLModelGroup + */ + created: optional TimeStamp + + /** + * Date when the MLModelGroup was last modified + */ + lastModified: optional TimeStamp + /** * Version of the MLModelGroup */ version: optional VersionTag -} +} \ No newline at end of file diff --git a/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLModelProperties.pdl b/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLModelProperties.pdl index 621a3e1747b504..e0cab82a0943ef 100644 --- a/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLModelProperties.pdl +++ b/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLModelProperties.pdl @@ -6,6 +6,7 @@ import com.linkedin.common.Time import com.linkedin.common.VersionTag import com.linkedin.common.CustomProperties import com.linkedin.common.ExternalReference +import com.linkedin.common.TimeStamp /** * Properties associated with a ML Model @@ -15,6 +16,18 @@ import com.linkedin.common.ExternalReference } record MLModelProperties includes CustomProperties, ExternalReference { + /** + * Display name of the MLModel + */ + @Searchable = { + "fieldType": "WORD_GRAM", + "enableAutocomplete": true, + "boostScore": 10.0, + "queryByDefault": true, + } + name: optional string + + /** * Documentation of the MLModel */ @@ -27,8 +40,19 @@ record MLModelProperties includes CustomProperties, ExternalReference { /** * Date when the MLModel was developed */ + @deprecated date: optional Time + /** + * Audit stamp containing who created this and when + */ + created: optional TimeStamp + + /** + * Date when the MLModel was last modified + */ + lastModified: optional TimeStamp + /** * Version of the MLModel */ @@ -93,24 +117,24 @@ record MLModelProperties includes CustomProperties, ExternalReference { deployments: optional array[Urn] /** - * List of jobs (if any) used to train the model + * List of jobs or process instances (if any) used to train the model */ @Relationship = { "/*": { "name": "TrainedBy", - "entityTypes": [ "dataJob" ], + "entityTypes": [ "dataJob", "dataProcessInstance" ], "isLineage": true } } trainingJobs: optional array[Urn] /** - * List of jobs (if any) that use the model + * List of jobs or process instance (if any) that use the model */ @Relationship = { "/*": { "name": "UsedBy", - "entityTypes": [ "dataJob" ], + "entityTypes": [ "dataJob" , "dataProcessInstance" ], "isLineage": true, "isUpstream": false } @@ -129,4 +153,4 @@ record MLModelProperties includes CustomProperties, ExternalReference { } } groups: optional array[Urn] -} +} \ No newline at end of file diff --git a/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLTrainingRunProperties.pdl b/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLTrainingRunProperties.pdl new file mode 100644 index 00000000000000..f8b8eeafe908b7 --- /dev/null +++ b/metadata-models/src/main/pegasus/com/linkedin/ml/metadata/MLTrainingRunProperties.pdl @@ -0,0 +1,36 @@ +namespace com.linkedin.ml.metadata + +import com.linkedin.common.AuditStamp +import com.linkedin.common.CustomProperties +import com.linkedin.common.ExternalReference +import com.linkedin.common.Urn +import com.linkedin.common.JobFlowUrn +import com.linkedin.common.DataJobUrn +/** + * The inputs and outputs of this training run + */ +@Aspect = { + "name": "mlTrainingRunProperties", +} +record MLTrainingRunProperties includes CustomProperties, ExternalReference { + + /** + * Run Id of the ML Training Run + */ + id: optional string + + /** + * List of URLs for the Outputs of the ML Training Run + */ + outputUrls: optional array[string] + + /** + * Hyperparameters of the ML Training Run + */ + hyperParams: optional array[MLHyperParam] + + /** + * Metrics of the ML Training Run + */ + trainingMetrics: optional array[MLMetric] +} \ No newline at end of file diff --git a/metadata-models/src/main/resources/entity-registry.yml b/metadata-models/src/main/resources/entity-registry.yml index 1c3eb5b574e204..4fe170ced69f33 100644 --- a/metadata-models/src/main/resources/entity-registry.yml +++ b/metadata-models/src/main/resources/entity-registry.yml @@ -116,6 +116,10 @@ entities: - dataProcessInstanceRunEvent - status - testResults + - dataPlatformInstance + - subTypes + - container + - mlTrainingRunProperties - name: chart category: core keyAspect: chartKey