Skip to content

Commit

Permalink
Adjust demo scripts to be Keras 3 compatible (#6761)
Browse files Browse the repository at this point in the history
This adjusts *_demo.py files to work with Keras 3.

The hparams_demo is fully backward compatible with Keras 2 and forward
compatible with Keras 3.

Unfortunately the graphs_demo is not backward compatible with Keras 2.
Users attempting to run it with Keras 2 will get the following error:

```
  File "/usr/local/google/home/bdubois/.cache/bazel/_bazel_bdubois/079646a57be11faea0b2bfefccb2a81a/execroot/org_tensorflow_tensorboard/bazel-out/k8-fastbuild/bin/tensorboard/plugins/graph/graphs_demo.runfiles/org_tensorflow_tensorboard/tensorboard/plugins/graph/graphs_demo.py", line 128, in profile
    tf.summary.trace_on(profiler=True, profiler_outdir=logdir)
TypeError: trace_on() got an unexpected keyword argument 'profiler_outdir'
```

Amazingly, though, the graph that the demo generates with Keras 3 can be
successfully loaded in the Graph dashboard. This makes me optimistic to
get the Graph plugin fully Keras 3 compatible after addressing the
user-reported error in #6686.

Old Keras 2 Graph: 

![image](https://github.com/tensorflow/tensorboard/assets/17152369/b8745739-ac06-4171-a7bc-c97135b2dec7)
New Keras 3 Graph:

![image](https://github.com/tensorflow/tensorboard/assets/17152369/04dcaf14-3464-47dc-b4a5-373cccd8370b)
  • Loading branch information
bmd3k authored Feb 16, 2024
1 parent 8a99668 commit 4c004d4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
8 changes: 3 additions & 5 deletions tensorboard/plugins/graph/graphs_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def keras():
model = tf.keras.models.Sequential(layers)
model.compile(
loss=tf.keras.losses.mean_squared_error,
optimizer=tf.keras.optimizers.SGD(lr=0.2),
optimizer=tf.keras.optimizers.SGD(learning_rate=0.2),
)
model.fit(
x_train,
Expand Down Expand Up @@ -125,11 +125,9 @@ def g(i):
for step in range(3):
# Suppress the profiler deprecation warnings from tf.summary.trace_*.
with _silence_deprecation_warnings():
tf.summary.trace_on(profiler=True)
tf.summary.trace_on(profiler=True, profiler_outdir=logdir)
print(f(tf.constant(step)).numpy())
tf.summary.trace_export(
"prof_f", step=step, profiler_outdir=logdir
)
tf.summary.trace_export("prof_f", step=step)

tf.summary.trace_on(profiler=False)
print(g(tf.constant(step)).numpy())
Expand Down
6 changes: 5 additions & 1 deletion tensorboard/plugins/hparams/hparams_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ def model_fn(hparams, seed):
conv_filters *= 2

model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dropout(hparams[HP_DROPOUT], seed=rng.random()))
model.add(
tf.keras.layers.Dropout(
hparams[HP_DROPOUT], seed=rng.randrange(1 << 32)
)
)

# Add fully connected layers.
dense_neurons = 32
Expand Down

0 comments on commit 4c004d4

Please sign in to comment.