Skip to content

Commit

Permalink
Adds common selectors for shared hparams (#6732)
Browse files Browse the repository at this point in the history
## Motivation for features / changes
Displaying shared hparam columns in the runs and scalar tables requires
new selectors.

## Technical description of changes
Adds two new common selectors that will be used by both runs and scalar
tables:
- getSelectableColumns
- getGroupedColumns

Also updates getDashboardDisplayedHparamColumns to only return relevant
hparams (i.e. those with specs defined by selected experiments)

## Detailed steps to verify changes work correctly (as executed by you)
Unit tests pass
  • Loading branch information
hoonji authored Jan 31, 2024
1 parent d6ad97e commit 304046d
Show file tree
Hide file tree
Showing 19 changed files with 708 additions and 54 deletions.
1 change: 1 addition & 0 deletions tensorboard/webapp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ tf_ng_web_test_suite(
"//tensorboard/webapp/widgets/content_wrapping_input:content_wrapping_input_tests",
"//tensorboard/webapp/widgets/custom_modal:custom_modal_test",
"//tensorboard/webapp/widgets/data_table:data_table_test",
"//tensorboard/webapp/widgets/data_table:utils_test",
"//tensorboard/webapp/widgets/dropdown:dropdown_tests",
"//tensorboard/webapp/widgets/experiment_alias:experiment_alias_test",
"//tensorboard/webapp/widgets/filter_input:filter_input_test",
Expand Down
8 changes: 7 additions & 1 deletion tensorboard/webapp/hparams/_redux/hparams_selectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,14 @@ export const getDashboardDefaultHparamFilters = createSelector(
);

export const getDashboardDisplayedHparamColumns = createSelector(
getDashboardHparamsAndMetricsSpecs,
getHparamsState,
(state) => state.dashboardDisplayedHparamColumns
({hparams}, state) => {
const hparamSet = new Set(hparams.map((hparam) => hparam.name));
return state.dashboardDisplayedHparamColumns.filter((column) =>
hparamSet.has(column.name)
);
}
);

export const getDashboardHparamFilterMap = createSelector(
Expand Down
70 changes: 52 additions & 18 deletions tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

import {ColumnHeaderType} from '../../widgets/data_table/types';
import {DomainType} from '../types';
import {State} from './types';
import * as selectors from './hparams_selectors';
import {
buildHparamSpec,
Expand Down Expand Up @@ -114,30 +115,63 @@ describe('hparams/_redux/hparams_selectors_test', () => {
});

describe('#getDashboardDisplayedHparamColumns', () => {
it('returns dashboard displayed hparam columns', () => {
const fakeColumns = [
{
type: ColumnHeaderType.HPARAM,
name: 'conv_layers',
displayName: 'Conv Layers',
enabled: true,
},
{
type: ColumnHeaderType.HPARAM,
name: 'conv_kernel_size',
displayName: 'Conv Kernel Size',
enabled: true,
},
];
it('returns no columns if no hparam specs', () => {
const state = buildStateFromHparamsState(
buildHparamsState({
dashboardDisplayedHparamColumns: fakeColumns,
dashboardSpecs: {
hparams: [],
},
dashboardDisplayedHparamColumns: [
{
type: ColumnHeaderType.HPARAM,
name: 'conv_layers',
displayName: 'Conv Layers',
enabled: true,
},
{
type: ColumnHeaderType.HPARAM,
name: 'conv_kernel_size',
displayName: 'Conv Kernel Size',
enabled: true,
},
],
})
);

expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual(
fakeColumns
expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual([]);
});

it('returns only hparam columns that have specs', () => {
const state = buildStateFromHparamsState(
buildHparamsState({
dashboardSpecs: {
hparams: [buildHparamSpec({name: 'conv_layers'})],
},
dashboardDisplayedHparamColumns: [
{
type: ColumnHeaderType.HPARAM,
name: 'conv_layers',
displayName: 'Conv Layers',
enabled: true,
},
{
type: ColumnHeaderType.HPARAM,
name: 'conv_kernel_size',
displayName: 'Conv Kernel Size',
enabled: true,
},
],
})
);

expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual([
{
type: ColumnHeaderType.HPARAM,
name: 'conv_layers',
displayName: 'Conv Layers',
enabled: true,
},
]);
});
});
});
6 changes: 6 additions & 0 deletions tensorboard/webapp/metrics/store/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ tf_ts_library(
"//tensorboard/webapp/app_routing:types",
"//tensorboard/webapp/app_routing/actions",
"//tensorboard/webapp/core/actions",
"//tensorboard/webapp/hparams/_redux:hparams_selectors",
"//tensorboard/webapp/metrics:types",
"//tensorboard/webapp/metrics:utils",
"//tensorboard/webapp/metrics/actions",
Expand All @@ -33,6 +34,7 @@ tf_ts_library(
"//tensorboard/webapp/util:types",
"//tensorboard/webapp/widgets/card_fob:types",
"//tensorboard/webapp/widgets/data_table:types",
"//tensorboard/webapp/widgets/data_table:utils",
"//tensorboard/webapp/widgets/line_chart_v2/lib:public_types",
"@npm//@ngrx/store",
],
Expand Down Expand Up @@ -97,14 +99,18 @@ tf_ts_library(
"//tensorboard/webapp/app_routing:types",
"//tensorboard/webapp/app_routing/actions",
"//tensorboard/webapp/core/actions",
"//tensorboard/webapp/hparams:testing",
"//tensorboard/webapp/hparams/_redux:types",
"//tensorboard/webapp/metrics:test_lib",
"//tensorboard/webapp/metrics:types",
"//tensorboard/webapp/metrics/actions",
"//tensorboard/webapp/metrics/data_source",
"//tensorboard/webapp/persistent_settings",
"//tensorboard/webapp/routes:testing",
"//tensorboard/webapp/testing:utils",
"//tensorboard/webapp/types",
"//tensorboard/webapp/util:dom",
"//tensorboard/webapp/util:types",
"//tensorboard/webapp/widgets/card_fob:types",
"//tensorboard/webapp/widgets/data_table:types",
"@npm//@types/jasmine",
Expand Down
11 changes: 11 additions & 0 deletions tensorboard/webapp/metrics/store/metrics_selectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ import {
import {ColumnHeader, DataTableMode} from '../../widgets/data_table/types';
import {Extent} from '../../widgets/line_chart_v2/lib/public_types';
import {memoize} from '../../util/memoize';
import {getDashboardDisplayedHparamColumns} from '../../hparams/_redux/hparams_selectors';
import {DataTableUtils} from '../../widgets/data_table/utils';

const selectMetricsState =
createFeatureSelector<MetricsState>(METRICS_FEATURE_KEY);
Expand Down Expand Up @@ -661,3 +663,12 @@ export const getColumnHeadersForCard = memoize((cardId: string) => {
}
);
});

export const getGroupedHeadersForCard = memoize((cardId: string) =>
createSelector(
getColumnHeadersForCard(cardId),
getDashboardDisplayedHparamColumns,
(standardColumns, hparamColumns) =>
DataTableUtils.groupColumns([...standardColumns, ...hparamColumns])
)
);
196 changes: 196 additions & 0 deletions tensorboard/webapp/metrics/store/metrics_selectors_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ import {
} from '../../widgets/data_table/types';
import * as selectors from './metrics_selectors';
import {CardFeatureOverride, MetricsState} from './metrics_types';
import {buildMockState} from '../../testing/utils';
import {
buildHparamSpec,
buildHparamsState,
buildStateFromHparamsState,
} from '../../hparams/testing';
import {DeepPartial} from '../../util/types';
import {HparamsState} from '../../hparams/_redux/types';

describe('metrics selectors', () => {
beforeEach(() => {
Expand Down Expand Up @@ -1745,4 +1753,192 @@ describe('metrics selectors', () => {
).toEqual(rangeSelectionHeaders);
});
});

describe('getGroupedHeadersForCard', () => {
let singleSelectionHeaders: ColumnHeader[];
let rangeSelectionHeaders: ColumnHeader[];
let hparamsState: DeepPartial<HparamsState>;

beforeEach(() => {
singleSelectionHeaders = [
{
type: ColumnHeaderType.COLOR,
name: 'color',
displayName: 'Color',
enabled: true,
},
{
type: ColumnHeaderType.RUN,
name: 'run',
displayName: 'My Run name',
enabled: false,
},
];
rangeSelectionHeaders = [
{
type: ColumnHeaderType.MEAN,
name: 'mean',
displayName: 'Mean',
enabled: true,
},
{
type: ColumnHeaderType.RUN,
name: 'run',
displayName: 'My Run name',
enabled: false,
},
];
hparamsState = {
dashboardSpecs: {
hparams: [
buildHparamSpec({name: 'conv_layers'}),
buildHparamSpec({name: 'conv_kernel_size'}),
],
},
dashboardDisplayedHparamColumns: [
{
type: ColumnHeaderType.HPARAM,
name: 'conv_layers',
displayName: 'Conv Layers',
enabled: true,
},
{
type: ColumnHeaderType.HPARAM,
name: 'conv_kernel_size',
displayName: 'Conv Kernel Size',
enabled: true,
},
],
};
});

it('returns grouped single selection headers when card range selection is disabled', () => {
const state = buildMockState({
...appStateFromMetricsState(
buildMetricsState({
singleSelectionHeaders,
rangeSelectionHeaders,
cardStateMap: {
card1: {
rangeSelectionOverride:
CardFeatureOverride.OVERRIDE_AS_DISABLED,
},
},
})
),
...buildStateFromHparamsState(buildHparamsState(hparamsState)),
});

expect(selectors.getGroupedHeadersForCard('card1')(state)).toEqual([
{
type: ColumnHeaderType.RUN,
name: 'run',
displayName: 'My Run name',
enabled: false,
},
{
type: ColumnHeaderType.HPARAM,
name: 'conv_layers',
displayName: 'Conv Layers',
enabled: true,
},
{
type: ColumnHeaderType.HPARAM,
name: 'conv_kernel_size',
displayName: 'Conv Kernel Size',
enabled: true,
},
{
type: ColumnHeaderType.COLOR,
name: 'color',
displayName: 'Color',
enabled: true,
},
]);
});

it('returns grouped range selection headers when card range selection is enabled', () => {
const state = buildMockState({
...appStateFromMetricsState(
buildMetricsState({
singleSelectionHeaders,
rangeSelectionHeaders,
cardStateMap: {
card1: {
rangeSelectionOverride: CardFeatureOverride.OVERRIDE_AS_ENABLED,
},
},
})
),
...buildStateFromHparamsState(buildHparamsState(hparamsState)),
});

expect(selectors.getGroupedHeadersForCard('card1')(state)).toEqual([
{
type: ColumnHeaderType.RUN,
name: 'run',
displayName: 'My Run name',
enabled: false,
},
{
type: ColumnHeaderType.HPARAM,
name: 'conv_layers',
displayName: 'Conv Layers',
enabled: true,
},
{
type: ColumnHeaderType.HPARAM,
name: 'conv_kernel_size',
displayName: 'Conv Kernel Size',
enabled: true,
},
{
type: ColumnHeaderType.MEAN,
name: 'mean',
displayName: 'Mean',
enabled: true,
},
]);
});

it('returns grouped range selection headers when global range selection is enabled', () => {
const state = buildMockState({
...appStateFromMetricsState(
buildMetricsState({
singleSelectionHeaders,
rangeSelectionHeaders,
rangeSelectionEnabled: true,
})
),
...buildStateFromHparamsState(buildHparamsState(hparamsState)),
});

expect(selectors.getGroupedHeadersForCard('card1')(state)).toEqual([
{
type: ColumnHeaderType.RUN,
name: 'run',
displayName: 'My Run name',
enabled: false,
},
{
type: ColumnHeaderType.HPARAM,
name: 'conv_layers',
displayName: 'Conv Layers',
enabled: true,
},
{
type: ColumnHeaderType.HPARAM,
name: 'conv_kernel_size',
displayName: 'Conv Kernel Size',
enabled: true,
},
{
type: ColumnHeaderType.MEAN,
name: 'mean',
displayName: 'Mean',
enabled: true,
},
]);
});
});
});
4 changes: 1 addition & 3 deletions tensorboard/webapp/metrics/testing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,12 @@ import {
TimeSeriesRequest,
TimeSeriesResponse,
} from './data_source';
import * as selectors from './store/metrics_selectors';
import {
MetricsState,
METRICS_FEATURE_KEY,
TagMetadata,
TimeSeriesData,
} from './store';
import * as selectors from './store/metrics_selectors';
import {
CardStepIndexMetaData,
MetricsSettings,
RunToSeries,
Expand Down
Loading

0 comments on commit 304046d

Please sign in to comment.