From 89fc04a3eb5a09d66694c72ff1b167816ee7c441 Mon Sep 17 00:00:00 2001 From: Hoonji <736199+hoonji@users.noreply.github.com> Date: Wed, 24 Jan 2024 12:18:53 +0900 Subject: [PATCH 1/6] Adds shared hparam related common selectors --- .../hparams/_redux/hparams_selectors.ts | 6 +- .../hparams/_redux/hparams_selectors_test.ts | 70 +++-- .../webapp/metrics/views/main_view/BUILD | 1 + .../views/main_view/common_selectors.ts | 60 ++++- .../views/main_view/common_selectors_test.ts | 246 ++++++++++++++---- 5 files changed, 299 insertions(+), 84 deletions(-) diff --git a/tensorboard/webapp/hparams/_redux/hparams_selectors.ts b/tensorboard/webapp/hparams/_redux/hparams_selectors.ts index 90c8b360d0..b5312c2cf9 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_selectors.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_selectors.ts @@ -47,8 +47,12 @@ export const getDashboardDefaultHparamFilters = createSelector( ); export const getDashboardDisplayedHparamColumns = createSelector( + getDashboardHparamsAndMetricsSpecs, getHparamsState, - (state) => state.dashboardDisplayedHparamColumns + ({hparams}, state) => { + const hparamSet = new Set(hparams.map(hparam => hparam.name)); + return state.dashboardDisplayedHparamColumns.filter(column => hparamSet.has(column.name)); + } ); export const getDashboardHparamFilterMap = createSelector( diff --git a/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts b/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts index cc43c0dfa2..fcdc2e5a6e 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts @@ -15,6 +15,7 @@ limitations under the License. import {ColumnHeaderType} from '../../widgets/data_table/types'; import {DomainType} from '../types'; +import {State} from './types'; import * as selectors from './hparams_selectors'; import { buildHparamSpec, @@ -114,30 +115,63 @@ describe('hparams/_redux/hparams_selectors_test', () => { }); describe('#getDashboardDisplayedHparamColumns', () => { - it('returns dashboard displayed hparam columns', () => { - const fakeColumns = [ - { - type: ColumnHeaderType.HPARAM, - name: 'conv_layers', - displayName: 'Conv Layers', - enabled: true, - }, - { - type: ColumnHeaderType.HPARAM, - name: 'conv_kernel_size', - displayName: 'Conv Kernel Size', - enabled: true, - }, - ]; + it('returns no columns if no hparam specs', () => { const state = buildStateFromHparamsState( buildHparamsState({ - dashboardDisplayedHparamColumns: fakeColumns, + dashboardSpecs: { + hparams: [], + }, + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ], }) ); - expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual( - fakeColumns + expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual([]); + }); + + it('returns only dashboard displayed hparam columns', () => { + const state = buildStateFromHparamsState( + buildHparamsState({ + dashboardSpecs: { + hparams: [buildHparamSpec({name: 'conv_layers'})], + }, + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ], + }) ); + + expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual([ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + ]); }); }); }); diff --git a/tensorboard/webapp/metrics/views/main_view/BUILD b/tensorboard/webapp/metrics/views/main_view/BUILD index a0a2f17691..13e2f82ef5 100644 --- a/tensorboard/webapp/metrics/views/main_view/BUILD +++ b/tensorboard/webapp/metrics/views/main_view/BUILD @@ -91,6 +91,7 @@ tf_ts_library( "//tensorboard/webapp/metrics/views:utils", "//tensorboard/webapp/metrics/views/card_renderer:scalar_card_types", "//tensorboard/webapp/runs:types", + "//tensorboard/webapp/runs/store:selectors", "//tensorboard/webapp/runs/views/runs_table:types", "//tensorboard/webapp/util:matcher", "//tensorboard/webapp/util:memoize", diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts index 20be847c66..c709fefad3 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {createSelector} from '@ngrx/store'; -import {State} from '../../../app_state'; +import { createSelector, Selector } from '@ngrx/store'; +import { State } from '../../../app_state'; import { getCurrentRouteRunSelection, getMetricsHideEmptyCards, @@ -27,12 +27,13 @@ import { getColumnHeadersForCard, getDashboardExperimentNames, } from '../../../selectors'; -import {DeepReadonly} from '../../../util/types'; +import { DeepReadonly } from '../../../util/types'; import { getDashboardMetricsFilterMap, getDashboardHparamsAndMetricsSpecs, getDashboardHparamFilterMap, getDashboardDefaultHparamFilters, + getDashboardDisplayedHparamColumns, } from '../../../hparams/_redux/hparams_selectors'; import { DiscreteFilter, @@ -44,13 +45,14 @@ import { RunTableItem, RunTableExperimentItem, } from '../../../runs/views/runs_table/types'; -import {matchRunToRegex} from '../../../util/matcher'; -import {isSingleRunPlugin, PluginType} from '../../data_source'; -import {getNonEmptyCardIdsWithMetadata, TagMetadata} from '../../store'; -import {compareTagNames} from '../../utils'; -import {CardIdWithMetadata} from '../metrics_view_types'; -import {RouteKind} from '../../../app_routing/types'; -import {memoize} from '../../../util/memoize'; +import { getRunsTableHeaders } from '../../../runs/store/runs_selectors'; +import { matchRunToRegex } from '../../../util/matcher'; +import { isSingleRunPlugin, PluginType } from '../../data_source'; +import { getNonEmptyCardIdsWithMetadata, TagMetadata } from '../../store'; +import { compareTagNames } from '../../utils'; +import { CardIdWithMetadata } from '../metrics_view_types'; +import { RouteKind } from '../../../app_routing/types'; +import { memoize } from '../../../util/memoize'; import { ColumnHeader, ColumnHeaderType, @@ -166,7 +168,7 @@ const utils = { hparamFilters: Map, metricFilters: Map ) { - return runItems.filter(({hparams, metrics}) => { + return runItems.filter(({ hparams, metrics }) => { const hparamMatches = [...hparamFilters.entries()].every( ([hparamName, filter]) => { const value = hparams.get(hparamName); @@ -264,14 +266,14 @@ export const getFilteredRenderableRuns = createSelector( export const getFilteredRenderableRunsIds = createSelector( getFilteredRenderableRuns, (filteredRenderableRuns) => { - return new Set(filteredRenderableRuns.map(({run: {id}}) => id)); + return new Set(filteredRenderableRuns.map(({ run: { id } }) => id)); } ); export const getPotentialHparamColumns = createSelector( getDashboardHparamsAndMetricsSpecs, getExperimentIdsFromRoute, - ({hparams}, experimentIds): ColumnHeader[] => { + ({ hparams }, experimentIds): ColumnHeader[] => { if (!experimentIds) { return []; } @@ -287,10 +289,42 @@ export const getPotentialHparamColumns = createSelector( sortable: true, movable: true, filterable: true, + hidable: true, })); } ); +export const getSelectableColumns = createSelector( + getPotentialHparamColumns, + getDashboardDisplayedHparamColumns, + (potentialColumns, currentColumns) => { + const currentColumnNames = new Set(currentColumns.map(({ name }) => name)); + return potentialColumns.filter((columnHeader) => { + return !currentColumnNames.has(columnHeader.name); + }); + } +); + +/** Returns a list of columns that have been sorted into logical groups. + * + * Column order: | RUN | experimentAlias | HPARAMs | other | +*/ +export const getGroupedColumns = ( + headersSelector: Selector +) => + createSelector( + headersSelector, + getDashboardDisplayedHparamColumns, + (tableHeaders, hparamHeaders): ColumnHeader[] => { + return [ + ...tableHeaders.filter((header) => header.type === 'RUN'), + ...tableHeaders.filter((header) => header.type === 'CUSTOM' && header.name === 'experimentAlias'), + ...hparamHeaders, + ...tableHeaders.filter((header) => header.type !== 'RUN' && !(header.type === 'CUSTOM' && header.name === 'experimentAlias')), + ]; + } + ); + export const getAllPotentialColumnsForCard = memoize((cardId: string) => { return createSelector( getColumnHeadersForCard(cardId), diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts index 56e59784be..f8838684fb 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {RouteKind} from '../../../app_routing'; +import { RouteKind } from '../../../app_routing'; import { buildHparamSpec, buildMetricSpec, @@ -21,25 +21,26 @@ import { buildAppRoutingState, buildStateFromAppRoutingState, } from '../../../app_routing/store/testing'; -import {buildRoute} from '../../../app_routing/testing'; -import {buildExperiment} from '../../../experiments/store/testing'; -import {IntervalFilter, DiscreteFilter} from '../../../hparams/types'; -import {DomainType, Run} from '../../../runs/store/runs_types'; +import { buildRoute } from '../../../app_routing/testing'; +import { buildExperiment } from '../../../experiments/store/testing'; +import { IntervalFilter, DiscreteFilter } from '../../../hparams/types'; +import { DomainType, Run } from '../../../runs/store/runs_types'; +import { getRunsTableHeaders } from '../../../runs/store/runs_selectors'; import { buildRun, buildRunsState, buildStateFromRunsState, } from '../../../runs/store/testing'; -import {RunTableItem} from '../../../runs/views/runs_table/types'; -import {buildMockState} from '../../../testing/utils'; +import { RunTableItem } from '../../../runs/views/runs_table/types'; +import { buildMockState } from '../../../testing/utils'; import { appStateFromMetricsState, buildMetricsSettingsState, buildMetricsState, } from '../../testing'; -import {PluginType} from '../../types'; +import { PluginType } from '../../types'; import * as selectors from './common_selectors'; -import {ColumnHeaderType} from '../card_renderer/scalar_card_types'; +import { ColumnHeaderType } from '../card_renderer/scalar_card_types'; describe('common selectors', () => { let runIds: Record; @@ -55,7 +56,7 @@ describe('common selectors', () => { let state: ReturnType; beforeEach(() => { - runIds = {defaultExperimentId: ['run1', 'run2', 'run3']}; + runIds = { defaultExperimentId: ['run1', 'run2', 'run3'] }; runIdToExpId = { run1: 'defaultExperimentId', run2: 'defaultExperimentId', @@ -142,10 +143,10 @@ describe('common selectors', () => { }, ]; - run1 = buildRun({name: 'run 1'}); - run2 = buildRun({id: '2', name: 'run 2'}); - run3 = buildRun({id: '3', name: 'run 3'}); - run4 = buildRun({id: '4', name: 'run 4'}); + run1 = buildRun({ name: 'run 1' }); + run2 = buildRun({ id: '2', name: 'run 2' }); + run3 = buildRun({ id: '3', name: 'run 3' }); + run4 = buildRun({ id: '4', name: 'run 4' }); state = buildMockState({ runs: { data: { @@ -161,13 +162,41 @@ describe('common selectors', () => { run4, }, } as any, - ui: {} as any, + ui: { + runsTableHeaders: [ + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'Run', + enabled: true, + sortable: true, + removable: false, + movable: false, + filterable: false, + hidable: false, + }, + { + type: ColumnHeaderType.CUSTOM, + name: 'experimentAlias', + displayName: 'Experiment', + enabled: true, + movable: false, + sortable: true, + }, + { + type: ColumnHeaderType.CUSTOM, + name: 'fakeRunsHeader', + displayName: 'Fake Runs Header', + enabled: true, + }, + ] + } as any, }, experiments: { data: { experimentMap: { - exp1: buildExperiment({name: 'experiment1', id: 'exp1'}), - exp2: buildExperiment({name: 'experiment2', id: 'exp2'}), + exp1: buildExperiment({ name: 'experiment1', id: 'exp1' }), + exp2: buildExperiment({ name: 'experiment2', id: 'exp2' }), }, }, }, @@ -182,10 +211,35 @@ describe('common selectors', () => { }, hparams: { dashboardSpecs: { - hparams: [buildHparamSpec({name: 'foo', displayName: 'Foo'})], - metrics: [buildMetricSpec({displayName: 'Bar'})], + hparams: [ + buildHparamSpec({ name: 'conv_layers', displayName: 'Conv Layers' }), + buildHparamSpec({ + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + }), + buildHparamSpec({ + name: 'dense_layers', + displayName: 'Dense Layers', + }), + buildHparamSpec({ name: 'dropout', displayName: 'Dropout' }), + ], + metrics: [buildMetricSpec({ displayName: 'Bar' })], }, dashboardSessionGroups: [], + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + enabled: true, + }, + ], } as any, }); }); @@ -715,11 +769,11 @@ describe('common selectors', () => { state.app_routing!.activeRoute!.routeKind = RouteKind.COMPARE_EXPERIMENT; const results = selectors.TEST_ONLY.getRenderableRuns(state); expect(results.length).toEqual(5); - expect(results[0].run).toEqual({...run1, experimentId: 'exp1'}); - expect(results[1].run).toEqual({...run2, experimentId: 'exp1'}); - expect(results[2].run).toEqual({...run2, experimentId: 'exp2'}); - expect(results[3].run).toEqual({...run3, experimentId: 'exp2'}); - expect(results[4].run).toEqual({...run4, experimentId: 'exp2'}); + expect(results[0].run).toEqual({ ...run1, experimentId: 'exp1' }); + expect(results[1].run).toEqual({ ...run2, experimentId: 'exp1' }); + expect(results[2].run).toEqual({ ...run2, experimentId: 'exp2' }); + expect(results[3].run).toEqual({ ...run3, experimentId: 'exp2' }); + expect(results[4].run).toEqual({ ...run4, experimentId: 'exp2' }); }); it('returns empty list when route does not contain experiments', () => { @@ -911,7 +965,7 @@ describe('common selectors', () => { state.runs!.data.regexFilter = 'foo'; state.app_routing!.activeRoute = { routeKind: RouteKind.EXPERIMENT, - params: {experimentIds: 'exp1'}, + params: { experimentIds: 'exp1' }, }; const result = selectors.getFilteredRenderableRuns(state); expect(result).toEqual([]); @@ -933,7 +987,7 @@ describe('common selectors', () => { ).and.callThrough(); state.app_routing!.activeRoute = { routeKind: RouteKind.EXPERIMENT, - params: {experimentIds: 'exp1'}, + params: { experimentIds: 'exp1' }, }; const results = selectors.getFilteredRenderableRuns(state); expect(spy).toHaveBeenCalledOnceWith(results, new Map(), new Map()); @@ -962,6 +1016,15 @@ describe('common selectors', () => { }); describe('getPotentialHparamColumns', () => { + const expectedBooleanFlags = { + enabled: false, + removable: true, + sortable: true, + movable: true, + filterable: true, + hidable: true, + }; + it('returns empty list when there are no experiments', () => { state.app_routing!.activeRoute!.routeKind = RouteKind.EXPERIMENTS; @@ -972,42 +1035,121 @@ describe('common selectors', () => { expect(selectors.getPotentialHparamColumns(state)).toEqual([ { type: ColumnHeaderType.HPARAM, - name: 'foo', - displayName: 'Foo', - enabled: false, - removable: true, - sortable: true, - movable: true, - filterable: true, + name: 'conv_layers', + displayName: 'Conv Layers', + ...expectedBooleanFlags, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + ...expectedBooleanFlags, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + ...expectedBooleanFlags, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'dropout', + displayName: 'Dropout', + ...expectedBooleanFlags, }, ]); }); it('sets name as display name when a display name is not provided', () => { - state.hparams!.dashboardSpecs.hparams.push( - buildHparamSpec({name: 'bar', displayName: ''}) - ); + state.hparams!.dashboardSpecs.hparams = [ + buildHparamSpec({ name: 'conv_layers', displayName: '' }), + ]; + expect(selectors.getPotentialHparamColumns(state)).toEqual([ { type: ColumnHeaderType.HPARAM, - name: 'foo', - displayName: 'Foo', - enabled: false, - removable: true, - sortable: true, - movable: true, - filterable: true, + name: 'conv_layers', + displayName: 'conv_layers', + ...expectedBooleanFlags, }, - { + ]); + }); + }); + + describe('getSelectableColumns', () => { + it('returns the full list of hparam columns if none are currently displayed', () => { + state.hparams!.dashboardDisplayedHparamColumns = []; + + expect(selectors.getSelectableColumns(state)).toEqual([ + jasmine.objectContaining({ type: ColumnHeaderType.HPARAM, - name: 'bar', - displayName: 'bar', - enabled: false, - removable: true, - sortable: true, - movable: true, - filterable: true, - }, + name: 'conv_layers', + displayName: 'Conv Layers', + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'dropout', + displayName: 'Dropout', + }), + ]); + }); + + it('returns only columns that are not displayed', () => { + expect(selectors.getSelectableColumns(state)).toEqual([ + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'dropout', + displayName: 'Dropout', + }), + ]); + }); + }); + + describe('getGroupedColumns', () => { + it('returns a grouped list of columns given a list of standard columns', () => { + expect(selectors.getGroupedColumns(getRunsTableHeaders)(state)).toEqual([ + jasmine.objectContaining({ + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'Run', + }), + jasmine.objectContaining({ + type: ColumnHeaderType.CUSTOM, + name: 'experimentAlias', + displayName: 'Experiment', + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + enabled: true, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.CUSTOM, + name: 'fakeRunsHeader', + displayName: 'Fake Runs Header', + }), ]); }); }); From bed9604fe0116bab83f5bfda59bbf049a7517d33 Mon Sep 17 00:00:00 2001 From: Hoonji <736199+hoonji@users.noreply.github.com> Date: Wed, 24 Jan 2024 21:01:50 +0900 Subject: [PATCH 2/6] Fixes test description --- tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts b/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts index fcdc2e5a6e..d57be572bb 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts @@ -141,7 +141,7 @@ describe('hparams/_redux/hparams_selectors_test', () => { expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual([]); }); - it('returns only dashboard displayed hparam columns', () => { + it('returns only hparam columns that have specs', () => { const state = buildStateFromHparamsState( buildHparamsState({ dashboardSpecs: { From 9eef73499564ae35a23fe70d15d90a4715ed037b Mon Sep 17 00:00:00 2001 From: Hoonji <736199+hoonji@users.noreply.github.com> Date: Wed, 24 Jan 2024 21:04:59 +0900 Subject: [PATCH 3/6] Fixes formatting --- .../hparams/_redux/hparams_selectors.ts | 6 +- .../views/main_view/common_selectors.ts | 45 ++++++++------ .../views/main_view/common_selectors_test.ts | 58 +++++++++---------- 3 files changed, 59 insertions(+), 50 deletions(-) diff --git a/tensorboard/webapp/hparams/_redux/hparams_selectors.ts b/tensorboard/webapp/hparams/_redux/hparams_selectors.ts index b5312c2cf9..bd6054e3a5 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_selectors.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_selectors.ts @@ -50,8 +50,10 @@ export const getDashboardDisplayedHparamColumns = createSelector( getDashboardHparamsAndMetricsSpecs, getHparamsState, ({hparams}, state) => { - const hparamSet = new Set(hparams.map(hparam => hparam.name)); - return state.dashboardDisplayedHparamColumns.filter(column => hparamSet.has(column.name)); + const hparamSet = new Set(hparams.map((hparam) => hparam.name)); + return state.dashboardDisplayedHparamColumns.filter((column) => + hparamSet.has(column.name) + ); } ); diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts index c709fefad3..76354a4ad9 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import { createSelector, Selector } from '@ngrx/store'; -import { State } from '../../../app_state'; +import {createSelector, Selector} from '@ngrx/store'; +import {State} from '../../../app_state'; import { getCurrentRouteRunSelection, getMetricsHideEmptyCards, @@ -27,7 +27,7 @@ import { getColumnHeadersForCard, getDashboardExperimentNames, } from '../../../selectors'; -import { DeepReadonly } from '../../../util/types'; +import {DeepReadonly} from '../../../util/types'; import { getDashboardMetricsFilterMap, getDashboardHparamsAndMetricsSpecs, @@ -45,14 +45,14 @@ import { RunTableItem, RunTableExperimentItem, } from '../../../runs/views/runs_table/types'; -import { getRunsTableHeaders } from '../../../runs/store/runs_selectors'; -import { matchRunToRegex } from '../../../util/matcher'; -import { isSingleRunPlugin, PluginType } from '../../data_source'; -import { getNonEmptyCardIdsWithMetadata, TagMetadata } from '../../store'; -import { compareTagNames } from '../../utils'; -import { CardIdWithMetadata } from '../metrics_view_types'; -import { RouteKind } from '../../../app_routing/types'; -import { memoize } from '../../../util/memoize'; +import {getRunsTableHeaders} from '../../../runs/store/runs_selectors'; +import {matchRunToRegex} from '../../../util/matcher'; +import {isSingleRunPlugin, PluginType} from '../../data_source'; +import {getNonEmptyCardIdsWithMetadata, TagMetadata} from '../../store'; +import {compareTagNames} from '../../utils'; +import {CardIdWithMetadata} from '../metrics_view_types'; +import {RouteKind} from '../../../app_routing/types'; +import {memoize} from '../../../util/memoize'; import { ColumnHeader, ColumnHeaderType, @@ -168,7 +168,7 @@ const utils = { hparamFilters: Map, metricFilters: Map ) { - return runItems.filter(({ hparams, metrics }) => { + return runItems.filter(({hparams, metrics}) => { const hparamMatches = [...hparamFilters.entries()].every( ([hparamName, filter]) => { const value = hparams.get(hparamName); @@ -266,14 +266,14 @@ export const getFilteredRenderableRuns = createSelector( export const getFilteredRenderableRunsIds = createSelector( getFilteredRenderableRuns, (filteredRenderableRuns) => { - return new Set(filteredRenderableRuns.map(({ run: { id } }) => id)); + return new Set(filteredRenderableRuns.map(({run: {id}}) => id)); } ); export const getPotentialHparamColumns = createSelector( getDashboardHparamsAndMetricsSpecs, getExperimentIdsFromRoute, - ({ hparams }, experimentIds): ColumnHeader[] => { + ({hparams}, experimentIds): ColumnHeader[] => { if (!experimentIds) { return []; } @@ -298,7 +298,7 @@ export const getSelectableColumns = createSelector( getPotentialHparamColumns, getDashboardDisplayedHparamColumns, (potentialColumns, currentColumns) => { - const currentColumnNames = new Set(currentColumns.map(({ name }) => name)); + const currentColumnNames = new Set(currentColumns.map(({name}) => name)); return potentialColumns.filter((columnHeader) => { return !currentColumnNames.has(columnHeader.name); }); @@ -306,9 +306,9 @@ export const getSelectableColumns = createSelector( ); /** Returns a list of columns that have been sorted into logical groups. - * + * * Column order: | RUN | experimentAlias | HPARAMs | other | -*/ + */ export const getGroupedColumns = ( headersSelector: Selector ) => @@ -318,9 +318,16 @@ export const getGroupedColumns = ( (tableHeaders, hparamHeaders): ColumnHeader[] => { return [ ...tableHeaders.filter((header) => header.type === 'RUN'), - ...tableHeaders.filter((header) => header.type === 'CUSTOM' && header.name === 'experimentAlias'), + ...tableHeaders.filter( + (header) => + header.type === 'CUSTOM' && header.name === 'experimentAlias' + ), ...hparamHeaders, - ...tableHeaders.filter((header) => header.type !== 'RUN' && !(header.type === 'CUSTOM' && header.name === 'experimentAlias')), + ...tableHeaders.filter( + (header) => + header.type !== 'RUN' && + !(header.type === 'CUSTOM' && header.name === 'experimentAlias') + ), ]; } ); diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts index f8838684fb..6667e7beea 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import { RouteKind } from '../../../app_routing'; +import {RouteKind} from '../../../app_routing'; import { buildHparamSpec, buildMetricSpec, @@ -21,26 +21,26 @@ import { buildAppRoutingState, buildStateFromAppRoutingState, } from '../../../app_routing/store/testing'; -import { buildRoute } from '../../../app_routing/testing'; -import { buildExperiment } from '../../../experiments/store/testing'; -import { IntervalFilter, DiscreteFilter } from '../../../hparams/types'; -import { DomainType, Run } from '../../../runs/store/runs_types'; -import { getRunsTableHeaders } from '../../../runs/store/runs_selectors'; +import {buildRoute} from '../../../app_routing/testing'; +import {buildExperiment} from '../../../experiments/store/testing'; +import {IntervalFilter, DiscreteFilter} from '../../../hparams/types'; +import {DomainType, Run} from '../../../runs/store/runs_types'; +import {getRunsTableHeaders} from '../../../runs/store/runs_selectors'; import { buildRun, buildRunsState, buildStateFromRunsState, } from '../../../runs/store/testing'; -import { RunTableItem } from '../../../runs/views/runs_table/types'; -import { buildMockState } from '../../../testing/utils'; +import {RunTableItem} from '../../../runs/views/runs_table/types'; +import {buildMockState} from '../../../testing/utils'; import { appStateFromMetricsState, buildMetricsSettingsState, buildMetricsState, } from '../../testing'; -import { PluginType } from '../../types'; +import {PluginType} from '../../types'; import * as selectors from './common_selectors'; -import { ColumnHeaderType } from '../card_renderer/scalar_card_types'; +import {ColumnHeaderType} from '../card_renderer/scalar_card_types'; describe('common selectors', () => { let runIds: Record; @@ -56,7 +56,7 @@ describe('common selectors', () => { let state: ReturnType; beforeEach(() => { - runIds = { defaultExperimentId: ['run1', 'run2', 'run3'] }; + runIds = {defaultExperimentId: ['run1', 'run2', 'run3']}; runIdToExpId = { run1: 'defaultExperimentId', run2: 'defaultExperimentId', @@ -143,10 +143,10 @@ describe('common selectors', () => { }, ]; - run1 = buildRun({ name: 'run 1' }); - run2 = buildRun({ id: '2', name: 'run 2' }); - run3 = buildRun({ id: '3', name: 'run 3' }); - run4 = buildRun({ id: '4', name: 'run 4' }); + run1 = buildRun({name: 'run 1'}); + run2 = buildRun({id: '2', name: 'run 2'}); + run3 = buildRun({id: '3', name: 'run 3'}); + run4 = buildRun({id: '4', name: 'run 4'}); state = buildMockState({ runs: { data: { @@ -189,14 +189,14 @@ describe('common selectors', () => { displayName: 'Fake Runs Header', enabled: true, }, - ] + ], } as any, }, experiments: { data: { experimentMap: { - exp1: buildExperiment({ name: 'experiment1', id: 'exp1' }), - exp2: buildExperiment({ name: 'experiment2', id: 'exp2' }), + exp1: buildExperiment({name: 'experiment1', id: 'exp1'}), + exp2: buildExperiment({name: 'experiment2', id: 'exp2'}), }, }, }, @@ -212,7 +212,7 @@ describe('common selectors', () => { hparams: { dashboardSpecs: { hparams: [ - buildHparamSpec({ name: 'conv_layers', displayName: 'Conv Layers' }), + buildHparamSpec({name: 'conv_layers', displayName: 'Conv Layers'}), buildHparamSpec({ name: 'conv_kernel_size', displayName: 'Conv Kernel Size', @@ -221,9 +221,9 @@ describe('common selectors', () => { name: 'dense_layers', displayName: 'Dense Layers', }), - buildHparamSpec({ name: 'dropout', displayName: 'Dropout' }), + buildHparamSpec({name: 'dropout', displayName: 'Dropout'}), ], - metrics: [buildMetricSpec({ displayName: 'Bar' })], + metrics: [buildMetricSpec({displayName: 'Bar'})], }, dashboardSessionGroups: [], dashboardDisplayedHparamColumns: [ @@ -769,11 +769,11 @@ describe('common selectors', () => { state.app_routing!.activeRoute!.routeKind = RouteKind.COMPARE_EXPERIMENT; const results = selectors.TEST_ONLY.getRenderableRuns(state); expect(results.length).toEqual(5); - expect(results[0].run).toEqual({ ...run1, experimentId: 'exp1' }); - expect(results[1].run).toEqual({ ...run2, experimentId: 'exp1' }); - expect(results[2].run).toEqual({ ...run2, experimentId: 'exp2' }); - expect(results[3].run).toEqual({ ...run3, experimentId: 'exp2' }); - expect(results[4].run).toEqual({ ...run4, experimentId: 'exp2' }); + expect(results[0].run).toEqual({...run1, experimentId: 'exp1'}); + expect(results[1].run).toEqual({...run2, experimentId: 'exp1'}); + expect(results[2].run).toEqual({...run2, experimentId: 'exp2'}); + expect(results[3].run).toEqual({...run3, experimentId: 'exp2'}); + expect(results[4].run).toEqual({...run4, experimentId: 'exp2'}); }); it('returns empty list when route does not contain experiments', () => { @@ -965,7 +965,7 @@ describe('common selectors', () => { state.runs!.data.regexFilter = 'foo'; state.app_routing!.activeRoute = { routeKind: RouteKind.EXPERIMENT, - params: { experimentIds: 'exp1' }, + params: {experimentIds: 'exp1'}, }; const result = selectors.getFilteredRenderableRuns(state); expect(result).toEqual([]); @@ -987,7 +987,7 @@ describe('common selectors', () => { ).and.callThrough(); state.app_routing!.activeRoute = { routeKind: RouteKind.EXPERIMENT, - params: { experimentIds: 'exp1' }, + params: {experimentIds: 'exp1'}, }; const results = selectors.getFilteredRenderableRuns(state); expect(spy).toHaveBeenCalledOnceWith(results, new Map(), new Map()); @@ -1062,7 +1062,7 @@ describe('common selectors', () => { it('sets name as display name when a display name is not provided', () => { state.hparams!.dashboardSpecs.hparams = [ - buildHparamSpec({ name: 'conv_layers', displayName: '' }), + buildHparamSpec({name: 'conv_layers', displayName: ''}), ]; expect(selectors.getPotentialHparamColumns(state)).toEqual([ From f17fe2290aac8a99537509190f2f8a66acb789b0 Mon Sep 17 00:00:00 2001 From: Hoonji <736199+hoonji@users.noreply.github.com> Date: Fri, 26 Jan 2024 16:55:21 +0900 Subject: [PATCH 4/6] Turns groupColumns into a projector function helper --- tensorboard/webapp/BUILD | 1 + tensorboard/webapp/metrics/store/BUILD | 6 + .../webapp/metrics/store/metrics_selectors.ts | 11 + .../metrics/store/metrics_selectors_test.ts | 196 ++++++++++++++++++ tensorboard/webapp/metrics/testing.ts | 4 +- .../views/main_view/common_selectors.ts | 30 +-- .../views/main_view/common_selectors_test.ts | 35 ---- tensorboard/webapp/runs/store/BUILD | 1 + .../webapp/runs/store/runs_selectors.ts | 18 +- .../webapp/runs/store/runs_selectors_test.ts | 90 ++++++++ tensorboard/webapp/testing/BUILD | 2 +- tensorboard/webapp/testing/utils.ts | 2 +- tensorboard/webapp/widgets/data_table/BUILD | 23 ++ .../webapp/widgets/data_table/types.ts | 2 + .../webapp/widgets/data_table/utils.ts | 46 ++++ .../webapp/widgets/data_table/utils_test.ts | 98 +++++++++ 16 files changed, 494 insertions(+), 71 deletions(-) create mode 100644 tensorboard/webapp/widgets/data_table/utils.ts create mode 100644 tensorboard/webapp/widgets/data_table/utils_test.ts diff --git a/tensorboard/webapp/BUILD b/tensorboard/webapp/BUILD index bd2230c338..25e45c27cf 100644 --- a/tensorboard/webapp/BUILD +++ b/tensorboard/webapp/BUILD @@ -301,6 +301,7 @@ tf_ng_web_test_suite( "//tensorboard/webapp/widgets/content_wrapping_input:content_wrapping_input_tests", "//tensorboard/webapp/widgets/custom_modal:custom_modal_test", "//tensorboard/webapp/widgets/data_table:data_table_test", + "//tensorboard/webapp/widgets/data_table:utils_test", "//tensorboard/webapp/widgets/dropdown:dropdown_tests", "//tensorboard/webapp/widgets/experiment_alias:experiment_alias_test", "//tensorboard/webapp/widgets/filter_input:filter_input_test", diff --git a/tensorboard/webapp/metrics/store/BUILD b/tensorboard/webapp/metrics/store/BUILD index 8db5c763f9..349dcfe15b 100644 --- a/tensorboard/webapp/metrics/store/BUILD +++ b/tensorboard/webapp/metrics/store/BUILD @@ -18,6 +18,7 @@ tf_ts_library( "//tensorboard/webapp/app_routing:types", "//tensorboard/webapp/app_routing/actions", "//tensorboard/webapp/core/actions", + "//tensorboard/webapp/hparams/_redux:hparams_selectors", "//tensorboard/webapp/metrics:types", "//tensorboard/webapp/metrics:utils", "//tensorboard/webapp/metrics/actions", @@ -33,6 +34,7 @@ tf_ts_library( "//tensorboard/webapp/util:types", "//tensorboard/webapp/widgets/card_fob:types", "//tensorboard/webapp/widgets/data_table:types", + "//tensorboard/webapp/widgets/data_table:utils", "//tensorboard/webapp/widgets/line_chart_v2/lib:public_types", "@npm//@ngrx/store", ], @@ -97,14 +99,18 @@ tf_ts_library( "//tensorboard/webapp/app_routing:types", "//tensorboard/webapp/app_routing/actions", "//tensorboard/webapp/core/actions", + "//tensorboard/webapp/hparams:testing", + "//tensorboard/webapp/hparams/_redux:types", "//tensorboard/webapp/metrics:test_lib", "//tensorboard/webapp/metrics:types", "//tensorboard/webapp/metrics/actions", "//tensorboard/webapp/metrics/data_source", "//tensorboard/webapp/persistent_settings", "//tensorboard/webapp/routes:testing", + "//tensorboard/webapp/testing:utils", "//tensorboard/webapp/types", "//tensorboard/webapp/util:dom", + "//tensorboard/webapp/util:types", "//tensorboard/webapp/widgets/card_fob:types", "//tensorboard/webapp/widgets/data_table:types", "@npm//@types/jasmine", diff --git a/tensorboard/webapp/metrics/store/metrics_selectors.ts b/tensorboard/webapp/metrics/store/metrics_selectors.ts index 9892241fb5..f01db8e327 100644 --- a/tensorboard/webapp/metrics/store/metrics_selectors.ts +++ b/tensorboard/webapp/metrics/store/metrics_selectors.ts @@ -50,6 +50,8 @@ import { import {ColumnHeader, DataTableMode} from '../../widgets/data_table/types'; import {Extent} from '../../widgets/line_chart_v2/lib/public_types'; import {memoize} from '../../util/memoize'; +import {getDashboardDisplayedHparamColumns} from '../../hparams/_redux/hparams_selectors'; +import {DataTableUtils} from '../../widgets/data_table/utils'; const selectMetricsState = createFeatureSelector(METRICS_FEATURE_KEY); @@ -661,3 +663,12 @@ export const getColumnHeadersForCard = memoize((cardId: string) => { } ); }); + +export const getGroupedHeadersForCard = memoize((cardId: string) => + createSelector( + getColumnHeadersForCard(cardId), + getDashboardDisplayedHparamColumns, + (standardColumns, hparamColumns) => + DataTableUtils.groupColumns([...standardColumns, ...hparamColumns]) + ) +); diff --git a/tensorboard/webapp/metrics/store/metrics_selectors_test.ts b/tensorboard/webapp/metrics/store/metrics_selectors_test.ts index a30cb45a0e..28b63e1a70 100644 --- a/tensorboard/webapp/metrics/store/metrics_selectors_test.ts +++ b/tensorboard/webapp/metrics/store/metrics_selectors_test.ts @@ -33,6 +33,14 @@ import { } from '../../widgets/data_table/types'; import * as selectors from './metrics_selectors'; import {CardFeatureOverride, MetricsState} from './metrics_types'; +import {buildMockState} from '../../testing/utils'; +import { + buildHparamSpec, + buildHparamsState, + buildStateFromHparamsState, +} from '../../hparams/testing'; +import {DeepPartial} from '../../util/types'; +import {HparamsState} from '../../hparams/_redux/types'; describe('metrics selectors', () => { beforeEach(() => { @@ -1745,4 +1753,192 @@ describe('metrics selectors', () => { ).toEqual(rangeSelectionHeaders); }); }); + + describe('getGroupedHeadersForCard', () => { + let singleSelectionHeaders: ColumnHeader[]; + let rangeSelectionHeaders: ColumnHeader[]; + let hparamsState: DeepPartial; + + beforeEach(() => { + singleSelectionHeaders = [ + { + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + ]; + rangeSelectionHeaders = [ + { + type: ColumnHeaderType.MEAN, + name: 'mean', + displayName: 'Mean', + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + ]; + hparamsState = { + dashboardSpecs: { + hparams: [ + buildHparamSpec({name: 'conv_layers'}), + buildHparamSpec({name: 'conv_kernel_size'}), + ], + }, + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ], + }; + }); + + it('returns grouped single selection headers when card range selection is disabled', () => { + const state = buildMockState({ + ...appStateFromMetricsState( + buildMetricsState({ + singleSelectionHeaders, + rangeSelectionHeaders, + cardStateMap: { + card1: { + rangeSelectionOverride: + CardFeatureOverride.OVERRIDE_AS_DISABLED, + }, + }, + }) + ), + ...buildStateFromHparamsState(buildHparamsState(hparamsState)), + }); + + expect(selectors.getGroupedHeadersForCard('card1')(state)).toEqual([ + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + { + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }, + ]); + }); + + it('returns grouped range selection headers when card range selection is enabled', () => { + const state = buildMockState({ + ...appStateFromMetricsState( + buildMetricsState({ + singleSelectionHeaders, + rangeSelectionHeaders, + cardStateMap: { + card1: { + rangeSelectionOverride: CardFeatureOverride.OVERRIDE_AS_ENABLED, + }, + }, + }) + ), + ...buildStateFromHparamsState(buildHparamsState(hparamsState)), + }); + + expect(selectors.getGroupedHeadersForCard('card1')(state)).toEqual([ + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + { + type: ColumnHeaderType.MEAN, + name: 'mean', + displayName: 'Mean', + enabled: true, + }, + ]); + }); + + it('returns grouped range selection headers when global range selection is enabled', () => { + const state = buildMockState({ + ...appStateFromMetricsState( + buildMetricsState({ + singleSelectionHeaders, + rangeSelectionHeaders, + rangeSelectionEnabled: true, + }) + ), + ...buildStateFromHparamsState(buildHparamsState(hparamsState)), + }); + + expect(selectors.getGroupedHeadersForCard('card1')(state)).toEqual([ + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + { + type: ColumnHeaderType.MEAN, + name: 'mean', + displayName: 'Mean', + enabled: true, + }, + ]); + }); + }); }); diff --git a/tensorboard/webapp/metrics/testing.ts b/tensorboard/webapp/metrics/testing.ts index de33b3b66d..4dc9757d89 100644 --- a/tensorboard/webapp/metrics/testing.ts +++ b/tensorboard/webapp/metrics/testing.ts @@ -28,14 +28,12 @@ import { TimeSeriesRequest, TimeSeriesResponse, } from './data_source'; +import * as selectors from './store/metrics_selectors'; import { MetricsState, METRICS_FEATURE_KEY, TagMetadata, TimeSeriesData, -} from './store'; -import * as selectors from './store/metrics_selectors'; -import { CardStepIndexMetaData, MetricsSettings, RunToSeries, diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts index 76354a4ad9..0836af7fb4 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {createSelector, Selector} from '@ngrx/store'; +import {createSelector} from '@ngrx/store'; import {State} from '../../../app_state'; import { getCurrentRouteRunSelection, @@ -45,7 +45,6 @@ import { RunTableItem, RunTableExperimentItem, } from '../../../runs/views/runs_table/types'; -import {getRunsTableHeaders} from '../../../runs/store/runs_selectors'; import {matchRunToRegex} from '../../../util/matcher'; import {isSingleRunPlugin, PluginType} from '../../data_source'; import {getNonEmptyCardIdsWithMetadata, TagMetadata} from '../../store'; @@ -305,33 +304,6 @@ export const getSelectableColumns = createSelector( } ); -/** Returns a list of columns that have been sorted into logical groups. - * - * Column order: | RUN | experimentAlias | HPARAMs | other | - */ -export const getGroupedColumns = ( - headersSelector: Selector -) => - createSelector( - headersSelector, - getDashboardDisplayedHparamColumns, - (tableHeaders, hparamHeaders): ColumnHeader[] => { - return [ - ...tableHeaders.filter((header) => header.type === 'RUN'), - ...tableHeaders.filter( - (header) => - header.type === 'CUSTOM' && header.name === 'experimentAlias' - ), - ...hparamHeaders, - ...tableHeaders.filter( - (header) => - header.type !== 'RUN' && - !(header.type === 'CUSTOM' && header.name === 'experimentAlias') - ), - ]; - } - ); - export const getAllPotentialColumnsForCard = memoize((cardId: string) => { return createSelector( getColumnHeadersForCard(cardId), diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts index 6667e7beea..11b133f54b 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts @@ -25,7 +25,6 @@ import {buildRoute} from '../../../app_routing/testing'; import {buildExperiment} from '../../../experiments/store/testing'; import {IntervalFilter, DiscreteFilter} from '../../../hparams/types'; import {DomainType, Run} from '../../../runs/store/runs_types'; -import {getRunsTableHeaders} from '../../../runs/store/runs_selectors'; import { buildRun, buildRunsState, @@ -1119,38 +1118,4 @@ describe('common selectors', () => { ]); }); }); - - describe('getGroupedColumns', () => { - it('returns a grouped list of columns given a list of standard columns', () => { - expect(selectors.getGroupedColumns(getRunsTableHeaders)(state)).toEqual([ - jasmine.objectContaining({ - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', - }), - jasmine.objectContaining({ - type: ColumnHeaderType.CUSTOM, - name: 'experimentAlias', - displayName: 'Experiment', - }), - jasmine.objectContaining({ - type: ColumnHeaderType.HPARAM, - name: 'conv_layers', - displayName: 'Conv Layers', - enabled: true, - }), - jasmine.objectContaining({ - type: ColumnHeaderType.HPARAM, - name: 'dense_layers', - displayName: 'Dense Layers', - enabled: true, - }), - jasmine.objectContaining({ - type: ColumnHeaderType.CUSTOM, - name: 'fakeRunsHeader', - displayName: 'Fake Runs Header', - }), - ]); - }); - }); }); diff --git a/tensorboard/webapp/runs/store/BUILD b/tensorboard/webapp/runs/store/BUILD index ef603b75ec..c0f9fa24bd 100644 --- a/tensorboard/webapp/runs/store/BUILD +++ b/tensorboard/webapp/runs/store/BUILD @@ -52,6 +52,7 @@ tf_ts_library( "//tensorboard/webapp/types", "//tensorboard/webapp/types:ui", "//tensorboard/webapp/widgets/data_table:types", + "//tensorboard/webapp/widgets/data_table:utils", "@npm//@ngrx/store", ], ) diff --git a/tensorboard/webapp/runs/store/runs_selectors.ts b/tensorboard/webapp/runs/store/runs_selectors.ts index e8cff472f2..5e5d091636 100644 --- a/tensorboard/webapp/runs/store/runs_selectors.ts +++ b/tensorboard/webapp/runs/store/runs_selectors.ts @@ -26,9 +26,13 @@ import { } from './runs_types'; import {createGroupBy} from './utils'; import {getExperimentIdsFromRoute} from '../../app_routing/store/app_routing_selectors'; -import {getDashboardSessionGroups} from '../../hparams/_redux/hparams_selectors'; +import { + getDashboardDisplayedHparamColumns, + getDashboardSessionGroups, +} from '../../hparams/_redux/hparams_selectors'; import {HparamValue, RunToHparamsAndMetrics} from '../../hparams/types'; import {ColumnHeader, SortingInfo} from '../../widgets/data_table/types'; +import {DataTableUtils} from '../../widgets/data_table/utils'; const getRunsState = createFeatureSelector(RUNS_FEATURE_KEY); @@ -301,7 +305,7 @@ export const getColorGroupRegexString = createSelector( ); /** - * Gets the columns to be displayed by the runs table. + * Gets the standard columns to be displayed by the runs table. */ export const getRunsTableHeaders = createSelector( getUiState, @@ -319,3 +323,13 @@ export const getRunsTableSortingInfo = createSelector( return state.sortingInfo; } ); + +/** + * Gets the grouped columns to be displayed by the runs table. + */ +export const getGroupedRunsTableHeaders = createSelector( + getRunsTableHeaders, + getDashboardDisplayedHparamColumns, + (runsTableHeaders, hparamColumns) => + DataTableUtils.groupColumns([...runsTableHeaders, ...hparamColumns]) +); diff --git a/tensorboard/webapp/runs/store/runs_selectors_test.ts b/tensorboard/webapp/runs/store/runs_selectors_test.ts index 01363d880c..3d31b4e72f 100644 --- a/tensorboard/webapp/runs/store/runs_selectors_test.ts +++ b/tensorboard/webapp/runs/store/runs_selectors_test.ts @@ -21,6 +21,7 @@ import { buildSessionGroup, buildStateFromHparamsState, buildHparamsState, + buildHparamSpec, } from '../../hparams/testing'; import {buildMockState} from '../../testing/utils'; import {DataLoadState} from '../../types/data'; @@ -1027,6 +1028,95 @@ describe('runs_selectors', () => { }); }); + describe('#getGroupedRunsTableHeaders', () => { + it('returns runs table headers grouped with other headers', () => { + const state = buildMockState({ + runs: buildRunsState( + {}, + { + runsTableHeaders: [ + { + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }, + { + type: ColumnHeaderType.CUSTOM, + name: 'experimentAlias', + displayName: 'Experiment Alias', + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'Run', + enabled: true, + }, + ], + } + ), + ...buildStateFromHparamsState( + buildHparamsState({ + dashboardSpecs: { + hparams: [ + buildHparamSpec({name: 'conv_layers'}), + buildHparamSpec({name: 'conv_kernel_size'}), + ], + }, + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ], + }) + ), + }); + + expect(selectors.getGroupedRunsTableHeaders(state)).toEqual([ + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'Run', + enabled: true, + }, + { + type: ColumnHeaderType.CUSTOM, + name: 'experimentAlias', + displayName: 'Experiment Alias', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + { + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }, + ]); + }); + }); + describe('#getRunsTableSortingInfo', () => { it('returns the runs data table sorting info', () => { const state = buildMockState({ diff --git a/tensorboard/webapp/testing/BUILD b/tensorboard/webapp/testing/BUILD index 7fa69f955f..1c2722990b 100644 --- a/tensorboard/webapp/testing/BUILD +++ b/tensorboard/webapp/testing/BUILD @@ -83,7 +83,7 @@ tf_ts_library( "//tensorboard/webapp/hparams/_redux:testing", "//tensorboard/webapp/hparams/_redux:types", "//tensorboard/webapp/metrics:test_lib", - "//tensorboard/webapp/metrics/store", + "//tensorboard/webapp/metrics/store:types", "//tensorboard/webapp/notification_center/_redux:testing", "//tensorboard/webapp/notification_center/_redux:types", "//tensorboard/webapp/persistent_settings/_redux:testing", diff --git a/tensorboard/webapp/testing/utils.ts b/tensorboard/webapp/testing/utils.ts index ad63ce8a86..f2c78daad4 100644 --- a/tensorboard/webapp/testing/utils.ts +++ b/tensorboard/webapp/testing/utils.ts @@ -45,7 +45,7 @@ import { buildHparamsState, buildStateFromHparamsState, } from '../hparams/_redux/testing'; -import {METRICS_FEATURE_KEY} from '../metrics/store'; +import {METRICS_FEATURE_KEY} from '../metrics/store/metrics_types'; import {appStateFromMetricsState, buildMetricsState} from '../metrics/testing'; import {NOTIFICATION_FEATURE_KEY} from '../notification_center/_redux/notification_center_types'; import { diff --git a/tensorboard/webapp/widgets/data_table/BUILD b/tensorboard/webapp/widgets/data_table/BUILD index c43110f490..043e98805a 100644 --- a/tensorboard/webapp/widgets/data_table/BUILD +++ b/tensorboard/webapp/widgets/data_table/BUILD @@ -164,6 +164,29 @@ tf_ts_library( ], ) +tf_ts_library( + name = "utils", + srcs = [ + "utils.ts", + ], + deps = [ + ":types", + ], +) + +tf_ts_library( + name = "utils_test", + testonly = True, + srcs = [ + "utils_test.ts", + ], + deps = [ + ":types", + ":utils", + "@npm//@types/jasmine", + ], +) + tf_ts_library( name = "data_table_test", testonly = True, diff --git a/tensorboard/webapp/widgets/data_table/types.ts b/tensorboard/webapp/widgets/data_table/types.ts index 7b004c0a00..a3ed4bf1bd 100644 --- a/tensorboard/webapp/widgets/data_table/types.ts +++ b/tensorboard/webapp/widgets/data_table/types.ts @@ -129,3 +129,5 @@ export interface AddColumnEvent { nextTo?: ColumnHeader | undefined; side?: Side | undefined; } + +export type ColumnGroup = 'RUN' | 'EXPERIMENT_ALIAS' | 'HPARAM' | 'OTHER'; diff --git a/tensorboard/webapp/widgets/data_table/utils.ts b/tensorboard/webapp/widgets/data_table/utils.ts new file mode 100644 index 0000000000..e0d29509ca --- /dev/null +++ b/tensorboard/webapp/widgets/data_table/utils.ts @@ -0,0 +1,46 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {ColumnHeader, ColumnGroup} from './types'; + +function columnToGroup(column: ColumnHeader): ColumnGroup { + if (column.type === 'RUN') { + return 'RUN'; + } else if (column.type === 'CUSTOM' && column.name === 'experimentAlias') { + return 'EXPERIMENT_ALIAS'; + } else if (column.type === 'HPARAM') { + return 'HPARAM'; + } else { + return 'OTHER'; + } +} + +/** Sorts columns into predefined groups. + * + * Preserves relative order within groups. + */ +function groupColumns(columns: ColumnHeader[]): ColumnHeader[] { + const headerGroups = new Map([ + ['RUN', []], + ['EXPERIMENT_ALIAS', []], + ['HPARAM', []], + ['OTHER', []], + ]); + columns.forEach((column) => { + headerGroups.get(columnToGroup(column))?.push(column); + }); + return Array.from(headerGroups.values()).flat(); +} + +export const DataTableUtils = { + groupColumns, +}; diff --git a/tensorboard/webapp/widgets/data_table/utils_test.ts b/tensorboard/webapp/widgets/data_table/utils_test.ts new file mode 100644 index 0000000000..4f4ae5c8c3 --- /dev/null +++ b/tensorboard/webapp/widgets/data_table/utils_test.ts @@ -0,0 +1,98 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {ColumnHeaderType} from './types'; +import {DataTableUtils} from './utils'; + +describe('data table utils', () => { + describe('groupColumns', () => { + it('groups columns according to a predefined order', () => { + const inputColumns = [ + { + type: ColumnHeaderType.VALUE, + name: 'value', + displayName: 'Value', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', + enabled: true, + }, + { + type: ColumnHeaderType.CUSTOM, + name: 'experimentAlias', + displayName: 'Experiment Alias', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'Run', + enabled: true, + }, + ]; + + expect(DataTableUtils.groupColumns(inputColumns)).toEqual([ + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'Run', + enabled: true, + }, + { + type: ColumnHeaderType.CUSTOM, + name: 'experimentAlias', + displayName: 'Experiment Alias', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + { + type: ColumnHeaderType.VALUE, + name: 'value', + displayName: 'Value', + enabled: true, + }, + { + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', + enabled: true, + }, + ]); + }); + }); +}); From ec93a93e6affa51ccf63cac6094f09f70812d199 Mon Sep 17 00:00:00 2001 From: Hoonji <736199+hoonji@users.noreply.github.com> Date: Sat, 27 Jan 2024 13:28:33 +0900 Subject: [PATCH 5/6] Changes ColumnGroup to enum --- .../webapp/widgets/data_table/types.ts | 7 +++++- .../webapp/widgets/data_table/utils.ts | 22 ++++++++++--------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/tensorboard/webapp/widgets/data_table/types.ts b/tensorboard/webapp/widgets/data_table/types.ts index a3ed4bf1bd..cf12975277 100644 --- a/tensorboard/webapp/widgets/data_table/types.ts +++ b/tensorboard/webapp/widgets/data_table/types.ts @@ -130,4 +130,9 @@ export interface AddColumnEvent { side?: Side | undefined; } -export type ColumnGroup = 'RUN' | 'EXPERIMENT_ALIAS' | 'HPARAM' | 'OTHER'; +export enum ColumnGroup { + RUN = 'RUN', + EXPERIMENT_ALIAS = 'EXPERIMENT_ALIAS', + HPARAM = 'HPARAM', + OTHER = 'OTHER' +} \ No newline at end of file diff --git a/tensorboard/webapp/widgets/data_table/utils.ts b/tensorboard/webapp/widgets/data_table/utils.ts index e0d29509ca..3a64980dcb 100644 --- a/tensorboard/webapp/widgets/data_table/utils.ts +++ b/tensorboard/webapp/widgets/data_table/utils.ts @@ -14,26 +14,28 @@ import {ColumnHeader, ColumnGroup} from './types'; function columnToGroup(column: ColumnHeader): ColumnGroup { if (column.type === 'RUN') { - return 'RUN'; + return ColumnGroup.RUN; } else if (column.type === 'CUSTOM' && column.name === 'experimentAlias') { - return 'EXPERIMENT_ALIAS'; + return ColumnGroup.EXPERIMENT_ALIAS; } else if (column.type === 'HPARAM') { - return 'HPARAM'; + return ColumnGroup.HPARAM; } else { - return 'OTHER'; + return ColumnGroup.OTHER; } } -/** Sorts columns into predefined groups. +/** + * Sorts columns into predefined groups. * - * Preserves relative order within groups. + * Preserves relative column order within groups. */ function groupColumns(columns: ColumnHeader[]): ColumnHeader[] { + // Using Map ensures that keys preserve order. const headerGroups = new Map([ - ['RUN', []], - ['EXPERIMENT_ALIAS', []], - ['HPARAM', []], - ['OTHER', []], + [ColumnGroup.RUN, []], + [ColumnGroup.EXPERIMENT_ALIAS, []], + [ColumnGroup.HPARAM, []], + [ColumnGroup.OTHER, []], ]); columns.forEach((column) => { headerGroups.get(columnToGroup(column))?.push(column); From 901dbc2be137d03c0597bd7e18fff31f3f960039 Mon Sep 17 00:00:00 2001 From: Hoonji <736199+hoonji@users.noreply.github.com> Date: Mon, 29 Jan 2024 19:00:18 +0900 Subject: [PATCH 6/6] Fixes lint error --- tensorboard/webapp/widgets/data_table/types.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorboard/webapp/widgets/data_table/types.ts b/tensorboard/webapp/widgets/data_table/types.ts index cf12975277..d5f6b566e7 100644 --- a/tensorboard/webapp/widgets/data_table/types.ts +++ b/tensorboard/webapp/widgets/data_table/types.ts @@ -134,5 +134,5 @@ export enum ColumnGroup { RUN = 'RUN', EXPERIMENT_ALIAS = 'EXPERIMENT_ALIAS', HPARAM = 'HPARAM', - OTHER = 'OTHER' -} \ No newline at end of file + OTHER = 'OTHER', +}