Skip to content

Commit

Permalink
data: define interface for read_last_scalars (#6657)
Browse files Browse the repository at this point in the history
This will be used for `hparams.backend_context.read_last_scalars`.

Googlers, see b/292102513 for more context. Tested internally at cl/577906864.

#hparams
  • Loading branch information
yatbear authored Oct 30, 2023
1 parent 8497ae1 commit d379d08
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tensorboard/backend/application_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def list_scalars(self, ctx=None, *, experiment_id):
def read_scalars(self, ctx=None, *, experiment_id):
raise NotImplementedError()

def read_last_scalars(self, ctx=None, *, experiment_id, plugin_name):
raise NotImplementedError()


class HandlingErrorsTest(tb_test.TestCase):
def test_successful_response_passes_through(self):
Expand Down
25 changes: 25 additions & 0 deletions tensorboard/backend/event_processing/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


import base64
import collections
import json
import random

Expand Down Expand Up @@ -137,6 +138,30 @@ def read_scalars(
)
return self._read(_convert_scalar_event, index, downsample)

def read_last_scalars(
self,
ctx=None,
*,
experiment_id,
plugin_name,
run_tag_filter=None,
):
self._validate_context(ctx)
self._validate_experiment_id(experiment_id)
index = self._index(
plugin_name, run_tag_filter, summary_pb2.DATA_CLASS_SCALAR
)
run_tag_to_last_scalar_datum = collections.defaultdict(dict)
for (run, tags_for_run) in index.items():
for (tag, metadata) in tags_for_run.items():
events = self._multiplexer.Tensors(run, tag)
if events:
run_tag_to_last_scalar_datum[run][
tag
] = _convert_scalar_event(events[-1])

return run_tag_to_last_scalar_datum

def list_tensors(
self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
):
Expand Down
36 changes: 36 additions & 0 deletions tensorboard/backend/event_processing/data_provider_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,42 @@ def test_read_scalars_but_not_rank_0(self):
downsample=100,
)

def test_read_last_scalars(self):
multiplexer = self.create_multiplexer()
provider = data_provider.MultiplexerDataProvider(
multiplexer, self.logdir
)

run_tag_filter = base_provider.RunTagFilter(
runs=["waves", "polynomials", "unicorns"],
tags=["sine", "square", "cube", "iridescence"],
)
result = provider.read_last_scalars(
self.ctx,
experiment_id="unused",
plugin_name=scalar_metadata.PLUGIN_NAME,
run_tag_filter=run_tag_filter,
)

self.assertCountEqual(result.keys(), ["polynomials", "waves"])
self.assertCountEqual(result["polynomials"].keys(), ["square", "cube"])
self.assertCountEqual(result["waves"].keys(), ["square", "sine"])
for run in result:
for tag in result[run]:
events = multiplexer.Tensors(run, tag)
if events:
last_event = events[-1]
datum = result[run][tag]
self.assertIsInstance(datum, base_provider.ScalarDatum)
self.assertEqual(datum.step, last_event.step)
self.assertEqual(datum.wall_time, last_event.wall_time)
self.assertEqual(
datum.value,
tensor_util.make_ndarray(
last_event.tensor_proto
).item(),
)

