diff --git a/tensorboard/plugins/hparams/_keras.py b/tensorboard/plugins/hparams/_keras.py index dc57c09301..29a86fd086 100644 --- a/tensorboard/plugins/hparams/_keras.py +++ b/tensorboard/plugins/hparams/_keras.py @@ -17,7 +17,7 @@ Use `tensorboard.plugins.hparams.api` to access this module's contents. """ - +import os import tensorflow as tf from tensorboard.plugins.hparams import api_pb2 @@ -39,7 +39,7 @@ def __init__(self, writer, hparams, trial_id=None): Args: writer: The `SummaryWriter` object to which hparams should be - written, or a logdir (as a `str`) to be passed to + written, or a logdir (as a `str` or `PathLike`) to be passed to `tf.summary.create_file_writer` to create such a writer. hparams: A `dict` mapping hyperparameters to the values used in this session. Keys should be the names of `HParam` objects used @@ -62,10 +62,14 @@ def __init__(self, writer, hparams, trial_id=None): summary_v2.hparams_pb(self._hparams, trial_id=self._trial_id) if writer is None: raise TypeError( - "writer must be a `SummaryWriter` or `str`, not None" + "writer must be a `SummaryWriter`, `str` or `PathLike`, not None" ) elif isinstance(writer, str): self._writer = tf.compat.v2.summary.create_file_writer(writer) + elif isinstance(writer, os.PathLike): + self._writer = tf.compat.v2.summary.create_file_writer( + os.fspath(writer) + ) else: self._writer = writer diff --git a/tensorboard/plugins/hparams/_keras_test.py b/tensorboard/plugins/hparams/_keras_test.py index 570bd15acf..1cef8ffc80 100644 --- a/tensorboard/plugins/hparams/_keras_test.py +++ b/tensorboard/plugins/hparams/_keras_test.py @@ -15,6 +15,7 @@ import os +from pathlib import Path from unittest import mock from google.protobuf import text_format @@ -145,6 +146,14 @@ def test_explicit_writer(self): # We'll assume that the contents are correct, as in the case where # the file writer was constructed implicitly. + def test_pathlib_writer(self): + writer = Path(self.logdir) + self._initialize_model(writer=writer) + self.model.fit(x=[(1,)], y=[(2,)], callbacks=[self.callback]) + + files = os.listdir(self.logdir) + self.assertEqual(len(files), 1, files) + def test_non_eager_failure(self): with tf.compat.v1.Graph().as_default(): assert not tf.executing_eagerly() @@ -165,7 +174,7 @@ def test_reuse_failure(self): def test_invalid_writer(self): with self.assertRaisesRegex( TypeError, - "writer must be a `SummaryWriter` or `str`, not None", + "writer must be a `SummaryWriter`, `str` or `PathLike`, not None", ): _keras.Callback(writer=None, hparams={})