diff --git a/tensorboard/webapp/BUILD b/tensorboard/webapp/BUILD index bd2230c338..25e45c27cf 100644 --- a/tensorboard/webapp/BUILD +++ b/tensorboard/webapp/BUILD @@ -301,6 +301,7 @@ tf_ng_web_test_suite( "//tensorboard/webapp/widgets/content_wrapping_input:content_wrapping_input_tests", "//tensorboard/webapp/widgets/custom_modal:custom_modal_test", "//tensorboard/webapp/widgets/data_table:data_table_test", + "//tensorboard/webapp/widgets/data_table:utils_test", "//tensorboard/webapp/widgets/dropdown:dropdown_tests", "//tensorboard/webapp/widgets/experiment_alias:experiment_alias_test", "//tensorboard/webapp/widgets/filter_input:filter_input_test", diff --git a/tensorboard/webapp/hparams/_redux/BUILD b/tensorboard/webapp/hparams/_redux/BUILD index 2b3608bd5b..5073a2c3f0 100644 --- a/tensorboard/webapp/hparams/_redux/BUILD +++ b/tensorboard/webapp/hparams/_redux/BUILD @@ -65,8 +65,10 @@ tf_ts_library( ":types", ":utils", "//tensorboard/webapp/hparams:_types", + "//tensorboard/webapp/persistent_settings", "//tensorboard/webapp/runs/actions", "//tensorboard/webapp/widgets/data_table:types", + "//tensorboard/webapp/widgets/data_table:utils", "@npm//@ngrx/store", ], ) @@ -167,6 +169,7 @@ tf_ts_library( "//tensorboard/webapp/app_routing/actions", "//tensorboard/webapp/core/actions", "//tensorboard/webapp/hparams:types", + "//tensorboard/webapp/persistent_settings", "//tensorboard/webapp/runs/actions", "//tensorboard/webapp/runs/data_source:testing", "//tensorboard/webapp/runs/store:testing", @@ -174,6 +177,7 @@ tf_ts_library( "//tensorboard/webapp/util:types", "//tensorboard/webapp/webapp_data_source:http_client_testing", "//tensorboard/webapp/widgets/data_table:types", + "//tensorboard/webapp/widgets/data_table:utils", "@npm//@ngrx/effects", "@npm//@ngrx/store", "@npm//@types/jasmine", diff --git a/tensorboard/webapp/hparams/_redux/hparams_reducers.ts b/tensorboard/webapp/hparams/_redux/hparams_reducers.ts index 883de392c6..eaca5cb3f0 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_reducers.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_reducers.ts @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ import {Action, ActionReducer, createReducer, on} from '@ngrx/store'; -import {ColumnHeader, Side} from '../../widgets/data_table/types'; +import {DataTableUtils} from '../../widgets/data_table/utils'; +import {persistentSettingsLoaded} from '../../persistent_settings'; +import {Side} from '../../widgets/data_table/types'; import * as actions from './hparams_actions'; import {HparamsState} from './types'; @@ -32,6 +34,16 @@ const initialState: HparamsState = { const reducer: ActionReducer = createReducer( initialState, + on(persistentSettingsLoaded, (state, {partialSettings}) => { + const {dashboardDisplayedHparamColumns: storedColumns} = partialSettings; + if (storedColumns) { + return { + ...state, + dashboardDisplayedHparamColumns: storedColumns, + }; + } + return state; + }), on(actions.hparamsFetchSessionGroupsSucceeded, (state, action) => { const nextDashboardSpecs = action.hparamsAndMetricsSpecs; const nextDashboardSessionGroups = action.sessionGroups; @@ -141,28 +153,12 @@ const reducer: ActionReducer = createReducer( actions.dashboardHparamColumnOrderChanged, (state, {source, destination, side}) => { const {dashboardDisplayedHparamColumns: columns} = state; - const sourceIndex = columns.findIndex( - (column: ColumnHeader) => column.name === source.name + const newColumns = DataTableUtils.moveColumn( + columns, + source, + destination, + side ); - let destinationIndex = columns.findIndex( - (column: ColumnHeader) => column.name === destination.name - ); - if (sourceIndex === -1 || sourceIndex === destinationIndex) { - return state; - } - if (destinationIndex === -1) { - // Use side as a backup to determine source position if destination isn't found. - if (side !== undefined) { - destinationIndex = side === Side.LEFT ? 0 : columns.length - 1; - } else { - return state; - } - } - - const newColumns = [...columns]; - newColumns.splice(sourceIndex, 1); - newColumns.splice(destinationIndex, 0, source); - return { ...state, dashboardDisplayedHparamColumns: newColumns, diff --git a/tensorboard/webapp/hparams/_redux/hparams_reducers_test.ts b/tensorboard/webapp/hparams/_redux/hparams_reducers_test.ts index d671f762dc..6ed102ac59 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_reducers_test.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_reducers_test.ts @@ -18,8 +18,82 @@ import * as actions from './hparams_actions'; import {reducers} from './hparams_reducers'; import {buildHparamSpec, buildHparamsState, buildMetricSpec} from './testing'; import {ColumnHeaderType, Side} from '../../widgets/data_table/types'; +import {DataTableUtils} from '../../widgets/data_table/utils'; +import {persistentSettingsLoaded} from '../../persistent_settings'; describe('hparams/_redux/hparams_reducers_test', () => { + describe('#persistentSettingsLoaded', () => { + it('loads dashboardDisplayedHparamColumns from the persistent settings storage', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: [], + }); + const state2 = reducers( + state, + persistentSettingsLoaded({ + partialSettings: { + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ], + }, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual([ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ]); + }); + + it('does nothing if persistent settings does not contain dashboardDisplayedHparamColumns', () => { + const state = buildHparamsState({ + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + ], + }); + const state2 = reducers( + state, + persistentSettingsLoaded({ + partialSettings: {}, + }) + ); + + expect(state2.dashboardDisplayedHparamColumns).toEqual([ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + ]); + }); + }); + describe('hparamsFetchSessionGroupsSucceeded', () => { it('saves action.hparamsAndMetricsSpecs as dashboardSpecs', () => { const state = buildHparamsState({ @@ -594,104 +668,15 @@ describe('hparams/_redux/hparams_reducers_test', () => { }, ]; - it('does nothing if source is not found', () => { + it('moves source to destination using moveColumn', () => { const state = buildHparamsState({ dashboardDisplayedHparamColumns: fakeColumns, }); - const state2 = reducers( - state, - actions.dashboardHparamColumnOrderChanged({ - source: { - type: ColumnHeaderType.HPARAM, - name: 'nonexistent_param', - displayName: 'Nonexistent param', - enabled: false, - }, - destination: { - type: ColumnHeaderType.HPARAM, - name: 'conv_kernel_size', - displayName: 'Conv Kernel Size', - enabled: true, - }, - side: Side.LEFT, - }) - ); - - expect(state2.dashboardDisplayedHparamColumns).toEqual(fakeColumns); - }); - - it('does nothing if source equals dest', () => { - const state = buildHparamsState({ - dashboardDisplayedHparamColumns: fakeColumns, - }); - const state2 = reducers( - state, - actions.dashboardHparamColumnOrderChanged({ - source: { - type: ColumnHeaderType.HPARAM, - name: 'conv_kernel_size', - displayName: 'Conv Kernel Size', - enabled: false, - }, - destination: { - type: ColumnHeaderType.HPARAM, - name: 'conv_kernel_size', - displayName: 'Conv Kernel Size', - enabled: true, - }, - side: Side.LEFT, - }) - ); - - expect(state2.dashboardDisplayedHparamColumns).toEqual(fakeColumns); - }); - - [ - { - testDesc: 'to front if side is left', - side: Side.LEFT, - expectedResult: [ - fakeColumns[1], - fakeColumns[0], - ...fakeColumns.slice(2), - ], - }, - { - testDesc: 'to back if side is right', - side: Side.RIGHT, - expectedResult: [ - fakeColumns[0], - ...fakeColumns.slice(2), - fakeColumns[1], - ], - }, - ].forEach(({testDesc, side, expectedResult}) => { - it(`if destination not found, moves source ${testDesc}`, () => { - const state = buildHparamsState({ - dashboardDisplayedHparamColumns: fakeColumns, - }); - const state2 = reducers( - state, - actions.dashboardHparamColumnOrderChanged({ - source: fakeColumns[1], - destination: { - type: ColumnHeaderType.HPARAM, - name: 'nonexistent param', - displayName: 'Nonexistent param', - enabled: true, - }, - side, - }) - ); - - expect(state2.dashboardDisplayedHparamColumns).toEqual(expectedResult); - }); - }); + const moveColumnSpy = spyOn( + DataTableUtils, + 'moveColumn' + ).and.callThrough(); - it('swaps source and destination positions if destination is found', () => { - const state = buildHparamsState({ - dashboardDisplayedHparamColumns: fakeColumns, - }); const state2 = reducers( state, actions.dashboardHparamColumnOrderChanged({ @@ -701,6 +686,13 @@ describe('hparams/_redux/hparams_reducers_test', () => { }) ); + // Edge cases are tested by moveColumn tests. + expect(moveColumnSpy).toHaveBeenCalledWith( + fakeColumns, + fakeColumns[1], + fakeColumns[0], + Side.LEFT + ); expect(state2.dashboardDisplayedHparamColumns).toEqual([ fakeColumns[1], fakeColumns[0], diff --git a/tensorboard/webapp/hparams/_redux/hparams_selectors.ts b/tensorboard/webapp/hparams/_redux/hparams_selectors.ts index 90c8b360d0..bd6054e3a5 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_selectors.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_selectors.ts @@ -47,8 +47,14 @@ export const getDashboardDefaultHparamFilters = createSelector( ); export const getDashboardDisplayedHparamColumns = createSelector( + getDashboardHparamsAndMetricsSpecs, getHparamsState, - (state) => state.dashboardDisplayedHparamColumns + ({hparams}, state) => { + const hparamSet = new Set(hparams.map((hparam) => hparam.name)); + return state.dashboardDisplayedHparamColumns.filter((column) => + hparamSet.has(column.name) + ); + } ); export const getDashboardHparamFilterMap = createSelector( diff --git a/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts b/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts index cc43c0dfa2..d57be572bb 100644 --- a/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts +++ b/tensorboard/webapp/hparams/_redux/hparams_selectors_test.ts @@ -15,6 +15,7 @@ limitations under the License. import {ColumnHeaderType} from '../../widgets/data_table/types'; import {DomainType} from '../types'; +import {State} from './types'; import * as selectors from './hparams_selectors'; import { buildHparamSpec, @@ -114,30 +115,63 @@ describe('hparams/_redux/hparams_selectors_test', () => { }); describe('#getDashboardDisplayedHparamColumns', () => { - it('returns dashboard displayed hparam columns', () => { - const fakeColumns = [ - { - type: ColumnHeaderType.HPARAM, - name: 'conv_layers', - displayName: 'Conv Layers', - enabled: true, - }, - { - type: ColumnHeaderType.HPARAM, - name: 'conv_kernel_size', - displayName: 'Conv Kernel Size', - enabled: true, - }, - ]; + it('returns no columns if no hparam specs', () => { const state = buildStateFromHparamsState( buildHparamsState({ - dashboardDisplayedHparamColumns: fakeColumns, + dashboardSpecs: { + hparams: [], + }, + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ], }) ); - expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual( - fakeColumns + expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual([]); + }); + + it('returns only hparam columns that have specs', () => { + const state = buildStateFromHparamsState( + buildHparamsState({ + dashboardSpecs: { + hparams: [buildHparamSpec({name: 'conv_layers'})], + }, + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ], + }) ); + + expect(selectors.getDashboardDisplayedHparamColumns(state)).toEqual([ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + ]); }); }); }); diff --git a/tensorboard/webapp/metrics/actions/index.ts b/tensorboard/webapp/metrics/actions/index.ts index a7f8f52d10..8fcfb1b9fd 100644 --- a/tensorboard/webapp/metrics/actions/index.ts +++ b/tensorboard/webapp/metrics/actions/index.ts @@ -29,12 +29,15 @@ import { HeaderEditInfo, HeaderToggleInfo, HistogramMode, - MinMaxStep, PluginType, TooltipSort, XAxisType, } from '../types'; -import {SortingInfo, DataTableMode} from '../../widgets/data_table/types'; +import { + SortingInfo, + DataTableMode, + ColumnHeader, +} from '../../widgets/data_table/types'; import {Extent} from '../../widgets/line_chart_v2/lib/public_types'; export const metricsSettingsPaneClosed = createAction( @@ -234,8 +237,8 @@ export const sortingDataTable = createAction( props() ); -export const dataTableColumnEdited = createAction( - '[Metrics] Data table columns edited in edit menu', +export const dataTableColumnOrderChanged = createAction( + '[Metrics] Data table columns order changed', props() ); @@ -244,6 +247,11 @@ export const dataTableColumnToggled = createAction( props() ); +export const dataTableColumnAdded = createAction( + '[Metrics] Data table column added in edit menu', + props<{header: ColumnHeader}>() +); + export const stepSelectorToggled = createAction( '[Metrics] Time Selector Enable Toggle', props<{ diff --git a/tensorboard/webapp/metrics/internal_types.ts b/tensorboard/webapp/metrics/internal_types.ts index d73a6c12e3..c664baa7e4 100644 --- a/tensorboard/webapp/metrics/internal_types.ts +++ b/tensorboard/webapp/metrics/internal_types.ts @@ -16,8 +16,8 @@ import {TimeSelection} from '../widgets/card_fob/card_fob_types'; import {HistogramMode} from '../widgets/histogram/histogram_types'; import { ColumnHeader, - ColumnHeaderType, DataTableMode, + ReorderColumnEvent, } from '../widgets/data_table/types'; export {HistogramMode, TimeSelection}; @@ -94,15 +94,14 @@ export interface URLDeserializedState { }; } -export interface HeaderEditInfo { +export interface HeaderEditInfo extends ReorderColumnEvent { dataTableMode: DataTableMode; - headers: ColumnHeader[]; } export interface HeaderToggleInfo { header: ColumnHeader; cardId?: CardId; - dataTableMode?: DataTableMode; + dataTableMode?: DataTableMode | undefined; } export const SCALARS_SMOOTHING_MIN = 0; diff --git a/tensorboard/webapp/metrics/store/BUILD b/tensorboard/webapp/metrics/store/BUILD index 8db5c763f9..349dcfe15b 100644 --- a/tensorboard/webapp/metrics/store/BUILD +++ b/tensorboard/webapp/metrics/store/BUILD @@ -18,6 +18,7 @@ tf_ts_library( "//tensorboard/webapp/app_routing:types", "//tensorboard/webapp/app_routing/actions", "//tensorboard/webapp/core/actions", + "//tensorboard/webapp/hparams/_redux:hparams_selectors", "//tensorboard/webapp/metrics:types", "//tensorboard/webapp/metrics:utils", "//tensorboard/webapp/metrics/actions", @@ -33,6 +34,7 @@ tf_ts_library( "//tensorboard/webapp/util:types", "//tensorboard/webapp/widgets/card_fob:types", "//tensorboard/webapp/widgets/data_table:types", + "//tensorboard/webapp/widgets/data_table:utils", "//tensorboard/webapp/widgets/line_chart_v2/lib:public_types", "@npm//@ngrx/store", ], @@ -97,14 +99,18 @@ tf_ts_library( "//tensorboard/webapp/app_routing:types", "//tensorboard/webapp/app_routing/actions", "//tensorboard/webapp/core/actions", + "//tensorboard/webapp/hparams:testing", + "//tensorboard/webapp/hparams/_redux:types", "//tensorboard/webapp/metrics:test_lib", "//tensorboard/webapp/metrics:types", "//tensorboard/webapp/metrics/actions", "//tensorboard/webapp/metrics/data_source", "//tensorboard/webapp/persistent_settings", "//tensorboard/webapp/routes:testing", + "//tensorboard/webapp/testing:utils", "//tensorboard/webapp/types", "//tensorboard/webapp/util:dom", + "//tensorboard/webapp/util:types", "//tensorboard/webapp/widgets/card_fob:types", "//tensorboard/webapp/widgets/data_table:types", "@npm//@types/jasmine", diff --git a/tensorboard/webapp/metrics/store/metrics_reducers.ts b/tensorboard/webapp/metrics/store/metrics_reducers.ts index d2f899c2c2..956af53f6b 100644 --- a/tensorboard/webapp/metrics/store/metrics_reducers.ts +++ b/tensorboard/webapp/metrics/store/metrics_reducers.ts @@ -46,11 +46,7 @@ import { URLDeserializedState, } from '../types'; import {groupCardIdWithMetdata} from '../utils'; -import { - ColumnHeader, - ColumnHeaderType, - DataTableMode, -} from '../../widgets/data_table/types'; +import {ColumnHeaderType, DataTableMode} from '../../widgets/data_table/types'; import { buildOrReturnStateWithPinnedCopy, buildOrReturnStateWithUnresolvedImportedPins, @@ -82,6 +78,7 @@ import { TimeSeriesData, TimeSeriesLoadable, } from './metrics_types'; +import {DataTableUtils} from '../../widgets/data_table/utils'; function buildCardMetadataList(tagMetadata: TagMetadata): CardMetadata[] { const results: CardMetadata[] = []; @@ -1434,34 +1431,30 @@ const reducer = createReducer( tableEditorSelectedTab: tab, }; }), - on(actions.dataTableColumnEdited, (state, {dataTableMode, headers}) => { - const enabledNewHeaders: ColumnHeader[] = []; - const disabledNewHeaders: ColumnHeader[] = []; - - // All enabled headers appear above all disabled headers. - headers.forEach((header) => { - if (header.enabled) { - enabledNewHeaders.push(header); - } else { - disabledNewHeaders.push(header); + on( + actions.dataTableColumnOrderChanged, + (state, {source, destination, side, dataTableMode}) => { + let headers = + dataTableMode === DataTableMode.RANGE + ? [...state.rangeSelectionHeaders] + : [...state.singleSelectionHeaders]; + headers = DataTableUtils.moveColumn(headers, source, destination, side); + + if (dataTableMode === DataTableMode.RANGE) { + return { + ...state, + rangeSelectionHeaders: headers, + }; } - }); - - if (dataTableMode === DataTableMode.RANGE) { return { ...state, - rangeSelectionHeaders: enabledNewHeaders.concat(disabledNewHeaders), + singleSelectionHeaders: headers, }; } - - return { - ...state, - singleSelectionHeaders: enabledNewHeaders.concat(disabledNewHeaders), - }; - }), + ), on( actions.dataTableColumnToggled, - (state, {dataTableMode, header, cardId}) => { + (state, {dataTableMode, header: toggledHeader, cardId}) => { const {cardStateMap, rangeSelectionEnabled, linkedTimeEnabled} = state; const rangeEnabled = cardId ? cardRangeSelectionEnabled( @@ -1471,32 +1464,17 @@ const reducer = createReducer( cardId ) : dataTableMode === DataTableMode.RANGE; - const targetedHeaders = rangeEnabled ? state.rangeSelectionHeaders : state.singleSelectionHeaders; - const currentToggledHeaderIndex = targetedHeaders.findIndex( - (element) => element.name === header.name - ); - - // If the header is being enabled it goes at the bottom of the currently - // enabled headers. If it is being disabled it goes to the top of the - // currently disabled headers. - let newToggledHeaderIndex = getEnabledCount(targetedHeaders); - if (targetedHeaders[currentToggledHeaderIndex].enabled) { - newToggledHeaderIndex--; - } - const newHeaders = moveHeader( - currentToggledHeaderIndex, - newToggledHeaderIndex, - targetedHeaders - ); - - newHeaders[newToggledHeaderIndex] = { - ...newHeaders[newToggledHeaderIndex], - enabled: !newHeaders[newToggledHeaderIndex].enabled, - }; + const newHeaders = targetedHeaders.map((header) => { + const newHeader = {...header}; + if (header.name === toggledHeader.name) { + newHeader.enabled = !newHeader.enabled; + } + return newHeader; + }); if (rangeEnabled) { return { @@ -1504,7 +1482,6 @@ const reducer = createReducer( rangeSelectionHeaders: newHeaders, }; } - return { ...state, singleSelectionHeaders: newHeaders, @@ -1583,30 +1560,3 @@ function buildTagToRuns(runTagInfo: {[run: string]: string[]}) { } return tagToRuns; } - -/** - * Returns a copy of the headers array with item at sourceIndex moved to - * destinationIndex. - */ -function moveHeader( - sourceIndex: number, - destinationIndex: number, - headers: ColumnHeader[] -) { - const newHeaders = [...headers]; - // Delete from original location - newHeaders.splice(sourceIndex, 1); - // Insert at destinationIndex. - newHeaders.splice(destinationIndex, 0, headers[sourceIndex]); - return newHeaders; -} - -function getEnabledCount(headers: ColumnHeader[]) { - let count = 0; - headers.forEach((header) => { - if (header.enabled) { - count++; - } - }); - return count; -} diff --git a/tensorboard/webapp/metrics/store/metrics_reducers_test.ts b/tensorboard/webapp/metrics/store/metrics_reducers_test.ts index ba42d51e0b..6631f580af 100644 --- a/tensorboard/webapp/metrics/store/metrics_reducers_test.ts +++ b/tensorboard/webapp/metrics/store/metrics_reducers_test.ts @@ -1601,40 +1601,20 @@ describe('metrics reducers', () => { const nextState = reducers( beforeState, - actions.dataTableColumnEdited({ + actions.dataTableColumnOrderChanged({ + source: { + type: ColumnHeaderType.END_VALUE, + name: 'endValue', + displayName: 'End Value', + enabled: true, + }, + destination: { + type: ColumnHeaderType.START_VALUE, + name: 'startValue', + displayName: 'Start Value', + enabled: true, + }, dataTableMode: DataTableMode.RANGE, - headers: [ - { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', - enabled: true, - }, - { - type: ColumnHeaderType.END_VALUE, - name: 'endValue', - displayName: 'End Value', - enabled: true, - }, - { - type: ColumnHeaderType.START_VALUE, - name: 'startValue', - displayName: 'Start Value', - enabled: true, - }, - { - type: ColumnHeaderType.MIN_VALUE, - name: 'minValue', - displayName: 'Min', - enabled: false, - }, - { - type: ColumnHeaderType.MAX_VALUE, - name: 'maxValue', - displayName: 'Max', - enabled: false, - }, - ], }) ); @@ -1762,34 +1742,20 @@ describe('metrics reducers', () => { const nextState = reducers( beforeState, - actions.dataTableColumnEdited({ + actions.dataTableColumnOrderChanged({ + source: { + type: ColumnHeaderType.STEP, + name: 'step', + displayName: 'Step', + enabled: true, + }, + destination: { + type: ColumnHeaderType.VALUE, + name: 'value', + displayName: 'Value', + enabled: true, + }, dataTableMode: DataTableMode.SINGLE, - headers: [ - { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', - enabled: true, - }, - { - type: ColumnHeaderType.STEP, - name: 'step', - displayName: 'Step', - enabled: true, - }, - { - type: ColumnHeaderType.VALUE, - name: 'value', - displayName: 'Value', - enabled: true, - }, - { - type: ColumnHeaderType.RELATIVE_TIME, - name: 'relativeTime', - displayName: 'Relative', - enabled: false, - }, - ], }) ); @@ -1852,115 +1818,6 @@ describe('metrics reducers', () => { }, ]); }); - - it('ensures ordering keeps enabled headers first', () => { - const beforeState = buildMetricsState({ - rangeSelectionHeaders: [ - { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', - enabled: true, - }, - { - type: ColumnHeaderType.START_VALUE, - name: 'startValue', - displayName: 'Start Value', - enabled: true, - }, - { - type: ColumnHeaderType.END_VALUE, - name: 'endValue', - displayName: 'End Value', - enabled: true, - }, - { - type: ColumnHeaderType.MIN_VALUE, - name: 'minValue', - displayName: 'Min', - enabled: false, - }, - { - type: ColumnHeaderType.MAX_VALUE, - name: 'maxValue', - displayName: 'Max', - enabled: false, - }, - ], - }); - - const nextState = reducers( - beforeState, - actions.dataTableColumnEdited({ - dataTableMode: DataTableMode.RANGE, - headers: [ - { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', - enabled: true, - }, - { - type: ColumnHeaderType.MAX_VALUE, - name: 'maxValue', - displayName: 'Max', - enabled: false, - }, - { - type: ColumnHeaderType.START_VALUE, - name: 'startValue', - displayName: 'Start Value', - enabled: true, - }, - { - type: ColumnHeaderType.END_VALUE, - name: 'endValue', - displayName: 'End Value', - enabled: true, - }, - { - type: ColumnHeaderType.MIN_VALUE, - name: 'minValue', - displayName: 'Min', - enabled: false, - }, - ], - }) - ); - - expect(nextState.rangeSelectionHeaders).toEqual([ - { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', - enabled: true, - }, - { - type: ColumnHeaderType.START_VALUE, - name: 'startValue', - displayName: 'Start Value', - enabled: true, - }, - { - type: ColumnHeaderType.END_VALUE, - name: 'endValue', - displayName: 'End Value', - enabled: true, - }, - { - type: ColumnHeaderType.MAX_VALUE, - name: 'maxValue', - displayName: 'Max', - enabled: false, - }, - { - type: ColumnHeaderType.MIN_VALUE, - name: 'minValue', - displayName: 'Min', - enabled: false, - }, - ]); - }); }); describe('dataTableColumnToggled', () => { @@ -2029,55 +1886,7 @@ describe('metrics reducers', () => { }); }); - it('moves header down to the disabled headers when toggling to disabled with data table mode input', () => { - const nextState = reducers( - beforeState, - actions.dataTableColumnToggled({ - dataTableMode: DataTableMode.RANGE, - header: { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', - enabled: false, - }, - }) - ); - - expect(nextState.rangeSelectionHeaders).toEqual([ - { - type: ColumnHeaderType.START_VALUE, - name: 'startValue', - displayName: 'Start Value', - enabled: true, - }, - { - type: ColumnHeaderType.END_VALUE, - name: 'endValue', - displayName: 'End Value', - enabled: true, - }, - { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', - enabled: false, - }, - { - type: ColumnHeaderType.MIN_VALUE, - name: 'minValue', - displayName: 'Min', - enabled: false, - }, - { - type: ColumnHeaderType.MAX_VALUE, - name: 'maxValue', - displayName: 'Max', - enabled: false, - }, - ]); - }); - - it('moves header up to the enabled headers when toggling to enabled with data table mode input', () => { + it('only changes range selection headers when dataTableMode is RANGE', () => { const nextState = reducers( beforeState, actions.dataTableColumnToggled({ @@ -2110,66 +1919,18 @@ describe('metrics reducers', () => { displayName: 'End Value', enabled: true, }, - { - type: ColumnHeaderType.MAX_VALUE, - name: 'maxValue', - displayName: 'Max', - enabled: true, - }, { type: ColumnHeaderType.MIN_VALUE, name: 'minValue', displayName: 'Min', enabled: false, }, - ]); - }); - - it('only changes range selection headers when dataTableMode is RANGE', () => { - const nextState = reducers( - beforeState, - actions.dataTableColumnToggled({ - dataTableMode: DataTableMode.RANGE, - header: { - type: ColumnHeaderType.MAX_VALUE, - name: 'maxValue', - displayName: 'Max', - enabled: true, - }, - }) - ); - - expect(nextState.rangeSelectionHeaders).toEqual([ - { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', - enabled: true, - }, - { - type: ColumnHeaderType.START_VALUE, - name: 'startValue', - displayName: 'Start Value', - enabled: true, - }, - { - type: ColumnHeaderType.END_VALUE, - name: 'endValue', - displayName: 'End Value', - enabled: true, - }, { type: ColumnHeaderType.MAX_VALUE, name: 'maxValue', displayName: 'Max', enabled: true, }, - { - type: ColumnHeaderType.MIN_VALUE, - name: 'minValue', - displayName: 'Min', - enabled: false, - }, ]); expect(nextState.singleSelectionHeaders).toEqual([ @@ -2275,34 +2036,6 @@ describe('metrics reducers', () => { ]); }); - it('moves header down to the disabled headers when column is removed with card id input', () => { - beforeState = { - ...beforeState, - cardStateMap: { - card1: { - rangeSelectionOverride: CardFeatureOverride.OVERRIDE_AS_ENABLED, - }, - }, - }; - - const nextState = reducers( - beforeState, - actions.dataTableColumnToggled({ - cardId: 'card1', - header: { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', - enabled: true, - }, - }) - ); - - expect( - nextState.rangeSelectionHeaders.map((header) => header.enabled) - ).toEqual([true, true, false, false, false]); - }); - it('only changes range selection headers when given card has rangeSelectionOverride ENABLED', () => { beforeState = { ...beforeState, @@ -2345,20 +2078,19 @@ describe('metrics reducers', () => { displayName: 'End Value', enabled: true, }, - { - type: ColumnHeaderType.MAX_VALUE, - name: 'maxValue', - displayName: 'Max', - enabled: true, - }, { type: ColumnHeaderType.MIN_VALUE, name: 'minValue', displayName: 'Min', enabled: false, }, + { + type: ColumnHeaderType.MAX_VALUE, + name: 'maxValue', + displayName: 'Max', + enabled: true, + }, ]); - expect(nextState.singleSelectionHeaders).toEqual([ { type: ColumnHeaderType.RUN, diff --git a/tensorboard/webapp/metrics/store/metrics_selectors.ts b/tensorboard/webapp/metrics/store/metrics_selectors.ts index 9892241fb5..8d821c1512 100644 --- a/tensorboard/webapp/metrics/store/metrics_selectors.ts +++ b/tensorboard/webapp/metrics/store/metrics_selectors.ts @@ -50,6 +50,8 @@ import { import {ColumnHeader, DataTableMode} from '../../widgets/data_table/types'; import {Extent} from '../../widgets/line_chart_v2/lib/public_types'; import {memoize} from '../../util/memoize'; +import {getDashboardDisplayedHparamColumns} from '../../hparams/_redux/hparams_selectors'; +import {DataTableUtils} from '../../widgets/data_table/utils'; const selectMetricsState = createFeatureSelector(METRICS_FEATURE_KEY); @@ -661,3 +663,22 @@ export const getColumnHeadersForCard = memoize((cardId: string) => { } ); }); + +export const getGroupedHeadersForCard = memoize((cardId: string) => + createSelector( + getColumnHeadersForCard(cardId), + getDashboardDisplayedHparamColumns, + (standardColumns, hparamColumns) => { + // Override hparam options to match scalar card table requirements. + const columns = [...standardColumns, ...hparamColumns].map((column) => { + const newColumn = {...column}; + if (column.type === 'HPARAM') { + newColumn.removable = false; + newColumn.hidable = true; + } + return newColumn; + }); + return DataTableUtils.groupColumns(columns); + } + ) +); diff --git a/tensorboard/webapp/metrics/store/metrics_selectors_test.ts b/tensorboard/webapp/metrics/store/metrics_selectors_test.ts index a30cb45a0e..1c7a3b81ec 100644 --- a/tensorboard/webapp/metrics/store/metrics_selectors_test.ts +++ b/tensorboard/webapp/metrics/store/metrics_selectors_test.ts @@ -33,6 +33,14 @@ import { } from '../../widgets/data_table/types'; import * as selectors from './metrics_selectors'; import {CardFeatureOverride, MetricsState} from './metrics_types'; +import {buildMockState} from '../../testing/utils'; +import { + buildHparamSpec, + buildHparamsState, + buildStateFromHparamsState, +} from '../../hparams/testing'; +import {DeepPartial} from '../../util/types'; +import {HparamsState} from '../../hparams/_redux/types'; describe('metrics selectors', () => { beforeEach(() => { @@ -1745,4 +1753,240 @@ describe('metrics selectors', () => { ).toEqual(rangeSelectionHeaders); }); }); + + describe('getGroupedHeadersForCard', () => { + let singleSelectionHeaders: ColumnHeader[]; + let rangeSelectionHeaders: ColumnHeader[]; + let hparamsState: DeepPartial; + + beforeEach(() => { + singleSelectionHeaders = [ + { + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + ]; + rangeSelectionHeaders = [ + { + type: ColumnHeaderType.MEAN, + name: 'mean', + displayName: 'Mean', + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }, + ]; + hparamsState = { + dashboardSpecs: { + hparams: [ + buildHparamSpec({name: 'conv_layers'}), + buildHparamSpec({name: 'conv_kernel_size'}), + ], + }, + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ], + }; + }); + + it('returns grouped single selection headers when card range selection is disabled', () => { + const state = buildMockState({ + ...appStateFromMetricsState( + buildMetricsState({ + singleSelectionHeaders, + rangeSelectionHeaders, + cardStateMap: { + card1: { + rangeSelectionOverride: + CardFeatureOverride.OVERRIDE_AS_DISABLED, + }, + }, + }) + ), + ...buildStateFromHparamsState(buildHparamsState(hparamsState)), + }); + + expect(selectors.getGroupedHeadersForCard('card1')(state)).toEqual([ + jasmine.objectContaining({ + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }), + ]); + }); + + it('returns grouped range selection headers when card range selection is enabled', () => { + const state = buildMockState({ + ...appStateFromMetricsState( + buildMetricsState({ + singleSelectionHeaders, + rangeSelectionHeaders, + cardStateMap: { + card1: { + rangeSelectionOverride: CardFeatureOverride.OVERRIDE_AS_ENABLED, + }, + }, + }) + ), + ...buildStateFromHparamsState(buildHparamsState(hparamsState)), + }); + + expect(selectors.getGroupedHeadersForCard('card1')(state)).toEqual([ + jasmine.objectContaining({ + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.MEAN, + name: 'mean', + displayName: 'Mean', + enabled: true, + }), + ]); + }); + + it('returns grouped range selection headers when global range selection is enabled', () => { + const state = buildMockState({ + ...appStateFromMetricsState( + buildMetricsState({ + singleSelectionHeaders, + rangeSelectionHeaders, + rangeSelectionEnabled: true, + }) + ), + ...buildStateFromHparamsState(buildHparamsState(hparamsState)), + }); + + expect(selectors.getGroupedHeadersForCard('card1')(state)).toEqual([ + jasmine.objectContaining({ + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'My Run name', + enabled: false, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.MEAN, + name: 'mean', + displayName: 'Mean', + enabled: true, + }), + ]); + }); + + [ + { + testDesc: 'for single selection', + rangeSelectionOverride: CardFeatureOverride.OVERRIDE_AS_DISABLED, + }, + { + testDesc: 'for range selection', + rangeSelectionOverride: CardFeatureOverride.OVERRIDE_AS_ENABLED, + }, + ].forEach(({testDesc, rangeSelectionOverride}) => { + it(`sets proper context menu options for hparam columns ${testDesc}`, () => { + const state = buildMockState({ + ...appStateFromMetricsState( + buildMetricsState({ + singleSelectionHeaders, + rangeSelectionHeaders, + cardStateMap: { + card1: { + rangeSelectionOverride: + CardFeatureOverride.OVERRIDE_AS_ENABLED, + }, + }, + }) + ), + ...buildStateFromHparamsState(buildHparamsState(hparamsState)), + }); + + expect(selectors.getGroupedHeadersForCard('card1')(state)).toEqual([ + jasmine.objectContaining({ + type: ColumnHeaderType.RUN, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + removable: false, + hidable: true, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + removable: false, + hidable: true, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.MEAN, + }), + ]); + }); + }); + }); }); diff --git a/tensorboard/webapp/metrics/testing.ts b/tensorboard/webapp/metrics/testing.ts index de33b3b66d..4dc9757d89 100644 --- a/tensorboard/webapp/metrics/testing.ts +++ b/tensorboard/webapp/metrics/testing.ts @@ -28,14 +28,12 @@ import { TimeSeriesRequest, TimeSeriesResponse, } from './data_source'; +import * as selectors from './store/metrics_selectors'; import { MetricsState, METRICS_FEATURE_KEY, TagMetadata, TimeSeriesData, -} from './store'; -import * as selectors from './store/metrics_selectors'; -import { CardStepIndexMetaData, MetricsSettings, RunToSeries, diff --git a/tensorboard/webapp/metrics/views/card_renderer/BUILD b/tensorboard/webapp/metrics/views/card_renderer/BUILD index 180c83eb0f..3e1dbb2d98 100644 --- a/tensorboard/webapp/metrics/views/card_renderer/BUILD +++ b/tensorboard/webapp/metrics/views/card_renderer/BUILD @@ -302,6 +302,7 @@ tf_ng_module( ":scalar_card_types", ":utils", "//tensorboard/webapp/metrics:types", + "//tensorboard/webapp/runs:types", "//tensorboard/webapp/widgets/card_fob:types", "//tensorboard/webapp/widgets/data_table", "//tensorboard/webapp/widgets/data_table:types", @@ -339,6 +340,8 @@ tf_ng_module( "//tensorboard/webapp/angular:expect_angular_material_progress_spinner", "//tensorboard/webapp/experiments:types", "//tensorboard/webapp/feature_flag/store", + "//tensorboard/webapp/hparams", + "//tensorboard/webapp/hparams:types", "//tensorboard/webapp/metrics:types", "//tensorboard/webapp/metrics/actions", "//tensorboard/webapp/metrics/data_source", @@ -346,6 +349,7 @@ tf_ng_module( "//tensorboard/webapp/metrics/views:types", "//tensorboard/webapp/metrics/views:utils", "//tensorboard/webapp/metrics/views/main_view:common_selectors", + "//tensorboard/webapp/runs:types", "//tensorboard/webapp/runs/store:types", "//tensorboard/webapp/types", "//tensorboard/webapp/types:ui", @@ -459,6 +463,9 @@ tf_ts_library( "//tensorboard/webapp/angular:expect_angular_platform_browser_animations", "//tensorboard/webapp/angular:expect_ngrx_store_testing", "//tensorboard/webapp/experiments:types", + "//tensorboard/webapp/hparams/_redux:hparams_actions", + "//tensorboard/webapp/hparams/_redux:hparams_selectors", + "//tensorboard/webapp/hparams/_redux:types", "//tensorboard/webapp/metrics:test_lib", "//tensorboard/webapp/metrics:types", "//tensorboard/webapp/metrics/actions", @@ -466,6 +473,7 @@ tf_ts_library( "//tensorboard/webapp/metrics/store", "//tensorboard/webapp/metrics/store:types", "//tensorboard/webapp/metrics/views/main_view:common_selectors", + "//tensorboard/webapp/runs/store:selectors", "//tensorboard/webapp/runs/store:testing", "//tensorboard/webapp/runs/store:types", "//tensorboard/webapp/testing:mat_icon", diff --git a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_component.ng.html b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_component.ng.html index 11113af833..0eaffad40c 100644 --- a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_component.ng.html +++ b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_component.ng.html @@ -200,9 +200,14 @@ [columnCustomizationEnabled]="columnCustomizationEnabled" [smoothingEnabled]="smoothingEnabled" [hparamsEnabled]="hparamsEnabled" + [columnFilters]="columnFilters" + [runToHparams]="runToHparams" + [selectableColumns]="selectableColumns" (sortDataBy)="sortDataBy($event)" (editColumnHeaders)="editColumnHeaders.emit($event)" - (removeColumn)="removeColumn.emit($event)" + (addColumn)="addColumn.emit($event)" + (hideColumn)="hideColumn.emit($event)" + (addFilter)="addFilter.emit($event)" > diff --git a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_component.ts b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_component.ts index 74e2390576..df0e0e9cf8 100644 --- a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_component.ts +++ b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_component.ts @@ -44,7 +44,12 @@ import { TooltipDatum, } from '../../../widgets/line_chart_v2/types'; import {CardState} from '../../store'; -import {HeaderEditInfo, TooltipSort, XAxisType} from '../../types'; +import { + HeaderEditInfo, + HeaderToggleInfo, + TooltipSort, + XAxisType, +} from '../../types'; import { MinMaxStep, ScalarCardDataSeries, @@ -56,8 +61,13 @@ import { DataTableMode, SortingInfo, SortingOrder, + DiscreteFilter, + IntervalFilter, + FilterAddedEvent, + AddColumnEvent, } from '../../../widgets/data_table/types'; import {isDatumVisible, TimeSelectionView} from './utils'; +import {RunToHparams} from '../../../runs/types'; type ScalarTooltipDatum = TooltipDatum< ScalarCardSeriesMetadata & { @@ -102,6 +112,9 @@ export class ScalarCardComponent { @Input() columnHeaders!: ColumnHeader[]; @Input() rangeEnabled!: boolean; @Input() hparamsEnabled?: boolean; + @Input() columnFilters!: Map; + @Input() selectableColumns!: ColumnHeader[]; + @Input() runToHparams!: RunToHparams[]; @Output() onFullSizeToggle = new EventEmitter(); @Output() onPinClicked = new EventEmitter(); @@ -114,10 +127,11 @@ export class ScalarCardComponent { @Output() onDataTableSorting = new EventEmitter(); @Output() editColumnHeaders = new EventEmitter(); @Output() openTableEditMenuToMode = new EventEmitter(); - @Output() removeColumn = new EventEmitter(); + @Output() addColumn = new EventEmitter(); + @Output() hideColumn = new EventEmitter(); + @Output() addFilter = new EventEmitter(); @Output() onLineChartZoom = new EventEmitter(); - @Output() onCardStateChanged = new EventEmitter>(); // Line chart may not exist when was never visible (*ngIf). diff --git a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_container.ts b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_container.ts index abbd9d7421..b3b316f6f5 100644 --- a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_container.ts +++ b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_container.ts @@ -37,6 +37,7 @@ import { } from 'rxjs/operators'; import {State} from '../../../app_state'; import {ExperimentAlias} from '../../../experiments/types'; +import {actions as hparamsActions} from '../../../hparams'; import { getEnableHparamsInTimeSeries, getForceSvgFeatureFlag, @@ -57,7 +58,8 @@ import { getRun, getRunColorMap, getCurrentRouteRunSelection, - getColumnHeadersForCard, + getDashboardRunsToHparams, + getGroupedHeadersForCard, } from '../../../selectors'; import {DataLoadState} from '../../../types/data'; import { @@ -70,14 +72,14 @@ import {Extent} from '../../../widgets/line_chart_v2/lib/public_types'; import {ScaleType} from '../../../widgets/line_chart_v2/types'; import { cardViewBoxChanged, - dataTableColumnEdited, - dataTableColumnToggled, metricsCardFullSizeToggled, metricsCardStateUpdated, sortingDataTable, stepSelectorToggled, timeSelectionChanged, metricsSlideoutMenuOpened, + dataTableColumnOrderChanged, + dataTableColumnToggled, } from '../../actions'; import {PluginType, ScalarStepDatum} from '../../data_source'; import { @@ -100,7 +102,12 @@ import { HeaderToggleInfo, XAxisType, } from '../../types'; -import {getFilteredRenderableRunsIds} from '../main_view/common_selectors'; +import {RunToHparams} from '../../../runs/types'; +import { + getFilteredRenderableRunsIds, + getCurrentColumnFilters, + getSelectableColumns, +} from '../main_view/common_selectors'; import {CardRenderer} from '../metrics_view_types'; import {getTagDisplayName} from '../utils'; import {DataDownloadDialogContainer} from './data_download_dialog_container'; @@ -117,12 +124,15 @@ import { ColumnHeader, DataTableMode, SortingInfo, + FilterAddedEvent, + AddColumnEvent, } from '../../../widgets/data_table/types'; import { maybeClipTimeSelectionView, partitionSeries, TimeSelectionView, } from './utils'; +import {RunToHparamsAndMetrics} from '../../../hparams/types'; type ScalarCardMetadata = CardMetadata & { plugin: PluginType.SCALARS; @@ -150,11 +160,6 @@ function areSeriesEqual( }); } -function isMinMaxStepValid(minMax: MinMaxStep | undefined): boolean { - if (!minMax) return false; - return !(minMax.minStep === -Infinity && minMax.maxStep === Infinity); -} - @Component({ selector: 'scalar-card', template: ` @@ -185,6 +190,9 @@ function isMinMaxStepValid(minMax: MinMaxStep | undefined): boolean { [columnHeaders]="columnHeaders$ | async" [rangeEnabled]="rangeEnabled$ | async" [hparamsEnabled]="hparamsEnabled$ | async" + [columnFilters]="columnFilters$ | async" + [runToHparams]="runToHparams$ | async" + [selectableColumns]="selectableColumns$ | async" (onFullSizeToggle)="onFullSizeToggle()" (onPinClicked)="pinStateChanged.emit($event)" observeIntersection @@ -196,7 +204,9 @@ function isMinMaxStepValid(minMax: MinMaxStep | undefined): boolean { (editColumnHeaders)="editColumnHeaders($event)" (onCardStateChanged)="onCardStateChanged($event)" (openTableEditMenuToMode)="openTableEditMenuToMode($event)" - (removeColumn)="onRemoveColumn($event)" + (addColumn)="onAddColumn($event)" + (hideColumn)="onHideColumn($event)" + (addFilter)="addHparamFilter($event)" > `, styles: [ @@ -236,6 +246,9 @@ export class ScalarCardContainer implements CardRenderer, OnInit, OnDestroy { cardState$?: Observable>; rangeEnabled$?: Observable; hparamsEnabled$?: Observable; + columnFilters$ = this.store.select(getCurrentColumnFilters); + runToHparams$?: Observable; + selectableColumns$?: Observable; onVisibilityChange({visible}: {visible: boolean}) { this.isVisible = visible; @@ -471,7 +484,7 @@ export class ScalarCardContainer implements CardRenderer, OnInit, OnDestroy { ); this.columnHeaders$ = this.store.select( - getColumnHeadersForCard(this.cardId) + getGroupedHeadersForCard(this.cardId) ); this.chartMetadataMap$ = partitionedSeries$.pipe( @@ -600,6 +613,23 @@ export class ScalarCardContainer implements CardRenderer, OnInit, OnDestroy { ); this.hparamsEnabled$ = this.store.select(getEnableHparamsInTimeSeries); + + this.runToHparams$ = this.store.select(getDashboardRunsToHparams).pipe( + map((runToHparamsAndMetrics: RunToHparamsAndMetrics): RunToHparams => { + const runToHparams: RunToHparams = {}; + for (const [runName, {hparams}] of Object.entries( + runToHparamsAndMetrics + )) { + runToHparams[runName] = {}; + for (const {name: hparamName, value} of hparams) { + runToHparams[runName][hparamName] = value; + } + } + return runToHparams; + }) + ); + + this.selectableColumns$ = this.store.select(getSelectableColumns); } ngOnDestroy() { @@ -686,15 +716,55 @@ export class ScalarCardContainer implements CardRenderer, OnInit, OnDestroy { ); } - editColumnHeaders(headerEditInfo: HeaderEditInfo) { - this.store.dispatch(dataTableColumnEdited(headerEditInfo)); + editColumnHeaders({ + source, + destination, + side, + dataTableMode, + }: HeaderEditInfo) { + if (source.type === 'HPARAM') { + this.store.dispatch( + hparamsActions.dashboardHparamColumnOrderChanged({ + source, + destination, + side, + }) + ); + } else { + this.store.dispatch( + dataTableColumnOrderChanged({source, destination, side, dataTableMode}) + ); + } } openTableEditMenuToMode(tableMode: DataTableMode) { this.store.dispatch(metricsSlideoutMenuOpened({mode: tableMode})); } - onRemoveColumn(header: ColumnHeader) { - this.store.dispatch(dataTableColumnToggled({header, cardId: this.cardId})); + onAddColumn(addColumnEvent: AddColumnEvent) { + this.store.dispatch( + hparamsActions.dashboardHparamColumnAdded(addColumnEvent) + ); + } + + onHideColumn({header, dataTableMode}: HeaderToggleInfo) { + if (header.type === 'HPARAM') { + this.store.dispatch( + hparamsActions.dashboardHparamColumnToggled({column: header}) + ); + } else { + this.store.dispatch( + dataTableColumnToggled({header, cardId: this.cardId, dataTableMode}) + ); + } + } + + addHparamFilter(event: FilterAddedEvent) { + this.store.dispatch( + hparamsActions.dashboardHparamFilterAdded({ + name: event.name, + filter: event.value, + }) + ); } } diff --git a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_data_table.ng.html b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_data_table.ng.html index 9f797a6da6..62d61eacd8 100644 --- a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_data_table.ng.html +++ b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_data_table.ng.html @@ -18,10 +18,13 @@ @@ -30,7 +33,6 @@ [header]="header" [sortingInfo]="sortingInfo" [hparamsEnabled]="hparamsEnabled" - disableContextMenu="true" > diff --git a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_data_table.ts b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_data_table.ts index e73194447a..4c184723bc 100644 --- a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_data_table.ts +++ b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_data_table.ts @@ -21,11 +21,13 @@ import { } from '@angular/core'; import {TimeSelection} from '../../../widgets/card_fob/card_fob_types'; import {findClosestIndex} from '../../../widgets/line_chart_v2/sub_view/line_chart_interactive_utils'; -import {HeaderEditInfo} from '../../types'; +import {HeaderEditInfo, HeaderToggleInfo} from '../../types'; +import {RunToHparams} from '../../../runs/types'; import { ScalarCardDataSeries, ScalarCardPoint, ScalarCardSeriesMetadataMap, + SmoothedSeriesMetadata, } from './scalar_card_types'; import { ColumnHeader, @@ -34,6 +36,11 @@ import { TableData, SortingInfo, SortingOrder, + DiscreteFilter, + IntervalFilter, + FilterAddedEvent, + ReorderColumnEvent, + AddColumnEvent, } from '../../../widgets/data_table/types'; import {isDatumVisible} from './utils'; @@ -52,12 +59,15 @@ export class ScalarCardDataTable { @Input() columnCustomizationEnabled!: boolean; @Input() smoothingEnabled!: boolean; @Input() hparamsEnabled?: boolean; + @Input() columnFilters!: Map; + @Input() selectableColumns!: ColumnHeader[]; + @Input() runToHparams!: RunToHparams; @Output() sortDataBy = new EventEmitter(); @Output() editColumnHeaders = new EventEmitter(); - @Output() removeColumn = new EventEmitter<{ - headerType: ColumnHeaderType; - }>(); + @Output() hideColumn = new EventEmitter(); + @Output() addColumn = new EventEmitter(); + @Output() addFilter = new EventEmitter(); ColumnHeaderType = ColumnHeaderType; @@ -71,6 +81,7 @@ export class ScalarCardDataTable { }, ].concat(this.columnHeaders); } + getMinPointInRange( points: ScalarCardPoint[], startPointIndex: number, @@ -253,6 +264,16 @@ export class ScalarCardDataTable { selectedStepData[header.name] = closestEndPoint.value - closestStartPoint.value; continue; + case ColumnHeaderType.HPARAM: + let runId: string; + if ((metadata as SmoothedSeriesMetadata).originalSeriesId) { + runId = (metadata as SmoothedSeriesMetadata).originalSeriesId; + } else { + runId = metadata.id; + } + selectedStepData[header.name] = + this.runToHparams?.[runId]?.[header.name] ?? ''; + continue; default: continue; } @@ -292,14 +313,24 @@ export class ScalarCardDataTable { return makeValueSortable(point[header.name]); } - orderColumns(headers: ColumnHeader[]) { + private getDataTableMode(): DataTableMode { + return this.stepOrLinkedTimeSelection.end + ? DataTableMode.RANGE + : DataTableMode.SINGLE; + } + + onOrderColumns({source, destination, side}: ReorderColumnEvent) { this.editColumnHeaders.emit({ - headers: headers, - dataTableMode: this.stepOrLinkedTimeSelection.end - ? DataTableMode.RANGE - : DataTableMode.SINGLE, + source, + destination, + side, + dataTableMode: this.getDataTableMode(), }); } + + onHideColumn(header: ColumnHeader) { + this.hideColumn.emit({header, dataTableMode: this.getDataTableMode()}); + } } function makeValueSortable( diff --git a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts index bbe739df71..c9093625f7 100644 --- a/tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts +++ b/tensorboard/webapp/metrics/views/card_renderer/scalar_card_test.ts @@ -81,7 +81,7 @@ import { stepSelectorToggled, timeSelectionChanged, metricsSlideoutMenuOpened, - dataTableColumnEdited, + dataTableColumnOrderChanged, dataTableColumnToggled, } from '../../actions'; import {PluginType} from '../../data_source'; @@ -114,19 +114,28 @@ import { SeriesType, } from './scalar_card_types'; import { + AddColumnEvent, ColumnHeader, ColumnHeaderType, DataTableMode, + DomainType, + FilterAddedEvent, + IntervalFilter, + ReorderColumnEvent, + Side, SortingOrder, } from '../../../widgets/data_table/types'; import {VisLinkedTimeSelectionWarningModule} from './vis_linked_time_selection_warning_module'; import {Extent} from '../../../widgets/line_chart_v2/lib/public_types'; import {provideMockTbStore} from '../../../testing/utils'; import * as commonSelectors from '../main_view/common_selectors'; -import {CardFeatureOverride} from '../../store/metrics_types'; import {ContentCellComponent} from '../../../widgets/data_table/content_cell_component'; import {ContentRowComponent} from '../../../widgets/data_table/content_row_component'; import {HeaderCellComponent} from '../../../widgets/data_table/header_cell_component'; +import {HparamFilter} from '../../../hparams/_redux/types'; +import * as hparamsSelectors from '../../../hparams/_redux/hparams_selectors'; +import * as hparamsActions from '../../../hparams/_redux/hparams_actions'; +import * as runsSelectors from '../../../runs/store/runs_selectors'; @Component({ selector: 'line-chart', @@ -2779,6 +2788,103 @@ describe('scalar card', () => { contentCellTypes.find((type) => type === ColumnHeaderType.SMOOTHED) ).toBeFalsy(); })); + + it('passes columnFilters to table', fakeAsync(() => { + store.overrideSelector( + commonSelectors.getCurrentColumnFilters, + new Map([ + [ + 'discrete hparam', + { + type: DomainType.DISCRETE, + includeUndefined: true, + possibleValues: [2, 4, 6, 8], + filterValues: [2, 4, 6, 8], + }, + ], + [ + 'interval metric', + { + type: DomainType.INTERVAL, + includeUndefined: true, + minValue: 2, + maxValue: 5, + filterLowerValue: 2, + filterUpperValue: 5, + }, + ], + ]) + ); + const fixture = createComponent('card1'); + fixture.detectChanges(); + + const dataTableComponentInstance = fixture.debugElement.query( + By.directive(DataTableComponent) + ).componentInstance; + + expect(dataTableComponentInstance.columnFilters).toEqual( + new Map([ + [ + 'discrete hparam', + { + type: DomainType.DISCRETE, + includeUndefined: true, + possibleValues: [2, 4, 6, 8], + filterValues: [2, 4, 6, 8], + }, + ], + [ + 'interval metric', + { + type: DomainType.INTERVAL, + includeUndefined: true, + minValue: 2, + maxValue: 5, + filterLowerValue: 2, + filterUpperValue: 5, + }, + ], + ]) + ); + })); + + it('passes selectableColumns to table', fakeAsync(() => { + store.overrideSelector(commonSelectors.getSelectableColumns, [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ]); + const fixture = createComponent('card1'); + fixture.detectChanges(); + + const dataTableComponentInstance = fixture.debugElement.query( + By.directive(DataTableComponent) + ).componentInstance; + + expect(dataTableComponentInstance.selectableColumns).toEqual([ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ]); + })); }); describe('line chart integration', () => { @@ -3854,20 +3960,54 @@ describe('scalar card', () => { ]) ); + store.overrideSelector(getMetricsLinkedTimeSelection, { + start: {step: 1}, + end: null, + }); + store.overrideSelector( commonSelectors.getFilteredRenderableRunsIds, new Set(['run1']) ); - store.overrideSelector(getMetricsLinkedTimeSelection, { - start: {step: 1}, - end: null, + store.overrideSelector(selectors.getEnableHparamsInTimeSeries, true); + + store.overrideSelector( + hparamsSelectors.getDashboardDisplayedHparamColumns, + [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ] + ); + store.overrideSelector(runsSelectors.getDashboardRunsToHparams, { + run1: { + hparams: [ + {name: 'conv_layers', value: 1}, + {name: 'conv_kernel_size', value: 2}, + ], + metrics: [], + }, + run2: { + hparams: [ + {name: 'conv_layers', value: 3}, + {name: 'conv_kernel_size', value: 4}, + ], + metrics: [], + }, }); }); it('filters runs by hparam when enableHparamsInTimeSeries is true', fakeAsync(() => { - store.overrideSelector(selectors.getEnableHparamsInTimeSeries, true); - const fixture = createComponent('card1'); const scalarCardDataTable = fixture.debugElement.query( By.directive(ScalarCardDataTable) @@ -3875,13 +4015,13 @@ describe('scalar card', () => { const data = scalarCardDataTable.componentInstance.getTimeSelectionTableData(); + expect(data.length).toEqual(1); expect(data[0].run).toEqual('run1'); })); it('does not filter runs by hparam when enableHparamsInTimeSeries is false', fakeAsync(() => { store.overrideSelector(selectors.getEnableHparamsInTimeSeries, false); - const fixture = createComponent('card1'); const scalarCardDataTable = fixture.debugElement.query( By.directive(ScalarCardDataTable) @@ -3889,10 +4029,66 @@ describe('scalar card', () => { const data = scalarCardDataTable.componentInstance.getTimeSelectionTableData(); + expect(data.length).toEqual(2); expect(data[0].run).toEqual('run1'); expect(data[1].run).toEqual('run2'); })); + + it('shows hparam values for selected hparam columns', fakeAsync(() => { + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIds, + new Set(['run1', 'run2']) + ); + const fixture = createComponent('card1'); + const scalarCardDataTable = fixture.debugElement.query( + By.directive(ScalarCardDataTable) + ); + + const data = + scalarCardDataTable.componentInstance.getTimeSelectionTableData(); + + expect(data).toEqual([ + jasmine.objectContaining({ + id: 'run1', + conv_layers: 1, + conv_kernel_size: 2, + }), + jasmine.objectContaining({ + id: 'run2', + conv_layers: 3, + conv_kernel_size: 4, + }), + ]); + })); + + it('shows hparam values with smoothing enabled', fakeAsync(() => { + store.overrideSelector( + commonSelectors.getFilteredRenderableRunsIds, + new Set(['run1', 'run2']) + ); + store.overrideSelector(selectors.getMetricsScalarSmoothing, 0.3); + const fixture = createComponent('card1'); + const scalarCardDataTable = fixture.debugElement.query( + By.directive(ScalarCardDataTable) + ); + + const data = + scalarCardDataTable.componentInstance.getTimeSelectionTableData(); + + expect(data).toEqual([ + jasmine.objectContaining({ + id: '["smoothed","run1"]', + conv_layers: 1, + conv_kernel_size: 2, + }), + jasmine.objectContaining({ + id: '["smoothed","run2"]', + conv_layers: 3, + conv_kernel_size: 4, + }), + ]); + })); }); }); @@ -4411,7 +4607,7 @@ describe('scalar card', () => { expect(dataTableComponent).toBeFalsy(); })); - it('emits dataTableColumnEdited with DataTableMode.SINGLE when orderColumns is called while in Single Selection', fakeAsync(() => { + it('emits dataTableColumnOrderChanged with DataTableMode.SINGLE when orderColumns is called while in Single Selection', fakeAsync(() => { store.overrideSelector(getCardStateMap, { card1: { dataMinMax: { @@ -4430,32 +4626,35 @@ describe('scalar card', () => { const scalarCardDataTable = fixture.debugElement.query( By.directive(ScalarCardDataTable) ); - - const headers = [ - { + const reorderColumnEvent: ReorderColumnEvent = { + source: { type: ColumnHeaderType.RUN, name: 'run', displayName: 'Run', enabled: true, }, - { + destination: { type: ColumnHeaderType.VALUE, name: 'value', displayName: 'Value', enabled: true, }, - ]; - scalarCardDataTable.componentInstance.orderColumns(headers); + side: Side.RIGHT, + }; + + scalarCardDataTable.componentInstance.onOrderColumns( + reorderColumnEvent + ); expect(dispatchedActions).toEqual([ - dataTableColumnEdited({ - headers, + dataTableColumnOrderChanged({ + ...reorderColumnEvent, dataTableMode: DataTableMode.SINGLE, }), ]); })); - it('emits dataTableColumnEdited with DataTableMode.RANGE when orderColumns is called while in Range Selection', fakeAsync(() => { + it('emits dataTableColumnOrderChanged with DataTableMode.RANGE when orderColumns is called while in Range Selection', fakeAsync(() => { store.overrideSelector(getCardStateMap, { card1: { dataMinMax: { @@ -4474,113 +4673,242 @@ describe('scalar card', () => { const scalarCardDataTable = fixture.debugElement.query( By.directive(ScalarCardDataTable) ); - - const headers = [ - { + const reorderColumnEvent: ReorderColumnEvent = { + source: { type: ColumnHeaderType.RUN, name: 'run', displayName: 'Run', enabled: true, }, - { + destination: { type: ColumnHeaderType.VALUE, name: 'value', displayName: 'Value', enabled: true, }, - ]; - scalarCardDataTable.componentInstance.orderColumns(headers); + side: Side.RIGHT, + }; + + scalarCardDataTable.componentInstance.onOrderColumns( + reorderColumnEvent + ); expect(dispatchedActions).toEqual([ - dataTableColumnEdited({ - headers, + dataTableColumnOrderChanged({ + ...reorderColumnEvent, dataTableMode: DataTableMode.RANGE, }), ]); })); - it('emits dataTableColumnToggled when onRemoveColumn is called with range selection disabled', fakeAsync(() => { - store.overrideSelector(getSingleSelectionHeaders, [ - { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', + it('dispatches dashboardHparamColumnOrderChanged when reordering hparam columns', fakeAsync(() => { + store.overrideSelector(getCardStateMap, { + card1: { + dataMinMax: { + minStep: 0, + maxStep: 100, + }, + }, + }); + store.overrideSelector(getMetricsCardTimeSelection, { + start: {step: 1}, + end: null, + }); + store.overrideSelector(selectors.getMetricsStepSelectorEnabled, true); + const fixture = createComponent('card1'); + fixture.detectChanges(); + const dataTableComponentInstance = fixture.debugElement.query( + By.directive(DataTableComponent) + ).componentInstance; + const reorderColumnEvent: ReorderColumnEvent = { + source: { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', enabled: true, }, - { - type: ColumnHeaderType.VALUE, - name: 'value', - displayName: 'Value', - enabled: false, + destination: { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, }, + side: Side.RIGHT, + }; + + dataTableComponentInstance.orderColumns.emit(reorderColumnEvent); + + expect(dispatchedActions).toEqual([ + hparamsActions.dashboardHparamColumnOrderChanged(reorderColumnEvent), ]); + })); + + it('dispatches dashboardHparamColumnAdded on column add event', fakeAsync(() => { store.overrideSelector(getCardStateMap, { card1: { - rangeSelectionOverride: CardFeatureOverride.OVERRIDE_AS_DISABLED, + dataMinMax: { + minStep: 0, + maxStep: 100, + }, }, }); + store.overrideSelector(getMetricsCardTimeSelection, { + start: {step: 1}, + end: null, + }); + store.overrideSelector(selectors.getMetricsStepSelectorEnabled, true); const fixture = createComponent('card1'); fixture.detectChanges(); + const dataTableComponentInstance = fixture.debugElement.query( + By.directive(DataTableComponent) + ).componentInstance; + const addColumnEvent: AddColumnEvent = { + column: { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + nextTo: { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + side: Side.RIGHT, + }; - fixture.componentInstance.onRemoveColumn({ - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', - enabled: true, - }); + dataTableComponentInstance.addColumn.emit(addColumnEvent); expect(dispatchedActions).toEqual([ - dataTableColumnToggled({ - cardId: 'card1', - header: { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', - enabled: true, - }, - }), + hparamsActions.dashboardHparamColumnAdded(addColumnEvent), ]); })); - it('emits dataTableColumnToggled when onRemoveColumn is called with range selection enabled', fakeAsync(() => { - store.overrideSelector(getRangeSelectionHeaders, [ - { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', - enabled: true, + [ + { + testDesc: 'for single selection', + timeSelectionOverride: { + start: {step: 1}, + end: null, }, - { - type: ColumnHeaderType.MIN_VALUE, - name: 'minValue', - displayName: 'Min Value', + expectedDataTableMode: DataTableMode.SINGLE, + }, + { + testDesc: 'for range selection', + timeSelectionOverride: { + start: {step: 1}, + end: {step: 20}, + }, + expectedDataTableMode: DataTableMode.RANGE, + }, + ].forEach(({testDesc, timeSelectionOverride, expectedDataTableMode}) => { + it(`dispatches dataTableColumnToggled on column hide event ${testDesc}`, fakeAsync(() => { + store.overrideSelector(getCardStateMap, { + card1: { + dataMinMax: { + minStep: 0, + maxStep: 100, + }, + }, + }); + store.overrideSelector( + getMetricsCardTimeSelection, + timeSelectionOverride + ); + store.overrideSelector(selectors.getMetricsStepSelectorEnabled, true); + const fixture = createComponent('card1'); + fixture.detectChanges(); + const dataTableComponentInstance = fixture.debugElement.query( + By.directive(DataTableComponent) + ).componentInstance; + const columnToHide: ColumnHeader = { + type: ColumnHeaderType.VALUE, + name: 'value', + displayName: 'Value', enabled: true, + }; + + dataTableComponentInstance.hideColumn.emit(columnToHide); + + expect(dispatchedActions).toEqual([ + dataTableColumnToggled({ + header: columnToHide, + cardId: 'card1', + dataTableMode: expectedDataTableMode, + }), + ]); + })); + }); + + it('dispatches dashboardHparamColumnToggled on column hide event for hparam columns', fakeAsync(() => { + store.overrideSelector(getCardStateMap, { + card1: { + dataMinMax: { + minStep: 0, + maxStep: 100, + }, }, + }); + store.overrideSelector(getMetricsCardTimeSelection, { + start: {step: 1}, + end: null, + }); + store.overrideSelector(selectors.getMetricsStepSelectorEnabled, true); + const fixture = createComponent('card1'); + fixture.detectChanges(); + const dataTableComponentInstance = fixture.debugElement.query( + By.directive(DataTableComponent) + ).componentInstance; + const columnToHide: ColumnHeader = { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }; + + dataTableComponentInstance.hideColumn.emit(columnToHide); + + expect(dispatchedActions).toEqual([ + hparamsActions.dashboardHparamColumnToggled({column: columnToHide}), ]); + })); + + it('dispatches dashboardHparamFilterAdded on column filter event', fakeAsync(() => { store.overrideSelector(getCardStateMap, { card1: { - rangeSelectionOverride: CardFeatureOverride.OVERRIDE_AS_ENABLED, + dataMinMax: { + minStep: 0, + maxStep: 100, + }, }, }); + store.overrideSelector(getMetricsCardTimeSelection, { + start: {step: 1}, + end: null, + }); + store.overrideSelector(selectors.getMetricsStepSelectorEnabled, true); const fixture = createComponent('card1'); fixture.detectChanges(); + const dataTableComponentInstance = fixture.debugElement.query( + By.directive(DataTableComponent) + ).componentInstance; + const filterAddedEvent: FilterAddedEvent = { + name: 'conv_kernel_size', + value: { + type: DomainType.DISCRETE, + includeUndefined: true, + filterValues: [5], + possibleValues: [5, 7, 8], + }, + }; - fixture.componentInstance.onRemoveColumn({ - type: ColumnHeaderType.MIN_VALUE, - name: 'minValue', - displayName: 'Min Value', - enabled: true, - }); + dataTableComponentInstance.addFilter.emit(filterAddedEvent); expect(dispatchedActions).toEqual([ - dataTableColumnToggled({ - cardId: 'card1', - header: { - type: ColumnHeaderType.MIN_VALUE, - name: 'minValue', - displayName: 'Min Value', - enabled: true, - }, + hparamsActions.dashboardHparamFilterAdded({ + name: filterAddedEvent.name, + filter: filterAddedEvent.value, }), ]); })); diff --git a/tensorboard/webapp/metrics/views/main_view/BUILD b/tensorboard/webapp/metrics/views/main_view/BUILD index a0a2f17691..13e2f82ef5 100644 --- a/tensorboard/webapp/metrics/views/main_view/BUILD +++ b/tensorboard/webapp/metrics/views/main_view/BUILD @@ -91,6 +91,7 @@ tf_ts_library( "//tensorboard/webapp/metrics/views:utils", "//tensorboard/webapp/metrics/views/card_renderer:scalar_card_types", "//tensorboard/webapp/runs:types", + "//tensorboard/webapp/runs/store:selectors", "//tensorboard/webapp/runs/views/runs_table:types", "//tensorboard/webapp/util:matcher", "//tensorboard/webapp/util:memoize", diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts index 20be847c66..0836af7fb4 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors.ts @@ -33,6 +33,7 @@ import { getDashboardHparamsAndMetricsSpecs, getDashboardHparamFilterMap, getDashboardDefaultHparamFilters, + getDashboardDisplayedHparamColumns, } from '../../../hparams/_redux/hparams_selectors'; import { DiscreteFilter, @@ -287,10 +288,22 @@ export const getPotentialHparamColumns = createSelector( sortable: true, movable: true, filterable: true, + hidable: true, })); } ); +export const getSelectableColumns = createSelector( + getPotentialHparamColumns, + getDashboardDisplayedHparamColumns, + (potentialColumns, currentColumns) => { + const currentColumnNames = new Set(currentColumns.map(({name}) => name)); + return potentialColumns.filter((columnHeader) => { + return !currentColumnNames.has(columnHeader.name); + }); + } +); + export const getAllPotentialColumnsForCard = memoize((cardId: string) => { return createSelector( getColumnHeadersForCard(cardId), diff --git a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts index 56e59784be..11b133f54b 100644 --- a/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts +++ b/tensorboard/webapp/metrics/views/main_view/common_selectors_test.ts @@ -161,7 +161,35 @@ describe('common selectors', () => { run4, }, } as any, - ui: {} as any, + ui: { + runsTableHeaders: [ + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'Run', + enabled: true, + sortable: true, + removable: false, + movable: false, + filterable: false, + hidable: false, + }, + { + type: ColumnHeaderType.CUSTOM, + name: 'experimentAlias', + displayName: 'Experiment', + enabled: true, + movable: false, + sortable: true, + }, + { + type: ColumnHeaderType.CUSTOM, + name: 'fakeRunsHeader', + displayName: 'Fake Runs Header', + enabled: true, + }, + ], + } as any, }, experiments: { data: { @@ -182,10 +210,35 @@ describe('common selectors', () => { }, hparams: { dashboardSpecs: { - hparams: [buildHparamSpec({name: 'foo', displayName: 'Foo'})], + hparams: [ + buildHparamSpec({name: 'conv_layers', displayName: 'Conv Layers'}), + buildHparamSpec({ + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + }), + buildHparamSpec({ + name: 'dense_layers', + displayName: 'Dense Layers', + }), + buildHparamSpec({name: 'dropout', displayName: 'Dropout'}), + ], metrics: [buildMetricSpec({displayName: 'Bar'})], }, dashboardSessionGroups: [], + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + enabled: true, + }, + ], } as any, }); }); @@ -962,6 +1015,15 @@ describe('common selectors', () => { }); describe('getPotentialHparamColumns', () => { + const expectedBooleanFlags = { + enabled: false, + removable: true, + sortable: true, + movable: true, + filterable: true, + hidable: true, + }; + it('returns empty list when there are no experiments', () => { state.app_routing!.activeRoute!.routeKind = RouteKind.EXPERIMENTS; @@ -972,42 +1034,87 @@ describe('common selectors', () => { expect(selectors.getPotentialHparamColumns(state)).toEqual([ { type: ColumnHeaderType.HPARAM, - name: 'foo', - displayName: 'Foo', - enabled: false, - removable: true, - sortable: true, - movable: true, - filterable: true, + name: 'conv_layers', + displayName: 'Conv Layers', + ...expectedBooleanFlags, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + ...expectedBooleanFlags, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + ...expectedBooleanFlags, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'dropout', + displayName: 'Dropout', + ...expectedBooleanFlags, }, ]); }); it('sets name as display name when a display name is not provided', () => { - state.hparams!.dashboardSpecs.hparams.push( - buildHparamSpec({name: 'bar', displayName: ''}) - ); + state.hparams!.dashboardSpecs.hparams = [ + buildHparamSpec({name: 'conv_layers', displayName: ''}), + ]; + expect(selectors.getPotentialHparamColumns(state)).toEqual([ { type: ColumnHeaderType.HPARAM, - name: 'foo', - displayName: 'Foo', - enabled: false, - removable: true, - sortable: true, - movable: true, - filterable: true, + name: 'conv_layers', + displayName: 'conv_layers', + ...expectedBooleanFlags, }, - { + ]); + }); + }); + + describe('getSelectableColumns', () => { + it('returns the full list of hparam columns if none are currently displayed', () => { + state.hparams!.dashboardDisplayedHparamColumns = []; + + expect(selectors.getSelectableColumns(state)).toEqual([ + jasmine.objectContaining({ type: ColumnHeaderType.HPARAM, - name: 'bar', - displayName: 'bar', - enabled: false, - removable: true, - sortable: true, - movable: true, - filterable: true, - }, + name: 'conv_layers', + displayName: 'Conv Layers', + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'dropout', + displayName: 'Dropout', + }), + ]); + }); + + it('returns only columns that are not displayed', () => { + expect(selectors.getSelectableColumns(state)).toEqual([ + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'dropout', + displayName: 'Dropout', + }), ]); }); }); diff --git a/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/BUILD b/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/BUILD index 4de20c5ee4..a0e21eac3e 100644 --- a/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/BUILD +++ b/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/BUILD @@ -27,10 +27,16 @@ tf_ng_module( "//tensorboard/webapp:app_state", "//tensorboard/webapp:selectors", "//tensorboard/webapp/angular:expect_angular_material_checkbox", + "//tensorboard/webapp/angular:expect_angular_material_icon", "//tensorboard/webapp/angular:expect_angular_material_tabs", + "//tensorboard/webapp/feature_flag/store", + "//tensorboard/webapp/hparams", "//tensorboard/webapp/metrics:types", "//tensorboard/webapp/metrics/actions", "//tensorboard/webapp/metrics/store", + "//tensorboard/webapp/metrics/views/main_view:common_selectors", + "//tensorboard/webapp/widgets/custom_modal", + "//tensorboard/webapp/widgets/data_table:column_selector", "//tensorboard/webapp/widgets/data_table:data_table_header", "//tensorboard/webapp/widgets/data_table:types", "@npm//@angular/common", @@ -54,8 +60,14 @@ tf_ts_library( "//tensorboard/webapp/angular:expect_angular_material_tabs", "//tensorboard/webapp/angular:expect_angular_platform_browser_animations", "//tensorboard/webapp/angular:expect_ngrx_store_testing", + "//tensorboard/webapp/feature_flag/store", + "//tensorboard/webapp/hparams/_redux:hparams_actions", + "//tensorboard/webapp/hparams/_redux:hparams_selectors", "//tensorboard/webapp/metrics/actions", "//tensorboard/webapp/metrics/store", + "//tensorboard/webapp/testing:utils", + "//tensorboard/webapp/widgets/custom_modal", + "//tensorboard/webapp/widgets/data_table:column_selector", "//tensorboard/webapp/widgets/data_table:data_table_header", "//tensorboard/webapp/widgets/data_table:types", "@npm//@angular/core", diff --git a/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_component.ng.html b/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_component.ng.html index fc0f611463..1ba1791bd4 100644 --- a/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_component.ng.html +++ b/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_component.ng.html @@ -21,16 +21,16 @@ (selectedTabChange)="tabChange($event)" > - + > - + > @@ -66,4 +66,41 @@ > + +
+

Hyperparameters

+ +
+
+
+ +
+
+
+ + + + diff --git a/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_component.scss b/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_component.scss index 2a3f3c0464..43133547d6 100644 --- a/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_component.scss +++ b/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_component.scss @@ -85,3 +85,22 @@ $_accent: map-get(mat.get-color-config($tb-theme), accent); ::ng-deep .mat-mdc-tab-body-wrapper { flex: 1; } + +.hparams { + &-header { + display: flex; + align-items: center; + margin: 10px 10px 0; + + &-add-button { + height: 30px; + width: 30px; + } + } + + &-title { + margin: 0; + font-size: 14px; + font-weight: normal; + } +} diff --git a/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_component.ts b/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_component.ts index b17aa03896..87668123b3 100644 --- a/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_component.ts +++ b/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_component.ts @@ -20,32 +20,23 @@ import { Input, OnDestroy, Output, + ViewChild, } from '@angular/core'; import {MatTabChangeEvent} from '@angular/material/tabs'; +import {CustomModalComponent} from '../../../../widgets/custom_modal/custom_modal_component'; +import {ColumnSelectorComponent} from '../../../../widgets/data_table/column_selector_component'; import { + AddColumnEvent, ColumnHeader, - ColumnHeaderType, DataTableMode, + Side, } from '../../../../widgets/data_table/types'; +import {HeaderEditInfo} from '../../../types'; const preventDefault = (e: MouseEvent) => { e.preventDefault(); }; -// Move the item at sourceIndex to destinationIndex -const moveHeader = ( - sourceIndex: number, - destinationIndex: number, - headers: ColumnHeader[] -) => { - const newHeaders = [...headers]; - // Delete from original location - newHeaders.splice(sourceIndex, 1); - // Insert at destinationIndex. - newHeaders.splice(destinationIndex, 0, headers[sourceIndex]); - return newHeaders; -}; - const getIndexOfColumn = (column: ColumnHeader, headers: ColumnHeader[]) => { return headers.findIndex((header) => { return header.name === column.name; @@ -70,18 +61,24 @@ export class ScalarColumnEditorComponent implements OnDestroy { highlightEdge: Edge = Edge.TOP; @Input() rangeHeaders!: ColumnHeader[]; @Input() singleHeaders!: ColumnHeader[]; + @Input() hparamHeaders!: ColumnHeader[]; + @Input() hparamsEnabled!: boolean; @Input() selectedTab!: DataTableMode; + @Input() selectableColumns!: ColumnHeader[]; - @Output() onScalarTableColumnEdit = new EventEmitter<{ - dataTableMode: DataTableMode; - headers: ColumnHeader[]; - }>(); + @Output() onScalarTableColumnEdit = new EventEmitter(); @Output() onScalarTableColumnToggled = new EventEmitter<{ dataTableMode: DataTableMode; header: ColumnHeader; }>(); @Output() onScalarTableColumnEditorClosed = new EventEmitter(); @Output() onTabChange = new EventEmitter(); + @Output() onColumnAdded = new EventEmitter(); + + @ViewChild('columnSelectorModal', {static: false}) + private readonly columnSelectorModal!: CustomModalComponent; + @ViewChild(ColumnSelectorComponent, {static: false}) + private readonly columnSelector!: ColumnSelectorComponent; constructor(private readonly hostElement: ElementRef) {} @@ -107,15 +104,23 @@ export class ScalarColumnEditorComponent implements OnDestroy { if (!this.draggingHeader || !this.highlightedHeader) { return; } - const headers = this.getHeadersForMode(dataTableMode); - this.onScalarTableColumnEdit.emit({ - dataTableMode: dataTableMode, - headers: moveHeader( - getIndexOfColumn(this.draggingHeader, headers), - getIndexOfColumn(this.highlightedHeader, headers), - headers - ), - }); + let headers: ColumnHeader[]; + if (this.draggingHeader.type === 'HPARAM') { + headers = this.hparamHeaders; + } else { + headers = this.getHeadersForMode(dataTableMode); + } + const source = {...this.draggingHeader}; + const destination = {...this.highlightedHeader}; + if (source && destination && source.name !== destination.name) { + this.onScalarTableColumnEdit.emit({ + source, + destination, + side: this.highlightEdge === Edge.TOP ? Side.LEFT : Side.RIGHT, + dataTableMode, + }); + } + this.draggingHeader = undefined; this.highlightedHeader = undefined; this.hostElement.nativeElement.removeEventListener( @@ -129,8 +134,21 @@ export class ScalarColumnEditorComponent implements OnDestroy { return; } + // Prevent hparam columns from interacting with standard columns. + if ( + [this.draggingHeader, header].some((h) => h.type === 'HPARAM') && + this.draggingHeader.type !== header.type + ) { + return; + } + // Highlight the position which the dragging header will go when dropped. - const headers = this.getHeadersForMode(dataTableMode); + let headers: ColumnHeader[]; + if (this.draggingHeader.type === 'HPARAM') { + headers = this.hparamHeaders; + } else { + headers = this.getHeadersForMode(dataTableMode); + } if ( getIndexOfColumn(header, headers) < getIndexOfColumn(this.draggingHeader, headers) @@ -171,4 +189,23 @@ export class ScalarColumnEditorComponent implements OnDestroy { ? this.singleHeaders : this.rangeHeaders; } + + openColumnSelector(event: MouseEvent) { + const rect = ( + (event.target as HTMLElement).closest('button') as HTMLButtonElement + ).getBoundingClientRect(); + this.columnSelectorModal.openAtPosition({ + x: rect.x + rect.width, + y: rect.y, + }); + this.columnSelector.activate(); + } + + focusColumnSelector() { + this.columnSelector.focus(); + } + + onColumnSelected(header: ColumnHeader) { + this.onColumnAdded.emit({column: header}); + } } diff --git a/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_container.ts b/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_container.ts index 22136ad63a..adf9f988d1 100644 --- a/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_container.ts +++ b/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_container.ts @@ -16,7 +16,7 @@ import {ChangeDetectionStrategy, Component} from '@angular/core'; import {Store} from '@ngrx/store'; import {State} from '../../../../app_state'; import { - dataTableColumnEdited, + dataTableColumnOrderChanged, dataTableColumnToggled, metricsSlideoutMenuClosed, tableEditorTabChanged, @@ -27,7 +27,22 @@ import { getTableEditorSelectedTab, } from '../../../store/metrics_selectors'; import {HeaderEditInfo, HeaderToggleInfo} from '../../../types'; -import {DataTableMode} from '../../../../widgets/data_table/types'; +import { + AddColumnEvent, + ColumnHeader, + DataTableMode, +} from '../../../../widgets/data_table/types'; +import {map} from 'rxjs'; +import {getSelectableColumns} from '../../main_view/common_selectors'; +import { + actions as hparamsActions, + selectors as hparamSelectors, +} from '../../../../hparams'; +import {getEnableHparamsInTimeSeries} from '../../../../feature_flag/store/feature_flag_selectors'; + +function headersWithoutRuns(headers: ColumnHeader[]) { + return headers.filter((header) => header.type !== 'RUN'); +} @Component({ selector: 'metrics-scalar-column-editor', @@ -35,11 +50,15 @@ import {DataTableMode} from '../../../../widgets/data_table/types'; `, @@ -48,16 +67,48 @@ import {DataTableMode} from '../../../../widgets/data_table/types'; export class ScalarColumnEditorContainer { constructor(private readonly store: Store) {} - readonly singleHeaders$ = this.store.select(getSingleSelectionHeaders); - readonly rangeHeaders$ = this.store.select(getRangeSelectionHeaders); + readonly singleHeaders$ = this.store + .select(getSingleSelectionHeaders) + .pipe(map(headersWithoutRuns)); + readonly rangeHeaders$ = this.store + .select(getRangeSelectionHeaders) + .pipe(map(headersWithoutRuns)); + readonly hparamHeaders$ = this.store.select( + hparamSelectors.getDashboardDisplayedHparamColumns + ); readonly selectedTab$ = this.store.select(getTableEditorSelectedTab); + readonly selectableColumns$ = this.store.select(getSelectableColumns); + readonly hparamsEnabled$ = this.store.select(getEnableHparamsInTimeSeries); - onScalarTableColumnToggled(toggleInfo: HeaderToggleInfo) { - this.store.dispatch(dataTableColumnToggled(toggleInfo)); + onScalarTableColumnToggled({dataTableMode, header}: HeaderToggleInfo) { + if (header.type === 'HPARAM') { + this.store.dispatch( + hparamsActions.dashboardHparamColumnToggled({column: header}) + ); + } else { + this.store.dispatch(dataTableColumnToggled({dataTableMode, header})); + } } - onScalarTableColumnEdit(editInfo: HeaderEditInfo) { - this.store.dispatch(dataTableColumnEdited(editInfo)); + onScalarTableColumnEdit({ + source, + destination, + side, + dataTableMode, + }: HeaderEditInfo) { + if (source.type === 'HPARAM') { + this.store.dispatch( + hparamsActions.dashboardHparamColumnOrderChanged({ + source, + destination, + side, + }) + ); + } else { + this.store.dispatch( + dataTableColumnOrderChanged({source, destination, side, dataTableMode}) + ); + } } onScalarTableColumnEditorClosed() { @@ -67,4 +118,8 @@ export class ScalarColumnEditorContainer { onTabChange(tab: DataTableMode) { this.store.dispatch(tableEditorTabChanged({tab})); } + + onColumnAdded(event: AddColumnEvent) { + this.store.dispatch(hparamsActions.dashboardHparamColumnAdded(event)); + } } diff --git a/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_module.ts b/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_module.ts index 508214d03d..9d6895c4de 100644 --- a/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_module.ts +++ b/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_module.ts @@ -21,6 +21,8 @@ import {MatButtonModule} from '@angular/material/button'; import {ScalarColumnEditorComponent} from './scalar_column_editor_component'; import {ScalarColumnEditorContainer} from './scalar_column_editor_container'; import {DataTableHeaderModule} from '../../../../widgets/data_table/data_table_header_module'; +import {CustomModalModule} from '../../../../widgets/custom_modal/custom_modal_module'; +import {ColumnSelectorModule} from '../../../../widgets/data_table/column_selector_module'; @NgModule({ declarations: [ScalarColumnEditorComponent, ScalarColumnEditorContainer], @@ -32,6 +34,8 @@ import {DataTableHeaderModule} from '../../../../widgets/data_table/data_table_h MatTabsModule, MatIconModule, MatButtonModule, + CustomModalModule, + ColumnSelectorModule, ], }) export class ScalarColumnEditorModule {} diff --git a/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_test.ts b/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_test.ts index 6731d9aa6a..78685e7eea 100644 --- a/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_test.ts +++ b/tensorboard/webapp/metrics/views/right_pane/scalar_column_editor/scalar_column_editor_test.ts @@ -26,7 +26,7 @@ import {Action, Store} from '@ngrx/store'; import {MockStore, provideMockStore} from '@ngrx/store/testing'; import {State} from '../../../../app_state'; import { - dataTableColumnEdited, + dataTableColumnOrderChanged, dataTableColumnToggled, metricsSlideoutMenuClosed, tableEditorTabChanged, @@ -39,11 +39,20 @@ import { import { ColumnHeaderType, DataTableMode, + Side, } from '../../../../widgets/data_table/types'; import {DataTableHeaderModule} from '../../../../widgets/data_table/data_table_header_module'; import {ScalarColumnEditorComponent} from './scalar_column_editor_component'; import {ScalarColumnEditorContainer} from './scalar_column_editor_container'; import {MatTabsModule} from '@angular/material/tabs'; +import {getEnableHparamsInTimeSeries} from '../../../../feature_flag/store/feature_flag_selectors'; +import {getDashboardDisplayedHparamColumns} from '../../../../hparams/_redux/hparams_selectors'; +import {provideMockTbStore} from '../../../../testing/utils'; +import {CustomModalComponent} from '../../../../widgets/custom_modal/custom_modal_component'; +import {ColumnSelectorComponent} from '../../../../widgets/data_table/column_selector_component'; +import {CustomModalModule} from '../../../../widgets/custom_modal/custom_modal_module'; +import {ColumnSelectorModule} from '../../../../widgets/data_table/column_selector_module'; +import * as hparamsActions from '../../../../hparams/_redux/hparams_actions'; describe('scalar column editor', () => { let store: MockStore; @@ -78,15 +87,38 @@ describe('scalar column editor', () => { MatTabsModule, NoopAnimationsModule, MatCheckboxModule, + CustomModalModule, + ColumnSelectorModule, ], declarations: [ScalarColumnEditorContainer, ScalarColumnEditorComponent], - providers: [provideMockStore()], + providers: [provideMockTbStore()], schemas: [NO_ERRORS_SCHEMA], }).compileComponents(); store = TestBed.inject>(Store) as MockStore; store.overrideSelector(getRangeSelectionHeaders, []); store.overrideSelector(getSingleSelectionHeaders, []); store.overrideSelector(getTableEditorSelectedTab, DataTableMode.SINGLE); + store.overrideSelector(getDashboardDisplayedHparamColumns, [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'dense_layers', + displayName: 'Dense Layers', + enabled: true, + }, + ]); + store.overrideSelector(getEnableHparamsInTimeSeries, false); }); afterEach(() => { @@ -101,9 +133,9 @@ describe('scalar column editor', () => { it('renders single selection headers when selectedTab is set to SINGLE', fakeAsync(() => { store.overrideSelector(getSingleSelectionHeaders, [ { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', enabled: true, }, { @@ -121,16 +153,16 @@ describe('scalar column editor', () => { ); expect(headerElements.length).toEqual(2); - expect(headerElements[0].nativeElement.innerText).toEqual('Run'); + expect(headerElements[0].nativeElement.innerText).toEqual('Smoothed'); expect(headerElements[1].nativeElement.innerText).toEqual('Value'); })); it('renders range selection headers when selectedTab is set to RANGE', fakeAsync(() => { store.overrideSelector(getRangeSelectionHeaders, [ { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', enabled: true, }, { @@ -148,16 +180,62 @@ describe('scalar column editor', () => { ); expect(headerElements.length).toEqual(2); - expect(headerElements[0].nativeElement.innerText).toEqual('Run'); + expect(headerElements[0].nativeElement.innerText).toEqual('Smoothed'); expect(headerElements[1].nativeElement.innerText).toEqual('Value'); })); + [ + { + testDesc: 'for singleSelectionHeaders', + selector: getSingleSelectionHeaders, + mode: DataTableMode.SINGLE, + }, + { + testDesc: 'for rangeSelectionHeaders', + selector: getRangeSelectionHeaders, + mode: DataTableMode.RANGE, + }, + ].forEach(({testDesc, selector, mode}) => { + it(`hides the runs column ${testDesc}`, fakeAsync(() => { + store.overrideSelector(selector, [ + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'Run', + enabled: true, + }, + { + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', + enabled: true, + }, + { + type: ColumnHeaderType.VALUE, + name: 'value', + displayName: 'Value', + enabled: true, + }, + ]); + const fixture = createComponent(); + + switchTabs(fixture, mode); + const headerElements = fixture.debugElement.queryAll( + By.css('.header-list-item') + ); + + expect(headerElements.length).toEqual(2); + expect(headerElements[0].nativeElement.innerText).toEqual('Smoothed'); + expect(headerElements[1].nativeElement.innerText).toEqual('Value'); + })); + }); + it('checkboxes reflect enabled state', fakeAsync(() => { store.overrideSelector(getSingleSelectionHeaders, [ { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', enabled: true, }, { @@ -173,7 +251,7 @@ describe('scalar column editor', () => { const checkboxes = fixture.debugElement.queryAll(By.css('mat-checkbox')); expect(checkboxes.length).toEqual(2); - expect(checkboxes[0].nativeElement.innerText).toEqual('Run'); + expect(checkboxes[0].nativeElement.innerText).toEqual('Smoothed'); expect( checkboxes[0].nativeElement.attributes.getNamedItem('ng-reflect-checked') .value @@ -197,9 +275,9 @@ describe('scalar column editor', () => { it('dispatches dataTableColumnToggled action with singe selection when checkbox is clicked', fakeAsync(() => { store.overrideSelector(getSingleSelectionHeaders, [ { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', enabled: true, }, { @@ -222,9 +300,9 @@ describe('scalar column editor', () => { dataTableColumnToggled({ dataTableMode: DataTableMode.SINGLE, header: { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', enabled: true, }, }) @@ -278,12 +356,12 @@ describe('scalar column editor', () => { }); }); - it('dispatches dataTableColumnEdited action with singe selection when header is dragged', fakeAsync(() => { + it('dispatches dataTableColumnEdited action with single selection when header is dragged', fakeAsync(() => { store.overrideSelector(getSingleSelectionHeaders, [ { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', enabled: true, }, { @@ -312,28 +390,21 @@ describe('scalar column editor', () => { headerListItems[0].triggerEventHandler('dragend'); expect(dispatchedActions[0]).toEqual( - dataTableColumnEdited({ + dataTableColumnOrderChanged({ + source: { + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', + enabled: true, + }, + destination: { + type: ColumnHeaderType.VALUE, + name: 'value', + displayName: 'Value', + enabled: true, + }, + side: Side.RIGHT, dataTableMode: DataTableMode.SINGLE, - headers: [ - { - type: ColumnHeaderType.VALUE, - name: 'value', - displayName: 'Value', - enabled: true, - }, - { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', - enabled: true, - }, - { - type: ColumnHeaderType.STEP, - name: 'step', - displayName: 'Step', - enabled: true, - }, - ], }) ); })); @@ -341,9 +412,9 @@ describe('scalar column editor', () => { it('dispatches dataTableColumnEdited action with range selection when header is dragged', fakeAsync(() => { store.overrideSelector(getRangeSelectionHeaders, [ { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', enabled: true, }, { @@ -372,28 +443,21 @@ describe('scalar column editor', () => { headerListItems[1].triggerEventHandler('dragend'); expect(dispatchedActions[0]).toEqual( - dataTableColumnEdited({ + dataTableColumnOrderChanged({ + source: { + type: ColumnHeaderType.MAX_VALUE, + name: 'maxValue', + displayName: 'Max', + enabled: true, + }, + destination: { + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', + enabled: true, + }, + side: Side.LEFT, dataTableMode: DataTableMode.RANGE, - headers: [ - { - type: ColumnHeaderType.MAX_VALUE, - name: 'maxValue', - displayName: 'Max', - enabled: true, - }, - { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', - enabled: true, - }, - { - type: ColumnHeaderType.MIN_VALUE, - name: 'minValue', - displayName: 'Min', - enabled: true, - }, - ], }) ); })); @@ -401,9 +465,9 @@ describe('scalar column editor', () => { it('highlights item with bottom edge when dragging below item being dragged', fakeAsync(() => { store.overrideSelector(getRangeSelectionHeaders, [ { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', enabled: true, }, { @@ -437,9 +501,9 @@ describe('scalar column editor', () => { it('highlights item with top edge when dragging above item being dragged', fakeAsync(() => { store.overrideSelector(getRangeSelectionHeaders, [ { - type: ColumnHeaderType.RUN, - name: 'run', - displayName: 'Run', + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', enabled: true, }, { @@ -540,4 +604,179 @@ describe('scalar column editor', () => { ).toBeTrue(); }); }); + + describe('hparam columns', () => { + beforeEach(() => { + store.overrideSelector(getEnableHparamsInTimeSeries, true); + }); + + it('hides hparam sections when hparams in time series is disabled', () => { + store.overrideSelector(getEnableHparamsInTimeSeries, false); + + const fixture = createComponent(); + + expect(fixture.debugElement.query(By.css('.hparams-header'))).toBeFalsy(); + }); + + it('shows hparam sections when hparams in time series is enabled', () => { + const fixture = createComponent(); + + expect( + fixture.debugElement.query(By.css('.hparams-header')) + ).toBeTruthy(); + }); + + it('opens column selector modal on add button click', async () => { + const focusSpy = spyOn(ColumnSelectorComponent.prototype, 'focus'); + const fixture = createComponent(); + + const addButton = fixture.debugElement.query( + By.css('.hparams-add-button') + ); + addButton.nativeElement.click(); + fixture.detectChanges(); + // Wait for modal init. + await new Promise((resolve) => window.requestAnimationFrame(resolve)); + + expect( + fixture.debugElement.query(By.directive(CustomModalComponent)) + ).toBeTruthy(); + expect( + fixture.debugElement.query(By.directive(ColumnSelectorComponent)) + ).toBeTruthy(); + expect(focusSpy).toHaveBeenCalled(); + }); + + it('dispatches dashboardHparamColumnAdded on column select', async () => { + const dispatchSpy = spyOn(store, 'dispatch'); + const fixture = createComponent(); + + const addButton = fixture.debugElement.query( + By.css('.hparams-add-button') + ); + addButton.nativeElement.click(); + fixture.detectChanges(); + // Wait for modal init. + await new Promise((resolve) => window.requestAnimationFrame(resolve)); + fixture.debugElement + .query(By.directive(ColumnSelectorComponent)) + .componentInstance.columnSelected.emit({ + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }); + + expect(dispatchSpy).toHaveBeenCalledOnceWith( + hparamsActions.dashboardHparamColumnAdded({ + column: { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + }) + ); + }); + + it('lists hparam columns', () => { + const fixture = createComponent(); + + const headerElements = fixture.debugElement.queryAll( + By.css('.hparams-list .header-list-item') + ); + expect(headerElements.length).toEqual(3); + expect(headerElements[0].nativeElement.innerText).toEqual('Conv Layers'); + expect(headerElements[1].nativeElement.innerText).toEqual( + 'Conv Kernel Size' + ); + expect(headerElements[2].nativeElement.innerText).toEqual('Dense Layers'); + }); + + it('dispatches dashboardHparamColumnOrderChanged hparam header is dragged', () => { + const dispatchSpy = spyOn(store, 'dispatch'); + const fixture = createComponent(); + const headerListItems = fixture.debugElement.queryAll( + By.css('.header-list-item') + ); + + headerListItems[0].triggerEventHandler('dragstart'); + headerListItems[1].triggerEventHandler('dragenter'); + headerListItems[0].triggerEventHandler('dragend'); + + expect(dispatchSpy).toHaveBeenCalledOnceWith( + hparamsActions.dashboardHparamColumnOrderChanged({ + source: { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + destination: { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + side: Side.RIGHT, + }) + ); + }); + + it('prevents interaction between hparam and standard headers', () => { + const dispatchSpy = spyOn(store, 'dispatch'); + store.overrideSelector(getSingleSelectionHeaders, [ + { + type: ColumnHeaderType.SMOOTHED, + name: 'smoothed', + displayName: 'Smoothed', + enabled: true, + }, + ]); + const fixture = createComponent(); + const standardHeaderDebugEl = fixture.debugElement + .queryAll(By.css('.header-list-item')) + .find((debugEl) => + debugEl.nativeElement.innerHTML.includes('Smoothed') + )!; + const hparamHeaderDebugEl = fixture.debugElement + .queryAll(By.css('.header-list-item')) + .find((debugEl) => + debugEl.nativeElement.innerHTML.includes('Conv Layers') + )!; + + // Moving standard to hparam header + standardHeaderDebugEl.triggerEventHandler('dragstart'); + hparamHeaderDebugEl.triggerEventHandler('dragenter'); + standardHeaderDebugEl.triggerEventHandler('dragend'); + // Moving hparam header to standard header + hparamHeaderDebugEl.triggerEventHandler('dragstart'); + standardHeaderDebugEl.triggerEventHandler('dragenter'); + hparamHeaderDebugEl.triggerEventHandler('dragend'); + + expect(dispatchSpy).not.toHaveBeenCalled(); + }); + + it('dispatches dashboardHparamColumnToggled on hparam header checkbox click', () => { + const dispatchSpy = spyOn(store, 'dispatch'); + const fixture = createComponent(); + const checkbox = fixture.debugElement.query( + By.css('.hparams-list mat-checkbox') + ); + + checkbox.triggerEventHandler('change'); + fixture.detectChanges(); + + expect(dispatchSpy).toHaveBeenCalledOnceWith( + hparamsActions.dashboardHparamColumnToggled({ + column: { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + }) + ); + }); + }); }); diff --git a/tensorboard/webapp/persistent_settings/_data_source/persistent_settings_data_source.ts b/tensorboard/webapp/persistent_settings/_data_source/persistent_settings_data_source.ts index 1f268d4ff7..612756f7ae 100644 --- a/tensorboard/webapp/persistent_settings/_data_source/persistent_settings_data_source.ts +++ b/tensorboard/webapp/persistent_settings/_data_source/persistent_settings_data_source.ts @@ -235,7 +235,8 @@ export class OSSSettingsConverter extends SettingsConverter< if ( Array.isArray(backendSettings.singleSelectionHeaders) && // If the settings stored in the backend are invalid, reset back to default. - backendSettings.singleSelectionHeaders[0].name !== undefined + backendSettings.singleSelectionHeaders[0].name !== undefined && + backendSettings.singleSelectionHeaders[0].type === 'RUN' ) { updateScalarContextMenuOptions(backendSettings.singleSelectionHeaders); settings.singleSelectionHeaders = backendSettings.singleSelectionHeaders; @@ -244,7 +245,8 @@ export class OSSSettingsConverter extends SettingsConverter< if ( Array.isArray(backendSettings.rangeSelectionHeaders) && // If the settings stored in the backend are invalid, reset back to default. - backendSettings.rangeSelectionHeaders[0].name !== undefined + backendSettings.rangeSelectionHeaders[0].name !== undefined && + backendSettings.rangeSelectionHeaders[0].type === 'RUN' ) { updateScalarContextMenuOptions(backendSettings.rangeSelectionHeaders); settings.rangeSelectionHeaders = backendSettings.rangeSelectionHeaders; diff --git a/tensorboard/webapp/persistent_settings/_data_source/persistent_settings_data_source_test.ts b/tensorboard/webapp/persistent_settings/_data_source/persistent_settings_data_source_test.ts index 1dccb63410..959c02ca83 100644 --- a/tensorboard/webapp/persistent_settings/_data_source/persistent_settings_data_source_test.ts +++ b/tensorboard/webapp/persistent_settings/_data_source/persistent_settings_data_source_test.ts @@ -300,6 +300,27 @@ describe('persistent_settings data_source test', () => { expect(actual).toEqual({}); }); + it('resets singleSelectionEnabled if runs header is not first', async () => { + getItemSpy.withArgs(TEST_ONLY.GLOBAL_LOCAL_STORAGE_KEY).and.returnValue( + JSON.stringify({ + singleSelectionHeaders: [ + { + type: ColumnHeaderType.VALUE, + enabled: false, + }, + { + type: ColumnHeaderType.RUN, + enabled: true, + }, + ], + }) + ); + + const actual = await firstValueFrom(dataSource.getSettings()); + + expect(actual).toEqual({}); + }); + it('resets rangeSelectionEnabled if old ColumnHeader is stored', async () => { getItemSpy.withArgs(TEST_ONLY.GLOBAL_LOCAL_STORAGE_KEY).and.returnValue( JSON.stringify({ @@ -321,6 +342,27 @@ describe('persistent_settings data_source test', () => { expect(actual).toEqual({}); }); + it('resets rangeSelectionEnabled if runs header is not first', async () => { + getItemSpy.withArgs(TEST_ONLY.GLOBAL_LOCAL_STORAGE_KEY).and.returnValue( + JSON.stringify({ + rangeSelectionHeaders: [ + { + type: ColumnHeaderType.MIN_VALUE, + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + enabled: true, + }, + ], + }) + ); + + const actual = await firstValueFrom(dataSource.getSettings()); + + expect(actual).toEqual({}); + }); + it('properly converts dashboardDisplayedHparamColumns', async () => { getItemSpy.withArgs(TEST_ONLY.GLOBAL_LOCAL_STORAGE_KEY).and.returnValue( JSON.stringify({ diff --git a/tensorboard/webapp/runs/store/BUILD b/tensorboard/webapp/runs/store/BUILD index ef603b75ec..d8715bc8e8 100644 --- a/tensorboard/webapp/runs/store/BUILD +++ b/tensorboard/webapp/runs/store/BUILD @@ -52,6 +52,7 @@ tf_ts_library( "//tensorboard/webapp/types", "//tensorboard/webapp/types:ui", "//tensorboard/webapp/widgets/data_table:types", + "//tensorboard/webapp/widgets/data_table:utils", "@npm//@ngrx/store", ], ) @@ -98,6 +99,7 @@ tf_ts_library( ":testing", ":types", ":utils", + "//tensorboard/webapp:app_state", "//tensorboard/webapp/app_routing:testing", "//tensorboard/webapp/app_routing:types", "//tensorboard/webapp/app_routing/actions", diff --git a/tensorboard/webapp/runs/store/runs_selectors.ts b/tensorboard/webapp/runs/store/runs_selectors.ts index e8cff472f2..87ed5f73d9 100644 --- a/tensorboard/webapp/runs/store/runs_selectors.ts +++ b/tensorboard/webapp/runs/store/runs_selectors.ts @@ -26,9 +26,13 @@ import { } from './runs_types'; import {createGroupBy} from './utils'; import {getExperimentIdsFromRoute} from '../../app_routing/store/app_routing_selectors'; -import {getDashboardSessionGroups} from '../../hparams/_redux/hparams_selectors'; +import { + getDashboardDisplayedHparamColumns, + getDashboardSessionGroups, +} from '../../hparams/_redux/hparams_selectors'; import {HparamValue, RunToHparamsAndMetrics} from '../../hparams/types'; import {ColumnHeader, SortingInfo} from '../../widgets/data_table/types'; +import {DataTableUtils} from '../../widgets/data_table/utils'; const getRunsState = createFeatureSelector(RUNS_FEATURE_KEY); @@ -301,7 +305,7 @@ export const getColorGroupRegexString = createSelector( ); /** - * Gets the columns to be displayed by the runs table. + * Gets the standard columns to be displayed by the runs table. */ export const getRunsTableHeaders = createSelector( getUiState, @@ -319,3 +323,23 @@ export const getRunsTableSortingInfo = createSelector( return state.sortingInfo; } ); + +/** + * Gets the grouped columns to be displayed by the runs table. + */ +export const getGroupedRunsTableHeaders = createSelector( + getRunsTableHeaders, + getDashboardDisplayedHparamColumns, + (runsTableHeaders, hparamColumns) => { + // Override hparam options to match runs table requirements. + const columns = [...runsTableHeaders, ...hparamColumns].map((column) => { + const newColumn = {...column}; + if (column.type === 'HPARAM') { + newColumn.removable = true; + newColumn.hidable = false; + } + return newColumn; + }); + return DataTableUtils.groupColumns(columns); + } +); diff --git a/tensorboard/webapp/runs/store/runs_selectors_test.ts b/tensorboard/webapp/runs/store/runs_selectors_test.ts index 01363d880c..c1e083a69c 100644 --- a/tensorboard/webapp/runs/store/runs_selectors_test.ts +++ b/tensorboard/webapp/runs/store/runs_selectors_test.ts @@ -21,8 +21,10 @@ import { buildSessionGroup, buildStateFromHparamsState, buildHparamsState, + buildHparamSpec, } from '../../hparams/testing'; import {buildMockState} from '../../testing/utils'; +import {State} from '../../app_state'; import {DataLoadState} from '../../types/data'; import {ColumnHeaderType, SortingOrder} from '../../widgets/data_table/types'; import {GroupByKey} from '../types'; @@ -1027,6 +1029,129 @@ describe('runs_selectors', () => { }); }); + describe('#getGroupedRunsTableHeaders', () => { + let state: State; + + beforeEach(() => { + state = buildMockState({ + runs: buildRunsState( + {}, + { + runsTableHeaders: [ + { + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }, + { + type: ColumnHeaderType.CUSTOM, + name: 'experimentAlias', + displayName: 'Experiment Alias', + enabled: true, + }, + { + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'Run', + enabled: true, + }, + ], + } + ), + ...buildStateFromHparamsState( + buildHparamsState({ + dashboardSpecs: { + hparams: [ + buildHparamSpec({name: 'conv_layers'}), + buildHparamSpec({name: 'conv_kernel_size'}), + ], + }, + dashboardDisplayedHparamColumns: [ + { + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }, + { + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }, + ], + }) + ), + }); + }); + + it('returns runs table headers grouped with other headers', () => { + expect(selectors.getGroupedRunsTableHeaders(state)).toEqual([ + jasmine.objectContaining({ + type: ColumnHeaderType.RUN, + name: 'run', + displayName: 'Run', + enabled: true, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.CUSTOM, + name: 'experimentAlias', + displayName: 'Experiment Alias', + enabled: true, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.COLOR, + name: 'color', + displayName: 'Color', + enabled: true, + }), + ]); + }); + + it('sets the hparam column context options for the runs table', () => { + expect(selectors.getGroupedRunsTableHeaders(state)).toEqual([ + jasmine.objectContaining({ + type: ColumnHeaderType.RUN, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.CUSTOM, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_layers', + displayName: 'Conv Layers', + enabled: true, + removable: true, + hidable: false, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.HPARAM, + name: 'conv_kernel_size', + displayName: 'Conv Kernel Size', + enabled: true, + removable: true, + hidable: false, + }), + jasmine.objectContaining({ + type: ColumnHeaderType.COLOR, + }), + ]); + }); + }); + describe('#getRunsTableSortingInfo', () => { it('returns the runs data table sorting info', () => { const state = buildMockState({ diff --git a/tensorboard/webapp/runs/types.ts b/tensorboard/webapp/runs/types.ts index 8a284d1c39..adfe4421ec 100644 --- a/tensorboard/webapp/runs/types.ts +++ b/tensorboard/webapp/runs/types.ts @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {Run} from './data_source/runs_data_source_types'; +import {Run, DiscreteHparamValue} from './data_source/runs_data_source_types'; -export {Run} from './data_source/runs_data_source_types'; +export {Run, DiscreteHparamValue} from './data_source/runs_data_source_types'; export type ExperimentIdToRuns = Record< string, @@ -58,3 +58,9 @@ export interface URLDeserializedState { regexFilter: string | null; }; } + +export type RunToHparams = { + [runName: string]: { + [hparamName: string]: DiscreteHparamValue; + }; +}; diff --git a/tensorboard/webapp/runs/views/runs_table/runs_data_table.ng.html b/tensorboard/webapp/runs/views/runs_table/runs_data_table.ng.html index 31e6552043..d91853e2c0 100644 --- a/tensorboard/webapp/runs/views/runs_table/runs_data_table.ng.html +++ b/tensorboard/webapp/runs/views/runs_table/runs_data_table.ng.html @@ -28,7 +28,6 @@ (); - @Output() orderColumns = new EventEmitter(); + @Output() orderColumns = new EventEmitter(); @Output() onSelectionToggle = new EventEmitter(); @Output() onAllSelectionToggle = new EventEmitter(); @Output() onRegexFilterChange = new EventEmitter(); @@ -57,10 +59,7 @@ export class RunsDataTable { runId: string; newColor: string; }>(); - @Output() addColumn = new EventEmitter<{ - header: ColumnHeader; - index?: number | undefined; - }>(); + @Output() addColumn = new EventEmitter(); @Output() removeColumn = new EventEmitter(); @Output() onSelectionDblClick = new EventEmitter(); @Output() addFilter = new EventEmitter(); diff --git a/tensorboard/webapp/runs/views/runs_table/runs_data_table_test.ts b/tensorboard/webapp/runs/views/runs_table/runs_data_table_test.ts index 0de6406a5d..a4d3bc30e2 100644 --- a/tensorboard/webapp/runs/views/runs_table/runs_data_table_test.ts +++ b/tensorboard/webapp/runs/views/runs_table/runs_data_table_test.ts @@ -140,21 +140,22 @@ describe('runs_data_table', () => { ).toBeTruthy(); }); - it('projects enabled headers plus color and selected column', () => { + it('projects headers plus color and selected column', () => { const fixture = createComponent({}); const dataTable = fixture.debugElement.query( By.directive(DataTableComponent) ); const headers = dataTable.queryAll(By.directive(HeaderCellComponent)); - expect(headers.length).toBe(4); + expect(headers.length).toBe(5); expect(headers[0].componentInstance.header.name).toEqual('selected'); expect(headers[1].componentInstance.header.name).toEqual('run'); - expect(headers[2].componentInstance.header.name).toEqual('other_header'); - expect(headers[3].componentInstance.header.name).toEqual('color'); + expect(headers[2].componentInstance.header.name).toEqual('disabled_header'); + expect(headers[3].componentInstance.header.name).toEqual('other_header'); + expect(headers[4].componentInstance.header.name).toEqual('color'); }); - it('projects content for each enabled header, selected, and color column', () => { + it('projects content for each header, selected, and color column', () => { const fixture = createComponent({ data: [{id: 'runid', run: 'run name', color: 'red', other_header: 'foo'}], }); @@ -163,11 +164,12 @@ describe('runs_data_table', () => { ); const cells = dataTable.queryAll(By.directive(ContentCellComponent)); - expect(cells.length).toBe(4); + expect(cells.length).toBe(5); expect(cells[0].componentInstance.header.name).toEqual('selected'); expect(cells[1].componentInstance.header.name).toEqual('run'); - expect(cells[2].componentInstance.header.name).toEqual('other_header'); - expect(cells[3].componentInstance.header.name).toEqual('color'); + expect(cells[2].componentInstance.header.name).toEqual('disabled_header'); + expect(cells[3].componentInstance.header.name).toEqual('other_header'); + expect(cells[4].componentInstance.header.name).toEqual('color'); }); describe('color column', () => { diff --git a/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts b/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts index 6bf276aef5..a9fbb8aacf 100644 --- a/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts +++ b/tensorboard/webapp/runs/views/runs_table/runs_table_container.ts @@ -22,7 +22,6 @@ import { import {createSelector, Store} from '@ngrx/store'; import {combineLatest, Observable, of, Subject} from 'rxjs'; import { - combineLatestWith, distinctUntilChanged, filter, map, @@ -44,13 +43,15 @@ import { getRunSelectorRegexFilter, getRunsLoadState, getRunsTableFullScreen, - getRunsTableHeaders, getRunsTableSortingInfo, + getGroupedRunsTableHeaders, } from '../../../selectors'; import {DataLoadState, LoadState} from '../../../types/data'; import { + AddColumnEvent, ColumnHeader, FilterAddedEvent, + ReorderColumnEvent, SortingInfo, TableData, } from '../../../widgets/data_table/types'; @@ -59,9 +60,6 @@ import { runPageSelectionToggled, runSelectionToggled, runSelectorRegexFilterChanged, - runsTableHeaderAdded, - runsTableHeaderOrderChanged, - runsTableHeaderRemoved, runsTableSortingInfoChanged, singleRunSelected, } from '../../actions'; @@ -70,7 +68,7 @@ import {RunsTableColumn, RunTableItem} from './types'; import { getCurrentColumnFilters, getFilteredRenderableRuns, - getPotentialHparamColumns, + getSelectableColumns, } from '../../../metrics/views/main_view/common_selectors'; import {runsTableFullScreenToggled} from '../../../core/actions'; import {sortTableDataItems} from './sorting_utils'; @@ -143,18 +141,9 @@ export class RunsTableContainer implements OnInit, OnDestroy { @Input() showHparamsAndMetrics = false; regexFilter$ = this.store.select(getRunSelectorRegexFilter); - runsColumns$ = this.store.select(getRunsTableHeaders); + runsColumns$ = this.store.select(getGroupedRunsTableHeaders); runsTableFullScreen$ = this.store.select(getRunsTableFullScreen); - - selectableColumns$ = this.store.select(getPotentialHparamColumns).pipe( - combineLatestWith(this.runsColumns$), - map(([potentialColumns, currentColumns]) => { - const currentColumnNames = new Set(currentColumns.map(({name}) => name)); - return potentialColumns.filter((columnHeader) => { - return !currentColumnNames.has(columnHeader.name); - }); - }) - ); + selectableColumns$ = this.store.select(getSelectableColumns); columnFilters$ = this.store.select(getCurrentColumnFilters); @@ -332,19 +321,26 @@ export class RunsTableContainer implements OnInit, OnDestroy { this.store.dispatch(runsTableFullScreenToggled()); } - addColumn({header, index}: {header: ColumnHeader; index: number}) { - header.enabled = true; + addColumn({column, nextTo, side}: AddColumnEvent) { this.store.dispatch( - runsTableHeaderAdded({header: {...header, enabled: true}, index}) + hparamsActions.dashboardHparamColumnAdded({ + column, + nextTo, + side, + }) ); } removeColumn(header: ColumnHeader) { - this.store.dispatch(runsTableHeaderRemoved({header})); + this.store.dispatch( + hparamsActions.dashboardHparamColumnRemoved({column: header}) + ); } - orderColumns(newHeaderOrder: ColumnHeader[]) { - this.store.dispatch(runsTableHeaderOrderChanged({newHeaderOrder})); + orderColumns(event: ReorderColumnEvent) { + this.store.dispatch( + hparamsActions.dashboardHparamColumnOrderChanged(event) + ); } addHparamFilter(event: FilterAddedEvent) { diff --git a/tensorboard/webapp/testing/BUILD b/tensorboard/webapp/testing/BUILD index 7fa69f955f..1c2722990b 100644 --- a/tensorboard/webapp/testing/BUILD +++ b/tensorboard/webapp/testing/BUILD @@ -83,7 +83,7 @@ tf_ts_library( "//tensorboard/webapp/hparams/_redux:testing", "//tensorboard/webapp/hparams/_redux:types", "//tensorboard/webapp/metrics:test_lib", - "//tensorboard/webapp/metrics/store", + "//tensorboard/webapp/metrics/store:types", "//tensorboard/webapp/notification_center/_redux:testing", "//tensorboard/webapp/notification_center/_redux:types", "//tensorboard/webapp/persistent_settings/_redux:testing", diff --git a/tensorboard/webapp/testing/utils.ts b/tensorboard/webapp/testing/utils.ts index ad63ce8a86..f2c78daad4 100644 --- a/tensorboard/webapp/testing/utils.ts +++ b/tensorboard/webapp/testing/utils.ts @@ -45,7 +45,7 @@ import { buildHparamsState, buildStateFromHparamsState, } from '../hparams/_redux/testing'; -import {METRICS_FEATURE_KEY} from '../metrics/store'; +import {METRICS_FEATURE_KEY} from '../metrics/store/metrics_types'; import {appStateFromMetricsState, buildMetricsState} from '../metrics/testing'; import {NOTIFICATION_FEATURE_KEY} from '../notification_center/_redux/notification_center_types'; import { diff --git a/tensorboard/webapp/widgets/custom_modal/custom_modal_component.ts b/tensorboard/webapp/widgets/custom_modal/custom_modal_component.ts index d082de4911..a2e1a22539 100644 --- a/tensorboard/webapp/widgets/custom_modal/custom_modal_component.ts +++ b/tensorboard/webapp/widgets/custom_modal/custom_modal_component.ts @@ -82,10 +82,11 @@ export class CustomModalComponent implements OnInit { public openAtPosition(position: {x: number; y: number}) { const root = this.viewRef.element.nativeElement; - const top = root.getBoundingClientRect().top; - if (top !== 0) { - root.style.top = top * -1 + root.offsetTop + 'px'; - } + // Set left/top to viewport (0,0) if the element has another "containing block" ancestor. + root.style.top = `${root.offsetTop - root.getBoundingClientRect().top}px`; + root.style.left = `${ + root.offsetLeft - root.getBoundingClientRect().left + }px`; this.content.nativeElement.style.left = position.x + 'px'; this.content.nativeElement.style.top = position.y + 'px'; diff --git a/tensorboard/webapp/widgets/data_table/BUILD b/tensorboard/webapp/widgets/data_table/BUILD index c43110f490..043e98805a 100644 --- a/tensorboard/webapp/widgets/data_table/BUILD +++ b/tensorboard/webapp/widgets/data_table/BUILD @@ -164,6 +164,29 @@ tf_ts_library( ], ) +tf_ts_library( + name = "utils", + srcs = [ + "utils.ts", + ], + deps = [ + ":types", + ], +) + +tf_ts_library( + name = "utils_test", + testonly = True, + srcs = [ + "utils_test.ts", + ], + deps = [ + ":types", + ":utils", + "@npm//@types/jasmine", + ], +) + tf_ts_library( name = "data_table_test", testonly = True, diff --git a/tensorboard/webapp/widgets/data_table/data_table_component.ng.html b/tensorboard/webapp/widgets/data_table/data_table_component.ng.html index e7e4357aee..25f9ca4374 100644 --- a/tensorboard/webapp/widgets/data_table/data_table_component.ng.html +++ b/tensorboard/webapp/widgets/data_table/data_table_component.ng.html @@ -17,6 +17,14 @@
No Actions Available
+