def test_list_tensors_all(self):
provider = self.create_provider()
result = provider.list_tensors(
Expand Down
39 changes: 39 additions & 0 deletions tensorboard/data/grpc_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""A data provider that talks to a gRPC server."""

import collections
import contextlib

import grpc
Expand Down Expand Up @@ -148,6 +149,44 @@ def read_scalars(
series.append(point)
return result

@timing.log_latency
def read_last_scalars(
self,
ctx,
*,
experiment_id,
plugin_name,
run_tag_filter=None,
):
with timing.log_latency("build request"):
req = data_provider_pb2.ReadScalarsRequest()
req.experiment_id = experiment_id
req.plugin_filter.plugin_name = plugin_name
_populate_rtf(run_tag_filter, req.run_tag_filter)
# `ReadScalars` always includes the most recent datum, therefore
# downsampling to one means fetching the latest value.
req.downsample.num_points = 1
with timing.log_latency("_stub.ReadScalars"):
with _translate_grpc_error():
res = self._stub.ReadScalars(req)
with timing.log_latency("build result"):
result = collections.defaultdict(dict)
for run_entry in res.runs:
run_name = run_entry.run_name
for tag_entry in run_entry.tags:
d = tag_entry.data
# There should be no more than one datum in
# `tag_entry.data` since downsample was set to 1.
for (step, wt, value) in zip(d.step, d.wall_time, d.value):
result[run_name][
tag_entry.tag_name
] = provider.ScalarDatum(
step=step,
wall_time=wt,
value=value,
)
return result

@timing.log_latency
def list_tensors(
self, ctx, *, experiment_id, plugin_name, run_tag_filter=None
Expand Down
57 changes: 57 additions & 0 deletions tensorboard/data/grpc_provider_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,63 @@ def test_read_scalars(self):
req.downsample.num_points = 4
self.stub.ReadScalars.assert_called_once_with(req)

def test_read_last_scalars(self):
tag1 = data_provider_pb2.ReadScalarsResponse.TagEntry(
tag_name="tag1",
data=data_provider_pb2.ScalarData(
step=[10000], wall_time=[1234.0], value=[1]
),
)
tag2 = data_provider_pb2.ReadScalarsResponse.TagEntry(
tag_name="tag2",
data=data_provider_pb2.ScalarData(
step=[10000], wall_time=[1235.0], value=[0.50]
),
)
run1 = data_provider_pb2.ReadScalarsResponse.RunEntry(
run_name="run1", tags=[tag1]
)
run2 = data_provider_pb2.ReadScalarsResponse.RunEntry(
run_name="run2", tags=[tag2]
)
res = data_provider_pb2.ReadScalarsResponse(runs=[run1, run2])
self.stub.ReadScalars.return_value = res

actual = self.provider.read_last_scalars(
self.ctx,
experiment_id="123",
plugin_name="scalars",
run_tag_filter=provider.RunTagFilter(
runs=["train", "test", "nope"]
),
)
expected = {
"run1": {
"tag1": provider.ScalarDatum(
step=10000, wall_time=1234.0, value=1
),
},
"run2": {
"tag2": provider.ScalarDatum(
step=10000, wall_time=1235.0, value=0.50
),
},
}

self.assertEqual(actual, expected)

expected_req = data_provider_pb2.ReadScalarsRequest(
experiment_id="123",
plugin_filter=data_provider_pb2.PluginFilter(plugin_name="scalars"),
run_tag_filter=data_provider_pb2.RunTagFilter(
runs=data_provider_pb2.RunFilter(
names=["nope", "test", "train"] # sorted
)
),
downsample=data_provider_pb2.Downsample(num_points=1),
)
self.stub.ReadScalars.assert_called_once_with(expected_req)

def test_list_tensors(self):
res = data_provider_pb2.ListTensorsResponse()
run1 = res.runs.add(run_name="val")
Expand Down
37 changes: 37 additions & 0 deletions tensorboard/data/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,43 @@ def read_scalars(
"""
pass

@abc.abstractmethod
def read_last_scalars(
self,
ctx=None,
*,
experiment_id,
plugin_name,
run_tag_filter=None,
):
"""Read the most recent values from scalar time series.
The most recent scalar value for each tag under each run is retrieved
from the latest event (at the latest step).
Args:
ctx: A TensorBoard `RequestContext` value.
experiment_id: ID of enclosing experiment.
plugin_name: String name of the TensorBoard plugin that created
the data to be queried. Required.
run_tag_filter: Optional `RunTagFilter` value. If provided, a datum
series will only be included in the result if its run and tag
both pass this filter. If `None`, all time series will be
included.
The result will only contain keys for run-tag combinations that
actually exist, which may not include all entries in the
`run_tag_filter`.
Returns:
A nested map `d` such that `d[run][tag]` is a `ScalarDatum`
representing the latest scalar in the time series.
Raises:
tensorboard.errors.PublicError: See `DataProvider` class docstring.
"""
pass

def list_tensors(
self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
):
Expand Down
11 changes: 11 additions & 0 deletions tensorboard/plugins/debugger_v2/debug_data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,17 @@ def read_scalars(
del experiment_id, plugin_name, downsample, run_tag_filter
raise TypeError("Debugger V2 DataProvider doesn't support scalars.")

def read_last_scalars(
self,
ctx=None,
*,
experiment_id,
plugin_name,
run_tag_filter=None,
):
del experiment_id, plugin_name, run_tag_filter
raise TypeError("Debugger V2 DataProvider doesn't support scalars.")

def list_blob_sequences(
self, ctx=None, *, experiment_id, plugin_name, run_tag_filter=None
):
Expand Down

0 comments on commit d379d08

Please sign in to comment.