Skip to content

Commit

Permalink
Hparams: Generate map of runs to hparams by matching with session nam…
Browse files Browse the repository at this point in the history
…es. (#6600)

The way we generate the mapping of runs to hparams for the dashboard
table must be changed.

The primary reason for this is that the current algorithm relies on
metric names from the session_group call to match runs with sessions
before mapping the session's hparams to the run. But this is not always
an accurate way of identifying the complete set of runs that belong to a
session and, besides, we will be turning off metric retrieval for
dashboard's session_group call in a change later this week.

We instead rely on the property of sessions that their name is the
prefix for all runs that belong to the session.

Note: The new algorithm requires getting all run ids from the runs state
so we move the algorithm from hparams_selectors to runs_selectors in
order to avoid a circular BUILD dependency.

Note: Some of the comments in this change claim that the hparams data
source does not retrieve metrics data. This is not true, yet, but will
be by the end of the week.
  • Loading branch information
bmd3k authored Oct 4, 2023
1 parent d295fad commit 0793c72
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 252 deletions.
40 changes: 3 additions & 37 deletions tensorboard/webapp/hparams/_redux/hparams_selectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -193,44 +193,10 @@ export const getDashboardHparamsAndMetricsSpecs = createSelector(
}
);

export const getDashboardRunsToHparamsAndMetrics = createSelector(
export const getDashboardSessionGroups = createSelector(
getHparamsState,
(state): RunToHparamsAndMetrics => {
const runToHparamsAndMetrics: RunToHparamsAndMetrics = {};

for (const sessionGroup of state.dashboardSessionGroups) {
const hparams: HparamValue[] = Object.entries(sessionGroup.hparams).map(
(keyValue) => {
const [hparam, value] = keyValue;
return {name: hparam, value};
}
);

for (const session of sessionGroup.sessions) {
runToHparamsAndMetrics[session.name] = {
metrics: [],
hparams,
};

for (const metricValue of session.metricValues) {
const runId = metricValue.name.group
? `${session.name}/${metricValue.name.group}`
: session.name;

const hparamsAndMetrics = runToHparamsAndMetrics[runId] || {
metrics: [],
hparams,
};
hparamsAndMetrics.metrics.push({
tag: metricValue.name.tag,
trainingStep: metricValue.trainingStep,
value: metricValue.value,
});
runToHparamsAndMetrics[runId] = hparamsAndMetrics;
}
}
}
return runToHparamsAndMetrics;
(state: HparamsState) => {
return state.dashboardSessionGroups;
}
);

Expand Down
218 changes: 11 additions & 207 deletions tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -487,218 +487,22 @@ describe('hparams/_redux/hparams_selectors_test', () => {
});
});

