diff --git a/keras/__init__.py b/keras/__init__.py index 701568f2a01..5a429d3a5d8 100644 --- a/keras/__init__.py +++ b/keras/__init__.py @@ -18,6 +18,7 @@ from keras.api import Regularizer from keras.api import Sequential from keras.api import StatelessScope +from keras.api import SymbolicScope from keras.api import Variable from keras.api import __version__ from keras.api import activations diff --git a/keras/api/__init__.py b/keras/api/__init__.py index 1750a42e869..9d082ae9b89 100644 --- a/keras/api/__init__.py +++ b/keras/api/__init__.py @@ -33,6 +33,7 @@ from keras.api import utils from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.exports import Variable from keras.src.backend.exports import device from keras.src.backend.exports import name_scope diff --git a/keras/api/_tf_keras/keras/__init__.py b/keras/api/_tf_keras/keras/__init__.py index 5e0a7229473..39a7e9cdb18 100644 --- a/keras/api/_tf_keras/keras/__init__.py +++ b/keras/api/_tf_keras/keras/__init__.py @@ -31,6 +31,7 @@ from keras.api._tf_keras.keras import preprocessing from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.exports import Variable from keras.src.backend.exports import device from keras.src.backend.exports import name_scope diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 5c7fa223520..794fe3ca364 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -14,6 +14,8 @@ from keras.src.backend.common.stateless_scope import StatelessScope from keras.src.backend.common.stateless_scope import get_stateless_scope from keras.src.backend.common.stateless_scope import in_stateless_scope +from keras.src.backend.common.symbolic_scope import SymbolicScope +from keras.src.backend.common.symbolic_scope import in_symbolic_scope from keras.src.backend.common.variables import AutocastScope from keras.src.backend.common.variables import get_autocast_scope from keras.src.backend.common.variables import is_float_dtype diff --git a/keras/src/backend/common/symbolic_scope.py b/keras/src/backend/common/symbolic_scope.py new file mode 100644 index 00000000000..15cd7a5ee05 --- /dev/null +++ b/keras/src/backend/common/symbolic_scope.py @@ -0,0 +1,23 @@ +from keras.src.api_export import keras_export +from keras.src.backend.common import global_state + + +@keras_export("keras.SymbolicScope") +class SymbolicScope: + """Scope to indicate the symbolic stage.""" + + def __enter__(self): + self.original_scope = get_symbolic_scope() + global_state.set_global_attribute("symbolic_scope", self) + return self + + def __exit__(self, *args, **kwargs): + global_state.set_global_attribute("symbolic_scope", self.original_scope) + + +def in_symbolic_scope(): + return global_state.get_global_attribute("symbolic_scope") is not None + + +def get_symbolic_scope(): + return global_state.get_global_attribute("symbolic_scope") diff --git a/keras/src/backend/common/symbolic_scope_test.py b/keras/src/backend/common/symbolic_scope_test.py new file mode 100644 index 00000000000..092dcfe0748 --- /dev/null +++ b/keras/src/backend/common/symbolic_scope_test.py @@ -0,0 +1,26 @@ +import numpy as np + +from keras.src import ops +from keras.src import testing +from keras.src.backend.common.symbolic_scope import SymbolicScope +from keras.src.backend.common.symbolic_scope import in_symbolic_scope + + +class TestSymbolicScope(testing.TestCase): + def test_basic_flow(self): + + # Define a function that behaves differently according to + # `in_symbolic_scope`. + def compute_loss(y, y_pred): + if in_symbolic_scope(): + return ops.zeros_like(y) + return ops.add(y, y_pred) + + y = ops.ones(shape=(2,)) + y_pred = ops.ones(shape=(2,)) + with SymbolicScope(): + loss = compute_loss(y, y_pred) + self.assertAllClose(loss, np.zeros((2,))) + + loss = compute_loss(y, y_pred) + self.assertAllClose(loss, 2 * np.ones((2,))) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 3ccaf06a980..c36dfee6a04 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -10,6 +10,7 @@ from keras.src.backend.common import standardize_dtype from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.jax import distribution_lib SUPPORTS_SPARSE_TENSORS = True @@ -101,7 +102,7 @@ def cast(x, dtype): # Shape / dtype / sparseness inference util def compute_output_spec(fn, *args, **kwargs): - with StatelessScope(): + with StatelessScope(), SymbolicScope(): built_in_types = (type(None), int, float, str, bool, complex, bytes) # First, separate symbolic args from other args diff --git a/keras/src/backend/numpy/core.py b/keras/src/backend/numpy/core.py index 2d34fcd7c6c..97be123f9e8 100644 --- a/keras/src/backend/numpy/core.py +++ b/keras/src/backend/numpy/core.py @@ -12,6 +12,7 @@ from keras.src.backend.common.dtypes import result_type from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.backend.common.symbolic_scope import SymbolicScope SUPPORTS_SPARSE_TENSORS = False @@ -88,7 +89,7 @@ def vectorized_map(function, elements): # Shape / dtype inference util def compute_output_spec(fn, *args, **kwargs): - with StatelessScope(): + with StatelessScope(), SymbolicScope(): def has_none_shape(x): if isinstance(x, KerasTensor): diff --git a/keras/src/backend/numpy/trainer.py b/keras/src/backend/numpy/trainer.py index 6d40982be43..12c3aad56b6 100644 --- a/keras/src/backend/numpy/trainer.py +++ b/keras/src/backend/numpy/trainer.py @@ -97,7 +97,10 @@ def _symbolic_build(self, data_batch): self._compile_metrics is not None and not self._compile_metrics.built ) - if model_unbuilt or compile_metrics_unbuilt: + compile_loss_unbuilt = ( + self._compile_loss is not None and not self._compile_loss.built + ) + if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt: # Create symbolic tensors matching an input batch. def to_symbolic_input(v): @@ -133,6 +136,15 @@ def to_symbolic_input(v): y_pred, sample_weight=sample_weight, ) + if compile_loss_unbuilt: + # Build `CompileLoss` state with `backend.compute_output_spec`. + backend.compute_output_spec( + self._compute_loss, + x, + y, + y_pred, + sample_weight=sample_weight, + ) self._post_build() def fit( diff --git a/keras/src/backend/tensorflow/core.py b/keras/src/backend/tensorflow/core.py index db33ce227d0..09d65e827cc 100644 --- a/keras/src/backend/tensorflow/core.py +++ b/keras/src/backend/tensorflow/core.py @@ -14,6 +14,7 @@ from keras.src.backend.common.name_scope import name_scope as base_name_scope from keras.src.backend.common.stateless_scope import StatelessScope from keras.src.backend.common.stateless_scope import in_stateless_scope +from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.tensorflow.sparse import sparse_to_dense from keras.src.utils.naming import auto_name @@ -182,7 +183,7 @@ def cast(x, dtype): def compute_output_spec(fn, *args, **kwargs): - with StatelessScope(): + with StatelessScope(), SymbolicScope(): graph_name = auto_name("scratch_graph") with tf.__internal__.FuncGraph(graph_name).as_default(): diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 3a941fc46a4..5f01d57d5b7 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -17,6 +17,7 @@ from keras.src.backend.common.stateless_scope import StatelessScope from keras.src.backend.common.stateless_scope import get_stateless_scope from keras.src.backend.common.stateless_scope import in_stateless_scope +from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.config import floatx SUPPORTS_SPARSE_TENSORS = False @@ -335,7 +336,7 @@ def symbolic_call(fn, args, kwargs, fill_value): ) return fn(*eager_args, **eager_kwargs) - with StatelessScope(), torch.no_grad(): + with StatelessScope(), SymbolicScope(), torch.no_grad(): outputs = symbolic_call(fn, args, kwargs, fill_value=83) none_in_shape = any( diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index dc945e22781..d6b68e2ae16 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -32,6 +32,7 @@ from keras.src.backend import KerasTensor from keras.src.backend.common import global_state from keras.src.backend.common.name_scope import current_path +from keras.src.backend.common.symbolic_scope import in_symbolic_scope from keras.src.distribution import distribution_lib from keras.src.dtype_policies import DTypePolicyMap from keras.src.layers import input_spec @@ -1139,7 +1140,10 @@ def _get_regularization_losses(self): for variable in self.trainable_weights: if variable.regularizer is None: continue - if backend.in_stateless_scope(): + if backend.in_stateless_scope() and not in_symbolic_scope(): + # If in symbolic scope, we might get `None` from + # `get_current_value` in `backend.compute_output_spec`. So we + # assign `variable` instead. v = backend.get_stateless_scope().get_current_value(variable) else: v = variable diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index ee61c64166b..94891e7fde4 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -239,14 +239,13 @@ def test_functional_list_outputs_list_losses(self): # Fit the model to make sure compile_metrics are built hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) hist_keys = sorted(hist.history.keys()) - # TODO `tf.keras` also outputs individual losses for outputs ref_keys = sorted( [ "loss", - # "output_a_loss", + "output_a_loss", "output_a_mean_squared_error", "output_b_accuracy", - # "output_b_loss", + "output_b_loss", "output_b_mean_squared_error", ] ) @@ -270,16 +269,15 @@ def test_functional_list_outputs_list_losses_abbr(self): # Fit the model to make sure compile_metrics are built hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) hist_keys = sorted(hist.history.keys()) - # TODO `tf.keras` also outputs individual losses for outputs ref_keys = sorted( [ "loss", - # "output_a_loss", + "output_a_loss", "output_a_bce", "output_a_mae", "output_a_mse", "output_b_acc", - # "output_b_loss", + "output_b_loss", "output_b_mse", ] ) @@ -303,14 +301,13 @@ def test_functional_list_outputs_nested_list_losses(self): # Fit the model to make sure compile_metrics are built hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) hist_keys = sorted(hist.history.keys()) - # TODO `tf.keras` also outputs individual losses for outputs ref_keys = sorted( [ "loss", - # "output_a_loss", + "output_a_loss", "output_a_mean_squared_error", "output_b_accuracy", - # "output_b_loss", + "output_b_loss", "output_b_mean_squared_error", ] ) @@ -351,15 +348,14 @@ def test_functional_dict_outputs_dict_losses(self): verbose=0, ) hist_keys = sorted(hist.history.keys()) - # TODO `tf.keras` also outputs individual losses for outputs ref_keys = sorted( [ "loss", - # "output_a_loss", + "output_a_loss", "output_a_mean_squared_error", "output_a_weighted_mean_squared_error", "output_b_accuracy", - # "output_b_loss", + "output_b_loss", "output_b_mean_squared_error", "output_b_weighted_accuracy", "output_b_weighted_mean_squared_error", @@ -396,15 +392,14 @@ def test_functional_list_outputs_dict_losses_metrics(self): # Fit the model to make sure compile_metrics are built hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) hist_keys = sorted(hist.history.keys()) - # TODO `tf.keras` also outputs individual losses for outputs ref_keys = sorted( [ "loss", - # "output_a_loss", + "output_a_loss", "output_a_mean_squared_error", "output_a_weighted_mean_squared_error", "output_b_accuracy", - # "output_b_loss", + "output_b_loss", "output_b_mean_squared_error", "output_b_weighted_accuracy", "output_b_weighted_mean_squared_error", @@ -436,18 +431,17 @@ def test_functional_list_outputs_dict_losses_metrics_uniq_weighted(self): # Fit the model to make sure compile_metrics are built hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) hist_keys = sorted(hist.history.keys()) - # TODO `tf.keras` also outputs individual losses for outputs # `output_b_accuracy` doesn't have `weighted_` in metric name. # When a metric is only in weighted metrics, it skips `weighted_` # prefix. This behavior matches`tf.keras`. ref_keys = sorted( [ "loss", - # "output_a_loss", + "output_a_loss", "output_a_mean_squared_error", "output_a_weighted_mean_squared_error", "output_b_accuracy", - # "output_b_loss", + "output_b_loss", "output_b_mean_squared_error", ] ) @@ -472,13 +466,12 @@ def test_functional_list_outputs_dict_losses_partial_metrics(self): # Fit the model to make sure compile_metrics are built hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) hist_keys = sorted(hist.history.keys()) - # TODO `tf.keras` also outputs individual losses for outputs ref_keys = sorted( [ "loss", - # "output_a_loss", + "output_a_loss", "output_b_accuracy", - # "output_b_loss", + "output_b_loss", "output_b_mean_squared_error", ] ) @@ -500,7 +493,10 @@ def test_functional_dict_outputs_with_single_tensor(self): "output_b": "binary_crossentropy", }, ) - model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted(["loss", "output_a_loss", "output_b_loss"]) + self.assertListEqual(hist_keys, ref_keys) def test_functional_list_outputs_with_custom_compute_loss(self): model = _get_model_with_custom_compute_loss() @@ -514,7 +510,12 @@ def test_functional_list_outputs_with_custom_compute_loss(self): model.compile( optimizer="sgd", loss=["mean_squared_error", "binary_crossentropy"] ) - model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted( + ["binary_crossentropy_loss", "loss", "mean_squared_error_loss"] + ) + self.assertListEqual(hist_keys, ref_keys) def test_functional_list_outputs_dict_losses_invalid_keys(self): model = _get_model_multi_outputs_list() diff --git a/keras/src/trainers/compile_utils.py b/keras/src/trainers/compile_utils.py index 9e21da2cf75..114925e669d 100644 --- a/keras/src/trainers/compile_utils.py +++ b/keras/src/trainers/compile_utils.py @@ -3,6 +3,7 @@ from keras.src import ops from keras.src import tree from keras.src.utils.naming import get_object_name +from keras.src.utils.tracking import Tracker class MetricsList(metrics_module.Metric): @@ -431,6 +432,28 @@ def __init__( # Inferred by `y_pred` and `output_names` self.inferred_output_names = None + # Use `Tracker` to track metrcis for individual losses. + self._metrics = [] + self._tracker = Tracker( + { + "metrics": ( + lambda x: isinstance(x, metrics_module.Metric), + self._metrics, + ) + } + ) + + @property + def metrics(self): + return self._metrics + + @property + def variables(self): + vars = [] + for m in self.metrics: + vars.extend(m.variables) + return vars + def build(self, y_true, y_pred): loss = self._user_loss loss_weights = self._user_loss_weights @@ -527,6 +550,21 @@ def build(self, y_true, y_pred): for identifier, _y_true, _y_pred in zip(flat_losses, y_true, y_pred) ] + # Add `Mean` metric to the tracker for each loss. + if len(flat_losses) > 1: + for i, _loss in enumerate(flat_losses): + if _loss is not None: + if inferred_output_names is not None and len( + inferred_output_names + ) == len(flat_losses): + name = inferred_output_names[i] + else: + name = _loss.name + name += "_loss" + self._tracker.add_to_store( + "metrics", metrics_module.Mean(name=name) + ) + self.flat_losses = flat_losses self.flat_loss_weights = flat_loss_weights self.filtered_y_true_keys = filtered_y_true_keys @@ -596,22 +634,31 @@ def call(self, y_true, y_pred, sample_weight=None): else: sample_weight = [None for _ in y_true] + # We need to add a dummy `None` if the model has only a single output. + metrics = [None] if len(self.metrics) == 0 else self.metrics + # Iterate all losses in flat form. loss_values = [] - for loss, y_t, y_p, loss_weight, sample_weight in zip( + for loss_fn, y_t, y_p, loss_weight, sample_weight, metric in zip( self.flat_losses, y_true, y_pred, self.flat_loss_weights, sample_weight, + metrics, ): - if loss: + if loss_fn: value = ops.cast( - loss(y_t, y_p, sample_weight), dtype=self.dtype + loss_fn(y_t, y_p, sample_weight), dtype=self.dtype ) if loss_weight is not None: value = ops.multiply(value, loss_weight) loss_values.append(value) + # Record individual losses. + if metric: + metric.update_state( + value, sample_weight=tree.flatten(y_p)[0].shape[0] + ) if loss_values: total_loss = sum(loss_values) return total_loss diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index 397c49ce391..9b027da9c4c 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -250,6 +250,8 @@ def metrics(self): metrics.extend(super().metrics) if self.compiled and self._compile_metrics is not None: metrics += [self._compile_metrics] + if self.compiled and self._compile_loss is not None: + metrics.extend(self._compile_loss.metrics) return metrics @property @@ -1004,10 +1006,13 @@ def _symbolic_build(self, iterator=None, data_batch=None): self._compile_metrics is not None and not self._compile_metrics.built ) + compile_loss_unbuilt = ( + self._compile_loss is not None and not self._compile_loss.built + ) optimizer_unbuilt = ( self.optimizer is not None and not self.optimizer.built ) - if model_unbuilt or compile_metrics_unbuilt: + if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt: # Create symbolic tensors matching an input batch. def to_symbolic_input(v): @@ -1030,7 +1035,7 @@ def to_symbolic_input(v): # Build all model state with `backend.compute_output_spec`. try: - y_pred = backend.compute_output_spec(self, x) + y_pred = backend.compute_output_spec(self, x, training=False) except Exception as e: raise RuntimeError( "Unable to automatically build the model. " @@ -1052,6 +1057,16 @@ def to_symbolic_input(v): y_pred, sample_weight=sample_weight, ) + if compile_loss_unbuilt: + # Build `CompileLoss` state with `backend.compute_output_spec`. + backend.compute_output_spec( + self._compute_loss, + x, + y, + y_pred, + sample_weight=sample_weight, + training=False, + ) if optimizer_unbuilt: # Build optimizer self.optimizer.build(self.trainable_variables) diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index a4b21e5f505..064df23adc2 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -14,6 +14,7 @@ from keras.src import ops from keras.src import optimizers from keras.src import testing +from keras.src.backend.common.symbolic_scope import in_symbolic_scope from keras.src.callbacks.callback import Callback from keras.src.optimizers.rmsprop import RMSprop from keras.src.testing.test_utils import named_product @@ -1406,7 +1407,8 @@ def compute_loss( sample_weight=None, training=True, ): - test_self.assertTrue(training) + if not in_symbolic_scope(): + test_self.assertTrue(training) loss = super().compute_loss( x, y, y_pred, sample_weight, training ) @@ -1443,7 +1445,8 @@ def compute_loss( sample_weight=None, training=True, ): - test_self.assertTrue(training) + if not in_symbolic_scope(): + test_self.assertTrue(training) loss = super().compute_loss( x, y, y_pred, sample_weight, training ) @@ -1478,7 +1481,8 @@ def compute_loss( sample_weight=None, training=True, ): - test_self.assertFalse(training) + if not in_symbolic_scope(): + test_self.assertFalse(training) loss = super().compute_loss( x, y, y_pred, sample_weight, training ) @@ -1613,6 +1617,65 @@ def test_loss_weights(self): atol=1e-3, ) + def test_symbolic_build(self): + class ExampleModelWithTrainingArgs(Trainer, layers.Layer): + def __init__(self, units): + layers.Layer.__init__(self) + Trainer.__init__(self) + self.dense = layers.Dense(units) + self.bn = layers.BatchNormalization(axis=-1) + + def build(self, input_shape): + self.dense.build(input_shape) + input_shape = self.dense.compute_output_shape(input_shape) + self.bn.build(input_shape) + + def call(self, x, training=None): + outputs = self.bn(self.dense(x), training=training) + return [outputs, outputs] + + model = ExampleModelWithTrainingArgs(units=3) + model.compile( + optimizer=optimizers.SGD(), + loss=[losses.MeanSquaredError(), losses.MeanSquaredError()], + metrics=[metrics.MeanSquaredError(), metrics.MeanSquaredError()], + ) + x = np.ones((4, 4)) + y = np.zeros((4, 3)) + model(x) # Eager call to build model weights + ref_weights = model.get_weights() + + # Before `_symbolic_build` + self.assertTrue(model.built) + self.assertFalse(model._compile_metrics.built) + self.assertFalse(model._compile_loss.built) + self.assertLen(model._compile_loss.metrics, 0) + self.assertLen(model.metrics, 2) + + model._symbolic_build(data_batch=(x, (y, y))) + weights = model.get_weights() + + # Ensure weights are intact + self.assertEqual(len(weights), len(ref_weights)) + for w, ref_w in zip(weights, ref_weights): + self.assertAllClose(w, ref_w) + + # Ensure `built` + self.assertTrue(model.built) + self.assertTrue(model._compile_metrics.built) + self.assertTrue(model._compile_loss.built) + + # Ensure the len of metrics (original metrics + loss trackers) + self.assertLen(model._compile_metrics.metrics, 2) + self.assertLen(model._compile_loss.metrics, 2) + self.assertLen(model.metrics, 4) + + # Ensure no values in metrics + for v in model._compile_metrics.variables: + self.assertAllClose(v, 0.0) + for v in model._compile_loss.variables: + self.assertAllClose(v, 0.0) + class TrainerDistributeTest(testing.TestCase): @pytest.mark.skipif( diff --git a/keras/src/utils/torch_utils.py b/keras/src/utils/torch_utils.py index f20669e4955..e81018e0da7 100644 --- a/keras/src/utils/torch_utils.py +++ b/keras/src/utils/torch_utils.py @@ -112,7 +112,11 @@ def _track_module_parameters(self): self._track_variable(variable) self.built = True - def call(self, *args, **kwargs): + def call(self, *args, training=None, **kwargs): + if training is False: + self.eval() + else: + self.train() return self.module(*args, **kwargs) def save_own_variables(self, store): diff --git a/keras/src/utils/torch_utils_test.py b/keras/src/utils/torch_utils_test.py index 7e972f5b1b5..55003240710 100644 --- a/keras/src/utils/torch_utils_test.py +++ b/keras/src/utils/torch_utils_test.py @@ -29,9 +29,9 @@ def __init__( self.torch_wrappers.append(TorchModuleWrapper(torch_model)) self.fc = layers.Dense(1) - def call(self, x): + def call(self, x, training=None): for wrapper in self.torch_wrappers: - x = wrapper(x) + x = wrapper(x, training=training) return self.fc(x) def get_config(self): @@ -49,7 +49,7 @@ def __init__(self, *args, **kwargs): self.fc2 = torch.nn.Linear(4, 4) self.fc3 = layers.Dense(2) - def call(self, x): + def call(self, x, training=None): return self.fc3(self.fc2(self.bn1(self.fc1(x)))) @@ -82,6 +82,50 @@ def test_basic_usage(self, use_batch_norm, num_torch_layers): model.compile(optimizer="sgd", loss="mse") model.fit(np.random.random((3, 2)), np.random.random((3, 1))) + @parameterized.named_parameters( + ( + "explicit_torch_wrapper", + Classifier, + {"use_batch_norm": True, "num_torch_layers": 1}, + ), + ("implicit_torch_wrapper", ClassifierWithNoSpecialCasing, {}), + ) + def test_training_args(self, cls, kwargs): + model = cls(**kwargs) + model(np.random.random((3, 2)), training=False) # Eager call to build + ref_weights = model.get_weights() + ref_running_mean = backend.convert_to_numpy( + model.torch_wrappers[0].module[-1].running_mean + if cls is Classifier + else model.bn1.module.running_mean + ) + + # Test training=False doesn't affect model weights + model(np.random.random((3, 2)), training=False) + weights = model.get_weights() + for w, ref_w in zip(weights, ref_weights): + self.assertAllClose(w, ref_w) + + # Test training=None affects BN's stats + model.set_weights(ref_weights) # Restore previous weights + model(np.random.random((3, 2))) + running_mean = backend.convert_to_numpy( + model.torch_wrappers[0].module[-1].running_mean + if cls is Classifier + else model.bn1.module.running_mean + ) + self.assertNotAllClose(running_mean, ref_running_mean) + + # Test training=True affects BN's stats + model.set_weights(ref_weights) # Restore previous weights + model(np.random.random((3, 2)), training=True) + running_mean = backend.convert_to_numpy( + model.torch_wrappers[0].module[-1].running_mean + if cls is Classifier + else model.bn1.module.running_mean + ) + self.assertNotAllClose(running_mean, ref_running_mean) + def test_module_autowrapping(self): model = ClassifierWithNoSpecialCasing() self.assertIsInstance(model.fc1, TorchModuleWrapper)