Skip to content

Commit

Permalink
fix mlflow connector
Browse files Browse the repository at this point in the history
  • Loading branch information
yoonhyejin committed Dec 19, 2024
1 parent 08044cc commit 2e50b0d
Showing 1 changed file with 10 additions and 24 deletions.
34 changes: 10 additions & 24 deletions metadata-ingestion/src/datahub/ingestion/source/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
DataProcessInstancePropertiesClass,
DataProcessInstanceRelationshipsClass,
ContainerPropertiesClass,
AuditStampClass,
TimeStampClass,
DataProcessRunStatusClass,
SubTypesClass,
Expand Down Expand Up @@ -161,9 +160,10 @@ def get_report(self) -> SourceReport:
return self.report

def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
print("get_workunits_internal")
yield from self._get_tags_workunits()
yield from self._get_experiment_workunits()
yield from self._get_ml_model_workunits()
# yield from self._get_experiment_workunits()

def _get_tags_workunits(self) -> Iterable[MetadataWorkUnit]:
for stage_info in self.registered_model_stages_info:
Expand Down Expand Up @@ -218,7 +218,7 @@ def _get_experiment_container_workunit(self, experiment: Experiment) -> List[Met
subtype="ML Experiment",
name=experiment.name,
description=experiment.tags.get('mlflow.note.content')
)
) # TODO: urn should be experiment id

print("experiment.key.id:", experiment.key.id) # this should be same as container key as urn
print("experiment.key.as_urn(): ", experiment.key.as_urn())
Expand Down Expand Up @@ -247,7 +247,6 @@ def _convert_run_result_type(self, status: str) -> DataProcessInstanceRunResultC
return DataProcessInstanceRunResultClass(type="SKIPPED", nativeResultType="mlflow")

def _get_run_workunits(self, experiment: Experiment, run: Run) -> List[MetadataWorkUnit]:
# TODO: this does not map to the correct experiment
experiment_key = ContainerKeyWithId(
platform=str(DataPlatformUrn.create_from_id("mlflow")),
id=experiment.name
Expand All @@ -258,6 +257,7 @@ def _get_run_workunits(self, experiment: Experiment, run: Run) -> List[MetadataW
id=run.info.run_name
)

# TODO: urn should be run id

print("dpi id", run.info.run_name)
print("experiment_key.id:", experiment_key.id)
Expand Down Expand Up @@ -398,6 +398,7 @@ def _get_ml_group_workunit(
versionTag=self._get_latest_version(registered_model)
),
)
print("ml_model_group_properties: ", ml_model_group_properties)
wu = self._create_workunit(
urn=ml_model_group_urn,
aspect=ml_model_group_properties,
Expand Down Expand Up @@ -452,30 +453,15 @@ def _get_ml_model_properties_workunit(
ml_model_group_urn = self._make_ml_model_group_urn(registered_model)
ml_model_urn = self._make_ml_model_urn(model_version)

hyperparams = None
training_metrics = None
training_jobs = None

if run:
# Use the same metrics and hyperparams from the run
hyperparams = self._get_run_params(run)
training_metrics = self._get_run_metrics(run)

# Create proper relationship with the run

# get experiment name from experiment id
experiment_id = run.info.experiment_id
experiment = self.client.get_experiment(experiment_id)
experiment_key = ContainerKeyWithId(
platform=str(DataPlatformUrn.create_from_id("mlflow")),
id=experiment.name
)

data_process_instance = DataProcessInstance.from_container(
container_key=experiment_key,
id=run.info.run_name
)
training_jobs = [str(data_process_instance.urn)]
else:
hyperparams = None
training_metrics = None
training_jobs = None
# training_jobs = [str(builder.make_data_process_instance_urn(run.info.run_id))]

created_time = model_version.creation_timestamp
created_actor = f"urn:li:platformResource:{model_version.user_id}" if model_version.user_id else None
Expand Down

0 comments on commit 2e50b0d

Please sign in to comment.