describe('#getDashboardRunsToHparamsAndMetrics', () => {
it('contains entry for each runId/group', () => {
const mockSessionGroups = [
buildSessionGroup({
name: 'session_group_1',
hparams: {
hp1: 1,
hp2: true,
hp3: 'foo',
},
sessions: [
{
name: 'exp1/run1',
metricValues: [
buildMetricsValue({name: {tag: 'foo', group: '1'}, value: 2}),
buildMetricsValue({
name: {tag: 'bar', group: '2'},
value: 103,
trainingStep: 4,
}),
buildMetricsValue({
name: {tag: 'bar', group: '2'},
value: 107,
trainingStep: 5,
}),
buildMetricsValue({name: {tag: 'abc123', group: ''}, value: 2}),
],
},
{
name: 'exp1/run2',
metricValues: [
buildMetricsValue({name: {tag: 'foo', group: '1'}, value: 3}),
buildMetricsValue({name: {tag: 'bar', group: '2'}, value: 104}),
buildMetricsValue({name: {tag: 'baz', group: '3'}, value: 201}),
],
},
],
}),
buildSessionGroup({
name: 'session_group_2',
hparams: {
hp1: 2,
hp2: false,
hp3: 'bar',
},
sessions: [
{
name: 'exp1/run3',
metricValues: [
buildMetricsValue({name: {tag: 'foo', group: '1'}, value: 4}),
buildMetricsValue({name: {tag: 'bar', group: '2'}, value: 105}),
],
},
],
}),
buildSessionGroup({
name: 'session_group_3',
hparams: {
hp4: 'hyperparameter4',
},
sessions: [
{
name: 'exp1/run4',
metricValues: [],
},
],
}),
buildSessionGroup({
name: 'session_group_4',
hparams: {
hp1: 7,
hp2: false,
hp3: 'foobar',
},
sessions: [
describe('#getDashboardSessionGroups', () => {
it('returns dashboard session groups', () => {
const state = buildStateFromHparamsState(
buildHparamsState({
dashboardSessionGroups: [
{
name: 'exp2/run1',
metricValues: [
buildMetricsValue({name: {tag: 'foo', group: '1'}, value: 4}),
buildMetricsValue({name: {tag: 'bar', group: '2'}, value: 105}),
buildMetricsValue({
name: {tag: 'baz', group: '2'},
value: 1000,
}),
],
name: 'SessionGroup1',
hparams: {hparam1: 'value1'},
sessions: [],
},
],
}),
];

const state = buildStateFromHparamsState(
buildHparamsState({
dashboardSessionGroups: mockSessionGroups,
})
);

expect(selectors.getDashboardRunsToHparamsAndMetrics(state)).toEqual({
'exp1/run1': {
metrics: [{tag: 'abc123', trainingStep: 0, value: 2}],
hparams: [
{name: 'hp1', value: 1},
{name: 'hp2', value: true},
{name: 'hp3', value: 'foo'},
],
},
'exp1/run1/1': {
metrics: [{tag: 'foo', trainingStep: 0, value: 2}],
hparams: [
{name: 'hp1', value: 1},
{name: 'hp2', value: true},
{name: 'hp3', value: 'foo'},
],
},
'exp1/run1/2': {
metrics: [
{tag: 'bar', trainingStep: 4, value: 103},
{tag: 'bar', trainingStep: 5, value: 107},
],
hparams: [
{name: 'hp1', value: 1},
{name: 'hp2', value: true},
{name: 'hp3', value: 'foo'},
],
},
'exp1/run2': {
metrics: [],
hparams: [
{name: 'hp1', value: 1},
{name: 'hp2', value: true},
{name: 'hp3', value: 'foo'},
],
},
'exp1/run2/1': {
metrics: [{tag: 'foo', trainingStep: 0, value: 3}],
hparams: [
{name: 'hp1', value: 1},
{name: 'hp2', value: true},
{name: 'hp3', value: 'foo'},
],
},
'exp1/run2/2': {
metrics: [{tag: 'bar', trainingStep: 0, value: 104}],
hparams: [
{name: 'hp1', value: 1},
{name: 'hp2', value: true},
{name: 'hp3', value: 'foo'},
],
},
'exp1/run2/3': {
metrics: [{tag: 'baz', trainingStep: 0, value: 201}],
hparams: [
{name: 'hp1', value: 1},
{name: 'hp2', value: true},
{name: 'hp3', value: 'foo'},
],
},
'exp1/run3': {
metrics: [],
hparams: [
{name: 'hp1', value: 2},
{name: 'hp2', value: false},
{name: 'hp3', value: 'bar'},
],
},
'exp1/run3/1': {
metrics: [{tag: 'foo', trainingStep: 0, value: 4}],
hparams: [
{name: 'hp1', value: 2},
{name: 'hp2', value: false},
{name: 'hp3', value: 'bar'},
],
},
'exp1/run3/2': {
metrics: [{tag: 'bar', trainingStep: 0, value: 105}],
hparams: [
{name: 'hp1', value: 2},
{name: 'hp2', value: false},
{name: 'hp3', value: 'bar'},
],
},
'exp1/run4': {
metrics: [],
hparams: [{name: 'hp4', value: 'hyperparameter4'}],
},
'exp2/run1': {
metrics: [],
hparams: [
{name: 'hp1', value: 7},
{name: 'hp2', value: false},
{name: 'hp3', value: 'foobar'},
],
},
'exp2/run1/1': {
metrics: [{tag: 'foo', trainingStep: 0, value: 4}],
hparams: [
{name: 'hp1', value: 7},
{name: 'hp2', value: false},
{name: 'hp3', value: 'foobar'},
],
},
'exp2/run1/2': {
metrics: [
{tag: 'bar', trainingStep: 0, value: 105},
{tag: 'baz', trainingStep: 0, value: 1000},
],
hparams: [
{name: 'hp1', value: 7},
{name: 'hp2', value: false},
{name: 'hp3', value: 'foobar'},
],
},
});
expect(selectors.getDashboardSessionGroups(state)).toEqual([
{name: 'SessionGroup1', hparams: {hparam1: 'value1'}, sessions: []},
]);
});
});

Expand Down
85 changes: 77 additions & 8 deletions tensorboard/webapp/runs/store/runs_selectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ import {
RUNS_FEATURE_KEY,
} from './runs_types';
import {createGroupBy} from './utils';
import {ColumnHeader, SortingInfo} from '../../widgets/data_table/types';
import {getDashboardRunsToHparamsAndMetrics} from '../../hparams/_redux/hparams_selectors';
import {RunToHparamsAndMetrics} from '../../hparams/types';
import {getExperimentIdsFromRoute} from '../../app_routing/store/app_routing_selectors';
import {getDashboardSessionGroups} from '../../hparams/_redux/hparams_selectors';
import {HparamValue, RunToHparamsAndMetrics} from '../../hparams/types';
import {ColumnHeader, SortingInfo} from '../../widgets/data_table/types';

const getRunsState = createFeatureSelector<RunsState>(RUNS_FEATURE_KEY);

Expand Down Expand Up @@ -85,16 +85,77 @@ export const getRuns = createSelector(
}
);

/**
* Determines hparam data for each run in the active route.
*
* Attempts to match each run with a session and, if found, copies the hparam
* values from the corresponding session group to the run.
*
* Note, it returns an RunToHparamsAndMetrics but leaves the `metrics` field
* blank since the Hparams data sources does not actually retrieve metrics
* data.
*
* Meant for usage in the Dashboard views.
*/
export const getDashboardRunsToHparams = createSelector(
getDashboardSessionGroups,
getExperimentIdsFromRoute,
getDataState,
(dashboardSessionGroups, experimentIds, state): RunToHparamsAndMetrics => {
if (!experimentIds) {
return {};
}

const runIds: string[] = [];
for (const experimentId of experimentIds) {
runIds.push(...(state.runIds[experimentId] || []));
}

const sessionToHparams: Record<string, HparamValue[]> = {};
for (const sessionGroup of dashboardSessionGroups) {
const hparams: HparamValue[] = Object.entries(sessionGroup.hparams).map(
(keyValue) => {
const [hparam, value] = keyValue;
return {name: hparam, value};
}
);
for (const session of sessionGroup.sessions) {
sessionToHparams[session.name] = hparams;
}
}

// Sort sessions based on length of name. We want to match runs with the
// longest matching session name. So, for example, given sessions "1" and
// "11", a run with name "11/train" should match with session "11".
const sortedSessionKeys = Object.keys(sessionToHparams).sort(
(a, b) => b.length - a.length
);

const runToHparamsAndMetrics: RunToHparamsAndMetrics = {};
for (const runId of runIds) {
for (const sessionName of sortedSessionKeys) {
if (runId.startsWith(sessionName)) {
runToHparamsAndMetrics[runId] = {
hparams: sessionToHparams[sessionName],
// The underlying data source that fetches the session groups data
// does not retrieve metrics.
metrics: [],
};
break;
}
}
}
return runToHparamsAndMetrics;
}
);

/**
* Get the runs used on the dashboard.
* TODO(rileyajones) get the experiment ids from the state rather than as an argument.
* @param experimentIds
* @returns
*/
export const getDashboardRuns = createSelector(
getDataState,
getExperimentIdsFromRoute,
getDashboardRunsToHparamsAndMetrics,
getDashboardRunsToHparams,
(
state: RunsDataState,
experimentIds: string[] | null,
Expand All @@ -109,9 +170,17 @@ export const getDashboardRuns = createSelector(
.filter((id) => Boolean(state.runMetadata[id]))
.map((runId) => {
const run = {...state.runMetadata[runId], experimentId};
// runMetadata contains hparam and metric values that were retrieved
// for the run in isolation. This data is incorrect for use in the
// dashboard, where we might be comparing multiple experiments and
// the set of hparams and metrics may be a superset.
//
// Instead we override the hparam and metric values with those
// calculated by getDashboardRunsToHparamsAndMetrics, which is based
// on the hparam and metric data for all active experiments
// together.
run.hparams = runsToHparamsAndMetrics[runId]?.hparams ?? null;
run.metrics = runsToHparamsAndMetrics[runId]?.metrics ?? null;

return run;
});
})
Expand Down
Loading

0 comments on commit 0793c72

Please sign in to comment.