diff --git a/n3fit/src/n3fit/backends/keras_backend/callbacks.py b/n3fit/src/n3fit/backends/keras_backend/callbacks.py index 8a593af814..f3627e9e3b 100644 --- a/n3fit/src/n3fit/backends/keras_backend/callbacks.py +++ b/n3fit/src/n3fit/backends/keras_backend/callbacks.py @@ -15,7 +15,6 @@ import logging from time import time -# Callbacks need tensorflow installed even if the backend is pytorch from keras.callbacks import Callback, TensorBoard import numpy as np @@ -196,7 +195,8 @@ def gen_tensorboard_callback(log_dir, profiling=False, histogram_freq=0): If the profiling flag is set to True, it will also attempt to save profiling data. - Note the usage of this callback can hurt performance. + Note the usage of this callback can hurt performance + At the moment can only be used with TensorFlow: https://github.com/keras-team/keras/issues/19121 Parameters ---------- diff --git a/n3fit/src/n3fit/backends/keras_backend/internal_state.py b/n3fit/src/n3fit/backends/keras_backend/internal_state.py index a10e317010..3b7be3f7ed 100644 --- a/n3fit/src/n3fit/backends/keras_backend/internal_state.py +++ b/n3fit/src/n3fit/backends/keras_backend/internal_state.py @@ -56,6 +56,10 @@ def set_threading(threads, cores): "Could not set tensorflow parallelism settings from n3fit, maybe tensorflow is already initialized by a third program" ) +else: + # Keras should've failed by now, if it doesn't it could be a new backend that works ootb? + log.warning(f"Backend {K.backend()} not recognized. You are entering uncharted territory") + def set_number_of_cores(max_cores=None, max_threads=None): """ diff --git a/n3fit/src/n3fit/checks.py b/n3fit/src/n3fit/checks.py index 32c2d26acf..ee7cdaee43 100644 --- a/n3fit/src/n3fit/checks.py +++ b/n3fit/src/n3fit/checks.py @@ -159,6 +159,14 @@ def check_dropout(parameters): def check_tensorboard(tensorboard): """Check that the tensorbard callback can be enabled correctly""" if tensorboard is not None: + # Check that Tensorflow is installed + try: + import tensorflow + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "The tensorboard callback requires `tensorflow` to be installed" + ) from e + weight_freq = tensorboard.get("weight_freq", 0) if weight_freq < 0: raise CheckError(