-
Notifications
You must be signed in to change notification settings - Fork 178
Commit
…aram selector
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import { testHook, standardUseFetchState } from '~/__tests__/unit/testUtils/hooks'; | ||
import { | ||
MetadataStoreServicePromiseClient, | ||
Artifact, | ||
Execution, | ||
Event, | ||
Context, | ||
} from '~/third_party/mlmd'; | ||
import { usePipelinesAPI } from '~/concepts/pipelines/context'; | ||
import { getMlmdContext } from '~/concepts/pipelines/apiHooks/mlmd/useMlmdContext'; | ||
import { PipelineRunKFv2 } from '~/concepts/pipelines/kfTypes'; | ||
import { | ||
GetArtifactsByContextResponse, | ||
GetExecutionsByContextResponse, | ||
GetEventsByExecutionIDsResponse, | ||
} from '~/third_party/mlmd/generated/ml_metadata/proto/metadata_store_service_pb'; | ||
import { useGetArtifactsByRuns } from '~/concepts/pipelines/apiHooks/mlmd/useGetArtifactsByRuns'; | ||
|
||
// Mock the usePipelinesAPI and getMlmdContext hooks | ||
jest.mock('~/concepts/pipelines/context', () => ({ | ||
usePipelinesAPI: jest.fn(), | ||
})); | ||
|
||
jest.mock('~/concepts/pipelines/apiHooks/mlmd/useMlmdContext', () => ({ | ||
getMlmdContext: jest.fn(), | ||
})); | ||
|
||
// Mock the MetadataStoreServicePromiseClient | ||
jest.mock('~/third_party/mlmd', () => { | ||
const originalModule = jest.requireActual('~/third_party/mlmd'); | ||
return { | ||
...originalModule, | ||
MetadataStoreServicePromiseClient: jest.fn().mockImplementation(() => ({ | ||
getArtifactsByContext: jest.fn(), | ||
getExecutionsByContext: jest.fn(), | ||
getEventsByExecutionIDs: jest.fn(), | ||
})), | ||
GetArtifactsByContextRequest: originalModule.GetArtifactsByContextRequest, | ||
GetExecutionsByContextRequest: originalModule.GetExecutionsByContextRequest, | ||
GetEventsByExecutionIDsRequest: originalModule.GetEventsByExecutionIDsRequest, | ||
}; | ||
}); | ||
|
||
describe('useGetArtifactsByRuns', () => { | ||
const mockClient = new MetadataStoreServicePromiseClient(''); | ||
const mockUsePipelinesAPI = jest.mocked( | ||
usePipelinesAPI as () => Partial<ReturnType<typeof usePipelinesAPI>>, | ||
); | ||
const mockGetMlmdContext = jest.mocked(getMlmdContext); | ||
const mockGetArtifactsByContext = jest.mocked(mockClient.getArtifactsByContext); | ||
const mockGetExecutionsByContext = jest.mocked(mockClient.getExecutionsByContext); | ||
const mockGetEventsByExecutionIDs = jest.mocked(mockClient.getEventsByExecutionIDs); | ||
|
||
const mockContext = new Context(); | ||
mockContext.setId(1); | ||
|
||
const mockArtifact = new Artifact(); | ||
mockArtifact.setId(1); | ||
mockArtifact.setName('artifact1'); | ||
|
||
const mockExecution = new Execution(); | ||
mockExecution.setId(1); | ||
|
||
const mockEvent = new Event(); | ||
mockEvent.getArtifactId = jest.fn().mockReturnValue(1); | ||
mockEvent.getExecutionId = jest.fn().mockReturnValue(1); | ||
|
||
// eslint-disable-next-line camelcase | ||
const mockRun = { run_id: 'test-run-id' } as PipelineRunKFv2; | ||
|
||
beforeEach(() => { | ||
jest.clearAllMocks(); | ||
mockUsePipelinesAPI.mockReturnValue({ | ||
metadataStoreServiceClient: mockClient, | ||
}); | ||
}); | ||
|
||
it('throws error when no MLMD context is found', async () => { | ||
mockGetMlmdContext.mockResolvedValue(undefined); | ||
const renderResult = testHook(useGetArtifactsByRuns)([mockRun]); | ||
|
||
// wait for update | ||
await renderResult.waitForNextUpdate(); | ||
|
||
expect(renderResult.result.current).toEqual( | ||
standardUseFetchState([], false, new Error('No context for run: test-run-id')), | ||
); | ||
}); | ||
|
||
it('should fetch and return MLMD packages for pipeline runs', async () => { | ||
mockGetMlmdContext.mockResolvedValue(mockContext); | ||
mockGetArtifactsByContext.mockResolvedValue({ | ||
getArtifactsList: () => [mockArtifact], | ||
} as GetArtifactsByContextResponse); | ||
mockGetExecutionsByContext.mockResolvedValue({ | ||
getExecutionsList: () => [mockExecution], | ||
} as GetExecutionsByContextResponse); | ||
mockGetEventsByExecutionIDs.mockResolvedValue({ | ||
getEventsList: () => [mockEvent], | ||
} as GetEventsByExecutionIDsResponse); | ||
|
||
const renderResult = testHook(useGetArtifactsByRuns)([mockRun]); | ||
|
||
expect(renderResult.result.current).toStrictEqual(standardUseFetchState([])); | ||
expect(renderResult).hookToHaveUpdateCount(1); | ||
|
||
// wait for update | ||
await renderResult.waitForNextUpdate(); | ||
|
||
expect(renderResult.result.current).toStrictEqual( | ||
standardUseFetchState( | ||
[ | ||
{ | ||
[mockRun.run_id]: [mockArtifact], | ||
}, | ||
], | ||
true, | ||
), | ||
); | ||
expect(renderResult).hookToHaveUpdateCount(2); | ||
}); | ||
|
||
it('should handle errors from getMlmdContext', async () => { | ||
const error = new Error('Cannot fetch context'); | ||
mockGetMlmdContext.mockRejectedValue(error); | ||
|
||
const renderResult = testHook(useGetArtifactsByRuns)([mockRun]); | ||
|
||
expect(renderResult.result.current).toStrictEqual(standardUseFetchState([])); | ||
expect(renderResult).hookToHaveUpdateCount(1); | ||
|
||
// wait for update | ||
await renderResult.waitForNextUpdate(); | ||
|
||
expect(renderResult.result.current).toStrictEqual(standardUseFetchState([], false, error)); | ||
expect(renderResult).hookToHaveUpdateCount(2); | ||
}); | ||
|
||
it('should handle errors from getArtifactsByContext', async () => { | ||
const error = new Error('Cannot fetch artifacts'); | ||
mockGetMlmdContext.mockResolvedValue(mockContext); | ||
mockGetArtifactsByContext.mockRejectedValue(error); | ||
|
||
const renderResult = testHook(useGetArtifactsByRuns)([mockRun]); | ||
|
||
expect(renderResult.result.current).toStrictEqual(standardUseFetchState([])); | ||
expect(renderResult).hookToHaveUpdateCount(1); | ||
|
||
// wait for update | ||
await renderResult.waitForNextUpdate(); | ||
|
||
expect(renderResult.result.current).toStrictEqual(standardUseFetchState([], false, error)); | ||
expect(renderResult).hookToHaveUpdateCount(2); | ||
}); | ||
}); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import React from 'react'; | ||
|
||
import { Artifact } from '~/third_party/mlmd'; | ||
import { GetArtifactsByContextRequest } from '~/third_party/mlmd/generated/ml_metadata/proto/metadata_store_service_pb'; | ||
import useFetchState, { FetchState, FetchStateCallbackPromise } from '~/utilities/useFetchState'; | ||
import { usePipelinesAPI } from '~/concepts/pipelines/context'; | ||
import { PipelineRunKFv2 } from '~/concepts/pipelines/kfTypes'; | ||
import { MlmdContextTypes } from './types'; | ||
import { getMlmdContext } from './useMlmdContext'; | ||
|
||
export const useGetArtifactsByRuns = ( | ||
runs: PipelineRunKFv2[], | ||
): FetchState<Record<string, Artifact[]>[]> => { | ||
const { metadataStoreServiceClient } = usePipelinesAPI(); | ||
|
||
const call = React.useCallback<FetchStateCallbackPromise<Record<string, Artifact[]>[]>>( | ||
() => | ||
Promise.all( | ||
runs.map((run) => | ||
getMlmdContext(metadataStoreServiceClient, run.run_id, MlmdContextTypes.RUN).then( | ||
async (context) => { | ||
if (!context) { | ||
throw new Error(`No context for run: ${run.run_id}`); | ||
} | ||
|
||
const request = new GetArtifactsByContextRequest(); | ||
request.setContextId(context.getId()); | ||
|
||
const response = await metadataStoreServiceClient.getArtifactsByContext(request); | ||
const artifacts = response.getArtifactsList(); | ||
|
||
return { | ||
[run.run_id]: artifacts, | ||
}; | ||
}, | ||
), | ||
), | ||
), | ||
[metadataStoreServiceClient, runs], | ||
); | ||
|
||
return useFetchState(call, []); | ||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import { Artifact } from '~/third_party/mlmd'; | ||
|
||
export interface ScalarMetrics { | ||
name: string; | ||
value: string; | ||
} | ||
|
||
export const getScalarMetrics = (artifact: Artifact): ScalarMetrics[] => | ||
artifact | ||
.toObject() | ||
.customPropertiesMap.reduce( | ||
( | ||
acc: { name: string; value: string }[], | ||
[customPropKey, { stringValue, intValue, doubleValue, boolValue }], | ||
) => { | ||
if (customPropKey !== 'display_name') { | ||
acc.push({ | ||
Check warning on line 17 in frontend/src/concepts/pipelines/content/pipelinesDetails/pipelineRun/artifacts/utils.ts Codecov / codecov/patchfrontend/src/concepts/pipelines/content/pipelinesDetails/pipelineRun/artifacts/utils.ts#L14-L17
|
||
name: customPropKey, | ||
value: stringValue || (intValue || doubleValue || boolValue).toString(), | ||
Check warning on line 19 in frontend/src/concepts/pipelines/content/pipelinesDetails/pipelineRun/artifacts/utils.ts Codecov / codecov/patchfrontend/src/concepts/pipelines/content/pipelinesDetails/pipelineRun/artifacts/utils.ts#L19
|
||
}); | ||
} | ||
|
||
return acc; | ||
Check warning on line 23 in frontend/src/concepts/pipelines/content/pipelinesDetails/pipelineRun/artifacts/utils.ts Codecov / codecov/patchfrontend/src/concepts/pipelines/content/pipelinesDetails/pipelineRun/artifacts/utils.ts#L23
|
||
}, | ||
[], | ||
); |