diff --git a/tensorboard/webapp/BUILD b/tensorboard/webapp/BUILD index bd2230c338..25e45c27cf 100644 --- a/tensorboard/webapp/BUILD +++ b/tensorboard/webapp/BUILD @@ -301,6 +301,7 @@ tf_ng_web_test_suite( "//tensorboard/webapp/widgets/content_wrapping_input:content_wrapping_input_tests", "//tensorboard/webapp/widgets/custom_modal:custom_modal_test", "//tensorboard/webapp/widgets/data_table:data_table_test", + "//tensorboard/webapp/widgets/data_table:utils_test", "//tensorboard/webapp/widgets/dropdown:dropdown_tests", "//tensorboard/webapp/widgets/experiment_alias:experiment_alias_test", "//tensorboard/webapp/widgets/filter_input:filter_input_test", diff --git a/tensorboard/webapp/hparams/_redux/hparams_selectors.ts b/tensorboard/webapp/hparams/_redux/hparams_selectors.ts index 90c8b360d0..bd6054e3a5 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_selectors.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_selectors.ts @@ -47,8 +47,14 @@ export const getDashboardDefaultHparamFilters = createSelector( ); export const getDashboardDisplayedHparamColumns = createSelector( + getDashboardHparamsAndMetricsSpecs, getHparamsState, - (state) => state.dashboardDisplayedHparamColumns + ({hparams}, state) => { + const hparamSet = new Set(hparams.map((hparam) => hparam.name)); + return state.dashboardDisplayedHparamColumns.filter((column) => + hparamSet.has(column.name) + ); + } ); export const getDashboardHparamFilterMap = createSelector( diff --git a/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts b/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts index cc43c0dfa2..d57be572bb 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts @@ -15,6 +15,7 @@ limitations under the License. import {ColumnHeaderType} from '../../widgets/data_table/types'; import {DomainType} from '../types'; +import {State} from './types'; import * as selectors from './hparams_selectors'; import { buildHparamSpec, @@ -114,30 +115,63 @@ describe('hparams/_redux/hparams_selectors_test', () => { }); describe('#getDashboardDisplayedHparamColumns', () => { - it('returns dashboard displayed hparam columns', () => { - const fakeColumns = [ - { - type: ColumnHeaderType.HPARAM, - name: 'conv_layers', - displayName: 'Conv Layers', - enabled: true, - }, - { - type: ColumnHeaderType.HPARAM, - name: 'conv_kernel_size', - displayName: 'Conv Kernel Size', - enabled: true, - }, - ]; + it('returns no columns if no hparam specs', () => { const state = buildStateFromHparamsState( buildHparamsState({ - dashboardDisplayedHparamColumns: fakeColumns, + dashboardSpecs: { + hparams: [], + }, + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ], }) ); - expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual( - fakeColumns + expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual([]); + }); + + it('returns only hparam columns that have specs', () => { + const state = buildStateFromHparamsState( + buildHparamsState({ + dashboardSpecs: { + hparams: [buildHparamSpec({name: 'conv_layers'})], + }, + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ], + }) ); + + expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual([ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + ]); }); }); }); diff --git a/tensorboard/webapp/metrics/store/BUILD b/tensorboard/webapp/metrics/store/BUILD index 8db5c763f9..349dcfe15b 100644 --- a/tensorboard/webapp/metrics/store/BUILD +++ b/tensorboard/webapp/metrics/store/BUILD @@ -18,6 +18,7 @@ tf_ts_library( "//tensorboard/webapp/app_routing:types", "//tensorboard/webapp/app_routing/actions", "//tensorboard/webapp/core/actions", + "//tensorboard/webapp/hparams/_redux:hparams_selectors", "//tensorboard/webapp/metrics:types", "//tensorboard/webapp/metrics:utils", "//tensorboard/webapp/metrics/actions", @@ -33,6 +34,7 @@ tf_ts_library( "//tensorboard/webapp/util:types", "//tensorboard/webapp/widgets/card_fob:types", "//tensorboard/webapp/widgets/data_table:types", + "//tensorboard/webapp/widgets/data_table:utils", "//tensorboard/webapp/widgets/line_chart_v2/lib:public_types", "@npm//@ngrx/store", ], @@ -97,14 +99,18 @@ tf_ts_library( "//tensorboard/webapp/app_routing:types", "//tensorboard/webapp/app_routing/actions", "//tensorboard/webapp/core/actions", + "//tensorboard/webapp/hparams:testing", + "//tensorboard/webapp/hparams/_redux:types", "//tensorboard/webapp/metrics:test_lib", "//tensorboard/webapp/metrics:types", "//tensorboard/webapp/metrics/actions", "//tensorboard/webapp/metrics/data_source", "//tensorboard/webapp/persistent_settings", "//tensorboard/webapp/routes:testing", + "//tensorboard/webapp/testing:utils", "//tensorboard/webapp/types", "//tensorboard/webapp/util:dom", + "//tensorboard/webapp/util:types", "//tensorboard/webapp/widgets/card_fob:types", "//tensorboard/webapp/widgets/data_table:types", "@npm//@types/jasmine", diff --git a/tensorboard/webapp/metrics/store/metrics_selectors.ts b/tensorboard/webapp/metrics/store/metrics_selectors.ts index 9892241fb5..f01db8e327 100644 --- a/tensorboard/webapp/metrics/store/metrics_selectors.ts +++ b/tensorboard/webapp/metrics/store/metrics_selectors.ts @@ -50,6 +50,8 @@ import { import {ColumnHeader, DataTableMode} from '../../widgets/data_table/types'; import {Extent} from '../../widgets/line_chart_v2/lib/public_types'; import {memoize} from '../../util/memoize'; +import {getDashboardDisplayedHparamColumns} from '../../hparams/_redux/hparams_selectors'; +import {DataTableUtils} from '../../widgets/data_table/utils'; const selectMetricsState = createFeatureSelector(METRICS_FEATURE_KEY); @@ -661,3 +663,12 @@ export const getColumnHeadersForCard = memoize((cardId: string) => { } ); }); + +export const getGroupedHeadersForCard = memoize((cardId: string) => + createSelector( + getColumnHeadersForCard(cardId), + getDashboardDisplayedHparamColumns, + (standardColumns, hparamColumns) => + DataTableUtils.groupColumns([...standardColumns, ...hparamColumns]) + ) +); diff --git a/tensorboard/webapp/metrics/store/metrics_selectors_test.ts b/tensorboard/webapp/metrics/store/metrics_selectors_test.ts index a30cb45a0e..28b63e1a70 100644 --- a/tensorboard/webapp/metrics/store/metrics_selectors_test.ts +++ b/tensorboard/webapp/metrics/store/metrics_selectors_test.ts @@ -33,6 +33,14 @@ import { } from '../../widgets/data_table/types'; import * as selectors from './metrics_selectors'; import {CardFeatureOverride, MetricsState} from './metrics_types'; +import {buildMockState} from '../../testing/utils'; +import { + buildHparamSpec, + buildHparamsState, + buildStateFromHparamsState, +} from '../../hparams/testing'; +import {DeepPartial} from '../../util/types'; +import {HparamsState} from '../../hparams/_redux/types'; describe('metrics selectors', () => { beforeEach(() => { @@ -1745,4 +1753,192 @@ describe('metrics selectors', () => { ).toEqual(rangeSelectionHeaders); }); }); + + describe('getGroupedHeadersForCard', () => { + let singleSelectionHeaders: ColumnHeader[]; + let rangeSelectionHeaders: ColumnHeader[]; + let hparamsState: DeepPartial; + + beforeEach(() => { + singleSelectionHeaders = [ + { + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + ]; + rangeSelectionHeaders = [ + { + type: ColumnHeaderType.MEAN, + name: 'mean', + displayName: 'Mean', + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + ]; + hparamsState = { + dashboardSpecs: { + hparams: [ + buildHparamSpec({name: 'conv_layers'}), + buildHparamSpec({name: 'conv_kernel_size'}), + ], + }, + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ], + }; + }); + + it('returns grouped single selection headers when card range selection is disabled', () => { + const state = buildMockState({ + ...appStateFromMetricsState( + buildMetricsState({ + singleSelectionHeaders, + rangeSelectionHeaders, + cardStateMap: { + card1: { + rangeSelectionOverride: + CardFeatureOverride.OVERRIDE_AS_DISABLED, + }, + }, + }) + ), + ...buildStateFromHparamsState(buildHparamsState(hparamsState)), + }); + + expect(selectors.getGroupedHeadersForCard('card1')(state)).toEqual([ + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + { + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }, + ]); + }); + + it('returns grouped range selection headers when card range selection is enabled', () => { + const state = buildMockState({ + ...appStateFromMetricsState( + buildMetricsState({ + singleSelectionHeaders, + rangeSelectionHeaders, + cardStateMap: { + card1: { + rangeSelectionOverride: CardFeatureOverride.OVERRIDE_AS_ENABLED, + }, + }, + }) + ), + ...buildStateFromHparamsState(buildHparamsState(hparamsState)), + }); + + expect(selectors.getGroupedHeadersForCard('card1')(state)).toEqual([ + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + { + type: ColumnHeaderType.MEAN, + name: 'mean', + displayName: 'Mean', + enabled: true, + }, + ]); + }); + + it('returns grouped range selection headers when global range selection is enabled', () => { + const state = buildMockState({ + ...appStateFromMetricsState( + buildMetricsState({ + singleSelectionHeaders, + rangeSelectionHeaders, + rangeSelectionEnabled: true, + }) + ), + ...buildStateFromHparamsState(buildHparamsState(hparamsState)), + }); + + expect(selectors.getGroupedHeadersForCard('card1')(state)).toEqual([ + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + { + type: ColumnHeaderType.MEAN, + name: 'mean', + displayName: 'Mean', + enabled: true, + }, + ]); + }); + }); }); diff --git a/tensorboard/webapp/metrics/testing.ts b/tensorboard/webapp/metrics/testing.ts index de33b3b66d..4dc9757d89 100644 --- a/tensorboard/webapp/metrics/testing.ts +++ b/tensorboard/webapp/metrics/testing.ts @@ -28,14 +28,12 @@ import { TimeSeriesRequest, TimeSeriesResponse, } from './data_source'; +import * as selectors from './store/metrics_selectors'; import { MetricsState, METRICS_FEATURE_KEY, TagMetadata, TimeSeriesData, -} from './store'; -import * as selectors from './store/metrics_selectors'; -import { CardStepIndexMetaData, MetricsSettings, RunToSeries, diff --git a/tensorboard/webapp/metrics/views/main_view/BUILD b/tensorboard/webapp/metrics/views/main_view/BUILD index a0a2f17691..13e2f82ef5 100644 --- a/tensorboard/webapp/metrics/views/main_view/BUILD +++ b/tensorboard/webapp/metrics/views/main_view/BUILD @@ -91,6 +91,7 @@ tf_ts_library( "//tensorboard/webapp/metrics/views:utils", "//tensorboard/webapp/metrics/views/card_renderer:scalar_card_types", "//tensorboard/webapp/runs:types", + "//tensorboard/webapp/runs/store:selectors", "//tensorboard/webapp/runs/views/runs_table:types", "//tensorboard/webapp/util:matcher", "//tensorboard/webapp/util:memoize", diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts index 20be847c66..0836af7fb4 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts @@ -33,6 +33,7 @@ import { getDashboardHparamsAndMetricsSpecs, getDashboardHparamFilterMap, getDashboardDefaultHparamFilters, + getDashboardDisplayedHparamColumns, } from '../../../hparams/_redux/hparams_selectors'; import { DiscreteFilter, @@ -287,10 +288,22 @@ export const getPotentialHparamColumns = createSelector( sortable: true, movable: true, filterable: true, + hidable: true, })); } ); +export const getSelectableColumns = createSelector( + getPotentialHparamColumns, + getDashboardDisplayedHparamColumns, + (potentialColumns, currentColumns) => { + const currentColumnNames = new Set(currentColumns.map(({name}) => name)); + return potentialColumns.filter((columnHeader) => { + return !currentColumnNames.has(columnHeader.name); + }); + } +); + export const getAllPotentialColumnsForCard = memoize((cardId: string) => { return createSelector( getColumnHeadersForCard(cardId), diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts index 56e59784be..11b133f54b 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts @@ -161,7 +161,35 @@ describe('common selectors', () => { run4, }, } as any, - ui: {} as any, + ui: { + runsTableHeaders: [ + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'Run', + enabled: true, + sortable: true, + removable: false, + movable: false, + filterable: false, + hidable: false, + }, + { + type: ColumnHeaderType.CUSTOM, + name: 'experimentAlias', + displayName: 'Experiment', + enabled: true, + movable: false, + sortable: true, + }, + { + type: ColumnHeaderType.CUSTOM, + name: 'fakeRunsHeader', + displayName: 'Fake Runs Header', + enabled: true, + }, + ], + } as any, }, experiments: { data: { @@ -182,10 +210,35 @@ describe('common selectors', () => { }, hparams: { dashboardSpecs: { - hparams: [buildHparamSpec({name: 'foo', displayName: 'Foo'})], + hparams: [ + buildHparamSpec({name: 'conv_layers', displayName: 'Conv Layers'}), + buildHparamSpec({ + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + }), + buildHparamSpec({ + name: 'dense_layers', + displayName: 'Dense Layers', + }), + buildHparamSpec({name: 'dropout', displayName: 'Dropout'}), + ], metrics: [buildMetricSpec({displayName: 'Bar'})], }, dashboardSessionGroups: [], + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + enabled: true, + }, + ], } as any, }); }); @@ -962,6 +1015,15 @@ describe('common selectors', () => { }); describe('getPotentialHparamColumns', () => { + const expectedBooleanFlags = { + enabled: false, + removable: true, + sortable: true, + movable: true, + filterable: true, + hidable: true, + }; + it('returns empty list when there are no experiments', () => { state.app_routing!.activeRoute!.routeKind = RouteKind.EXPERIMENTS; @@ -972,42 +1034,87 @@ describe('common selectors', () => { expect(selectors.getPotentialHparamColumns(state)).toEqual([ { type: ColumnHeaderType.HPARAM, - name: 'foo', - displayName: 'Foo', - enabled: false, - removable: true, - sortable: true, - movable: true, - filterable: true, + name: 'conv_layers', + displayName: 'Conv Layers', + ...expectedBooleanFlags, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + ...expectedBooleanFlags, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + ...expectedBooleanFlags, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'dropout', + displayName: 'Dropout', + ...expectedBooleanFlags, }, ]); }); it('sets name as display name when a display name is not provided', () => { - state.hparams!.dashboardSpecs.hparams.push( - buildHparamSpec({name: 'bar', displayName: ''}) - ); + state.hparams!.dashboardSpecs.hparams = [ + buildHparamSpec({name: 'conv_layers', displayName: ''}), + ]; + expect(selectors.getPotentialHparamColumns(state)).toEqual([ { type: ColumnHeaderType.HPARAM, - name: 'foo', - displayName: 'Foo', - enabled: false, - removable: true, - sortable: true, - movable: true, - filterable: true, + name: 'conv_layers', + displayName: 'conv_layers', + ...expectedBooleanFlags, }, - { + ]); + }); + }); + + describe('getSelectableColumns', () => { + it('returns the full list of hparam columns if none are currently displayed', () => { + state.hparams!.dashboardDisplayedHparamColumns = []; + + expect(selectors.getSelectableColumns(state)).toEqual([ + jasmine.objectContaining({ type: ColumnHeaderType.HPARAM, - name: 'bar', - displayName: 'bar', - enabled: false, - removable: true, - sortable: true, - movable: true, - filterable: true, - }, + name: 'conv_layers', + displayName: 'Conv Layers', + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'dropout', + displayName: 'Dropout', + }), + ]); + }); + + it('returns only columns that are not displayed', () => { + expect(selectors.getSelectableColumns(state)).toEqual([ + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'dropout', + displayName: 'Dropout', + }), ]); }); }); diff --git a/tensorboard/webapp/runs/store/BUILD b/tensorboard/webapp/runs/store/BUILD index ef603b75ec..c0f9fa24bd 100644 --- a/tensorboard/webapp/runs/store/BUILD +++ b/tensorboard/webapp/runs/store/BUILD @@ -52,6 +52,7 @@ tf_ts_library( "//tensorboard/webapp/types", "//tensorboard/webapp/types:ui", "//tensorboard/webapp/widgets/data_table:types", + "//tensorboard/webapp/widgets/data_table:utils", "@npm//@ngrx/store", ], ) diff --git a/tensorboard/webapp/runs/store/runs_selectors.ts b/tensorboard/webapp/runs/store/runs_selectors.ts index e8cff472f2..5e5d091636 100644 --- a/tensorboard/webapp/runs/store/runs_selectors.ts +++ b/tensorboard/webapp/runs/store/runs_selectors.ts @@ -26,9 +26,13 @@ import { } from './runs_types'; import {createGroupBy} from './utils'; import {getExperimentIdsFromRoute} from '../../app_routing/store/app_routing_selectors'; -import {getDashboardSessionGroups} from '../../hparams/_redux/hparams_selectors'; +import { + getDashboardDisplayedHparamColumns, + getDashboardSessionGroups, +} from '../../hparams/_redux/hparams_selectors'; import {HparamValue, RunToHparamsAndMetrics} from '../../hparams/types'; import {ColumnHeader, SortingInfo} from '../../widgets/data_table/types'; +import {DataTableUtils} from '../../widgets/data_table/utils'; const getRunsState = createFeatureSelector(RUNS_FEATURE_KEY); @@ -301,7 +305,7 @@ export const getColorGroupRegexString = createSelector( ); /** - * Gets the columns to be displayed by the runs table. + * Gets the standard columns to be displayed by the runs table. */ export const getRunsTableHeaders = createSelector( getUiState, @@ -319,3 +323,13 @@ export const getRunsTableSortingInfo = createSelector( return state.sortingInfo; } ); + +/** + * Gets the grouped columns to be displayed by the runs table. + */ +export const getGroupedRunsTableHeaders = createSelector( + getRunsTableHeaders, + getDashboardDisplayedHparamColumns, + (runsTableHeaders, hparamColumns) => + DataTableUtils.groupColumns([...runsTableHeaders, ...hparamColumns]) +); diff --git a/tensorboard/webapp/runs/store/runs_selectors_test.ts b/tensorboard/webapp/runs/store/runs_selectors_test.ts index 01363d880c..3d31b4e72f 100644 --- a/tensorboard/webapp/runs/store/runs_selectors_test.ts +++ b/tensorboard/webapp/runs/store/runs_selectors_test.ts @@ -21,6 +21,7 @@ import { buildSessionGroup, buildStateFromHparamsState, buildHparamsState, + buildHparamSpec, } from '../../hparams/testing'; import {buildMockState} from '../../testing/utils'; import {DataLoadState} from '../../types/data'; @@ -1027,6 +1028,95 @@ describe('runs_selectors', () => { }); }); + describe('#getGroupedRunsTableHeaders', () => { + it('returns runs table headers grouped with other headers', () => { + const state = buildMockState({ + runs: buildRunsState( + {}, + { + runsTableHeaders: [ + { + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }, + { + type: ColumnHeaderType.CUSTOM, + name: 'experimentAlias', + displayName: 'Experiment Alias', + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'Run', + enabled: true, + }, + ], + } + ), + ...buildStateFromHparamsState( + buildHparamsState({ + dashboardSpecs: { + hparams: [ + buildHparamSpec({name: 'conv_layers'}), + buildHparamSpec({name: 'conv_kernel_size'}), + ], + }, + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ], + }) + ), + }); + + expect(selectors.getGroupedRunsTableHeaders(state)).toEqual([ + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'Run', + enabled: true, + }, + { + type: ColumnHeaderType.CUSTOM, + name: 'experimentAlias', + displayName: 'Experiment Alias', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + { + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }, + ]); + }); + }); + describe('#getRunsTableSortingInfo', () => { it('returns the runs data table sorting info', () => { const state = buildMockState({ diff --git a/tensorboard/webapp/testing/BUILD b/tensorboard/webapp/testing/BUILD index 7fa69f955f..1c2722990b 100644 --- a/tensorboard/webapp/testing/BUILD +++ b/tensorboard/webapp/testing/BUILD @@ -83,7 +83,7 @@ tf_ts_library( "//tensorboard/webapp/hparams/_redux:testing", "//tensorboard/webapp/hparams/_redux:types", "//tensorboard/webapp/metrics:test_lib", - "//tensorboard/webapp/metrics/store", + "//tensorboard/webapp/metrics/store:types", "//tensorboard/webapp/notification_center/_redux:testing", "//tensorboard/webapp/notification_center/_redux:types", "//tensorboard/webapp/persistent_settings/_redux:testing", diff --git a/tensorboard/webapp/testing/utils.ts b/tensorboard/webapp/testing/utils.ts index ad63ce8a86..f2c78daad4 100644 --- a/tensorboard/webapp/testing/utils.ts +++ b/tensorboard/webapp/testing/utils.ts @@ -45,7 +45,7 @@ import { buildHparamsState, buildStateFromHparamsState, } from '../hparams/_redux/testing'; -import {METRICS_FEATURE_KEY} from '../metrics/store'; +import {METRICS_FEATURE_KEY} from '../metrics/store/metrics_types'; import {appStateFromMetricsState, buildMetricsState} from '../metrics/testing'; import {NOTIFICATION_FEATURE_KEY} from '../notification_center/_redux/notification_center_types'; import { diff --git a/tensorboard/webapp/widgets/data_table/BUILD b/tensorboard/webapp/widgets/data_table/BUILD index c43110f490..043e98805a 100644 --- a/tensorboard/webapp/widgets/data_table/BUILD +++ b/tensorboard/webapp/widgets/data_table/BUILD @@ -164,6 +164,29 @@ tf_ts_library( ], ) +tf_ts_library( + name = "utils", + srcs = [ + "utils.ts", + ], + deps = [ + ":types", + ], +) + +tf_ts_library( + name = "utils_test", + testonly = True, + srcs = [ + "utils_test.ts", + ], + deps = [ + ":types", + ":utils", + "@npm//@types/jasmine", + ], +) + tf_ts_library( name = "data_table_test", testonly = True, diff --git a/tensorboard/webapp/widgets/data_table/types.ts b/tensorboard/webapp/widgets/data_table/types.ts index 7b004c0a00..d5f6b566e7 100644 --- a/tensorboard/webapp/widgets/data_table/types.ts +++ b/tensorboard/webapp/widgets/data_table/types.ts @@ -129,3 +129,10 @@ export interface AddColumnEvent { nextTo?: ColumnHeader | undefined; side?: Side | undefined; } + +export enum ColumnGroup { + RUN = 'RUN', + EXPERIMENT_ALIAS = 'EXPERIMENT_ALIAS', + HPARAM = 'HPARAM', + OTHER = 'OTHER', +} diff --git a/tensorboard/webapp/widgets/data_table/utils.ts b/tensorboard/webapp/widgets/data_table/utils.ts new file mode 100644 index 0000000000..3a64980dcb --- /dev/null +++ b/tensorboard/webapp/widgets/data_table/utils.ts @@ -0,0 +1,48 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {ColumnHeader, ColumnGroup} from './types'; + +function columnToGroup(column: ColumnHeader): ColumnGroup { + if (column.type === 'RUN') { + return ColumnGroup.RUN; + } else if (column.type === 'CUSTOM' && column.name === 'experimentAlias') { + return ColumnGroup.EXPERIMENT_ALIAS; + } else if (column.type === 'HPARAM') { + return ColumnGroup.HPARAM; + } else { + return ColumnGroup.OTHER; + } +} + +/** + * Sorts columns into predefined groups. + * + * Preserves relative column order within groups. + */ +function groupColumns(columns: ColumnHeader[]): ColumnHeader[] { + // Using Map ensures that keys preserve order. + const headerGroups = new Map([ + [ColumnGroup.RUN, []], + [ColumnGroup.EXPERIMENT_ALIAS, []], + [ColumnGroup.HPARAM, []], + [ColumnGroup.OTHER, []], + ]); + columns.forEach((column) => { + headerGroups.get(columnToGroup(column))?.push(column); + }); + return Array.from(headerGroups.values()).flat(); +} + +export const DataTableUtils = { + groupColumns, +}; diff --git a/tensorboard/webapp/widgets/data_table/utils_test.ts b/tensorboard/webapp/widgets/data_table/utils_test.ts new file mode 100644 index 0000000000..4f4ae5c8c3 --- /dev/null +++ b/tensorboard/webapp/widgets/data_table/utils_test.ts @@ -0,0 +1,98 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {ColumnHeaderType} from './types'; +import {DataTableUtils} from './utils'; + +describe('data table utils', () => { + describe('groupColumns', () => { + it('groups columns according to a predefined order', () => { + const inputColumns = [ + { + type: ColumnHeaderType.VALUE, + name: 'value', + displayName: 'Value', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', + enabled: true, + }, + { + type: ColumnHeaderType.CUSTOM, + name: 'experimentAlias', + displayName: 'Experiment Alias', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'Run', + enabled: true, + }, + ]; + + expect(DataTableUtils.groupColumns(inputColumns)).toEqual([ + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'Run', + enabled: true, + }, + { + type: ColumnHeaderType.CUSTOM, + name: 'experimentAlias', + displayName: 'Experiment Alias', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + { + type: ColumnHeaderType.VALUE, + name: 'value', + displayName: 'Value', + enabled: true, + }, + { + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', + enabled: true, + }, + ]); + }); + }); +});