From 71a360b98e561d44c947e9143d61c660ec79f084 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sun, 21 Jul 2024 16:41:51 +0800 Subject: [PATCH 1/7] Bring back loss info for multiple outputs --- keras/src/models/model_test.py | 47 ++++++++++---------- keras/src/trainers/compile_utils.py | 67 +++++++++++++++++++++++++++-- keras/src/trainers/trainer.py | 16 ++++++- 3 files changed, 103 insertions(+), 27 deletions(-) diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index ee61c64166b..94891e7fde4 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -239,14 +239,13 @@ def test_functional_list_outputs_list_losses(self): # Fit the model to make sure compile_metrics are built hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) hist_keys = sorted(hist.history.keys()) - # TODO `tf.keras` also outputs individual losses for outputs ref_keys = sorted( [ "loss", - # "output_a_loss", + "output_a_loss", "output_a_mean_squared_error", "output_b_accuracy", - # "output_b_loss", + "output_b_loss", "output_b_mean_squared_error", ] ) @@ -270,16 +269,15 @@ def test_functional_list_outputs_list_losses_abbr(self): # Fit the model to make sure compile_metrics are built hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) hist_keys = sorted(hist.history.keys()) - # TODO `tf.keras` also outputs individual losses for outputs ref_keys = sorted( [ "loss", - # "output_a_loss", + "output_a_loss", "output_a_bce", "output_a_mae", "output_a_mse", "output_b_acc", - # "output_b_loss", + "output_b_loss", "output_b_mse", ] ) @@ -303,14 +301,13 @@ def test_functional_list_outputs_nested_list_losses(self): # Fit the model to make sure compile_metrics are built hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) hist_keys = sorted(hist.history.keys()) - # TODO `tf.keras` also outputs individual losses for outputs ref_keys = sorted( [ "loss", - # "output_a_loss", + "output_a_loss", "output_a_mean_squared_error", "output_b_accuracy", - # "output_b_loss", + "output_b_loss", "output_b_mean_squared_error", ] ) @@ -351,15 +348,14 @@ def test_functional_dict_outputs_dict_losses(self): verbose=0, ) hist_keys = sorted(hist.history.keys()) - # TODO `tf.keras` also outputs individual losses for outputs ref_keys = sorted( [ "loss", - # "output_a_loss", + "output_a_loss", "output_a_mean_squared_error", "output_a_weighted_mean_squared_error", "output_b_accuracy", - # "output_b_loss", + "output_b_loss", "output_b_mean_squared_error", "output_b_weighted_accuracy", "output_b_weighted_mean_squared_error", @@ -396,15 +392,14 @@ def test_functional_list_outputs_dict_losses_metrics(self): # Fit the model to make sure compile_metrics are built hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) hist_keys = sorted(hist.history.keys()) - # TODO `tf.keras` also outputs individual losses for outputs ref_keys = sorted( [ "loss", - # "output_a_loss", + "output_a_loss", "output_a_mean_squared_error", "output_a_weighted_mean_squared_error", "output_b_accuracy", - # "output_b_loss", + "output_b_loss", "output_b_mean_squared_error", "output_b_weighted_accuracy", "output_b_weighted_mean_squared_error", @@ -436,18 +431,17 @@ def test_functional_list_outputs_dict_losses_metrics_uniq_weighted(self): # Fit the model to make sure compile_metrics are built hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) hist_keys = sorted(hist.history.keys()) - # TODO `tf.keras` also outputs individual losses for outputs # `output_b_accuracy` doesn't have `weighted_` in metric name. # When a metric is only in weighted metrics, it skips `weighted_` # prefix. This behavior matches`tf.keras`. ref_keys = sorted( [ "loss", - # "output_a_loss", + "output_a_loss", "output_a_mean_squared_error", "output_a_weighted_mean_squared_error", "output_b_accuracy", - # "output_b_loss", + "output_b_loss", "output_b_mean_squared_error", ] ) @@ -472,13 +466,12 @@ def test_functional_list_outputs_dict_losses_partial_metrics(self): # Fit the model to make sure compile_metrics are built hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) hist_keys = sorted(hist.history.keys()) - # TODO `tf.keras` also outputs individual losses for outputs ref_keys = sorted( [ "loss", - # "output_a_loss", + "output_a_loss", "output_b_accuracy", - # "output_b_loss", + "output_b_loss", "output_b_mean_squared_error", ] ) @@ -500,7 +493,10 @@ def test_functional_dict_outputs_with_single_tensor(self): "output_b": "binary_crossentropy", }, ) - model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted(["loss", "output_a_loss", "output_b_loss"]) + self.assertListEqual(hist_keys, ref_keys) def test_functional_list_outputs_with_custom_compute_loss(self): model = _get_model_with_custom_compute_loss() @@ -514,7 +510,12 @@ def test_functional_list_outputs_with_custom_compute_loss(self): model.compile( optimizer="sgd", loss=["mean_squared_error", "binary_crossentropy"] ) - model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0) + hist_keys = sorted(hist.history.keys()) + ref_keys = sorted( + ["binary_crossentropy_loss", "loss", "mean_squared_error_loss"] + ) + self.assertListEqual(hist_keys, ref_keys) def test_functional_list_outputs_dict_losses_invalid_keys(self): model = _get_model_multi_outputs_list() diff --git a/keras/src/trainers/compile_utils.py b/keras/src/trainers/compile_utils.py index 9e21da2cf75..8a2aa4df31c 100644 --- a/keras/src/trainers/compile_utils.py +++ b/keras/src/trainers/compile_utils.py @@ -3,6 +3,7 @@ from keras.src import ops from keras.src import tree from keras.src.utils.naming import get_object_name +from keras.src.utils.tracking import Tracker class MetricsList(metrics_module.Metric): @@ -431,6 +432,39 @@ def __init__( # Inferred by `y_pred` and `output_names` self.inferred_output_names = None + # Use `Tracker` to track metrcis for individual losses. + self._metrics = [] + self._tracker = Tracker( + { + "metrics": ( + lambda x: isinstance(x, metrics_module.Metric), + self._metrics, + ) + } + ) + + @property + def metrics(self): + if not self.built: + return [] + metrics = [] + for m in self._metrics: + if m is not None: + metrics.append(m) + return metrics + + @property + def variables(self): + # Avoiding relying on implicit tracking since + # CompileLoss may be instantiated or built in a no tracking scope. + if not self.built: + return [] + vars = [] + for m in self.metrics: + if m is not None: + vars.extend(m.variables) + return vars + def build(self, y_true, y_pred): loss = self._user_loss loss_weights = self._user_loss_weights @@ -527,6 +561,21 @@ def build(self, y_true, y_pred): for identifier, _y_true, _y_pred in zip(flat_losses, y_true, y_pred) ] + # Add `Mean` metric to the tracker for each loss. + if len(flat_losses) > 1: + for i, loss in enumerate(flat_losses): + if loss is not None: + if inferred_output_names is not None and len( + inferred_output_names + ) == len(flat_losses): + name = inferred_output_names[i] + else: + name = loss.name + name += "_loss" + self._tracker.add_to_store( + "metrics", metrics_module.Mean(name=name) + ) + self.flat_losses = flat_losses self.flat_loss_weights = flat_loss_weights self.filtered_y_true_keys = filtered_y_true_keys @@ -595,23 +644,35 @@ def call(self, y_true, y_pred, sample_weight=None): sample_weight = [sample_weight[0] for _ in range(len(y_true))] else: sample_weight = [None for _ in y_true] + if len(self.metrics) == 0: + # This means that the model has a single output. We need to add a + # dummy `None` for the following `zip` to function correctly. + metrics = [None] + else: + metrics = self.metrics # Iterate all losses in flat form. loss_values = [] - for loss, y_t, y_p, loss_weight, sample_weight in zip( + for loss_fn, y_t, y_p, loss_weight, sample_weight, metric in zip( self.flat_losses, y_true, y_pred, self.flat_loss_weights, sample_weight, + metrics, ): - if loss: + if loss_fn: value = ops.cast( - loss(y_t, y_p, sample_weight), dtype=self.dtype + loss_fn(y_t, y_p, sample_weight), dtype=self.dtype ) if loss_weight is not None: value = ops.multiply(value, loss_weight) loss_values.append(value) + # Record individual losses. + if metric: + metric.update_state( + value, sample_weight=tree.flatten(y_p)[0].shape[0] + ) if loss_values: total_loss = sum(loss_values) return total_loss diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index 397c49ce391..8cf4cd9445b 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -250,6 +250,8 @@ def metrics(self): metrics.extend(super().metrics) if self.compiled and self._compile_metrics is not None: metrics += [self._compile_metrics] + if self.compiled and self._compile_loss is not None: + metrics.extend(self._compile_loss.metrics) return metrics @property @@ -1004,10 +1006,13 @@ def _symbolic_build(self, iterator=None, data_batch=None): self._compile_metrics is not None and not self._compile_metrics.built ) + compile_loss_unbuilt = ( + self._compile_loss is not None and not self._compile_loss.built + ) optimizer_unbuilt = ( self.optimizer is not None and not self.optimizer.built ) - if model_unbuilt or compile_metrics_unbuilt: + if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt: # Create symbolic tensors matching an input batch. def to_symbolic_input(v): @@ -1052,6 +1057,15 @@ def to_symbolic_input(v): y_pred, sample_weight=sample_weight, ) + if compile_loss_unbuilt: + # Build `CompileLoss` state with `backend.compute_output_spec`. + backend.compute_output_spec( + self._compute_loss, + x, + y, + y_pred, + sample_weight=sample_weight, + ) if optimizer_unbuilt: # Build optimizer self.optimizer.build(self.trainable_variables) From 6c68683712d0cb6aa5d7b69423839b3ce6dec30b Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sun, 21 Jul 2024 19:43:50 +0800 Subject: [PATCH 2/7] Fix CI --- keras/src/backend/jax/trainer.py | 4 ++-- keras/src/backend/numpy/trainer.py | 8 ++++---- keras/src/backend/torch/trainer.py | 4 ++-- keras/src/layers/layer.py | 2 ++ keras/src/trainers/compile_utils.py | 15 +++++---------- keras/src/trainers/trainer.py | 12 +++++++++--- 6 files changed, 24 insertions(+), 21 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index cab8333688a..ab3d49e6a2c 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -536,7 +536,7 @@ def evaluate( steps_per_execution=self.steps_per_execution, ) - self._symbolic_build(iterator=epoch_iterator) + self._symbolic_build(iterator=epoch_iterator, training=False) # Container that configures and calls callbacks. if not isinstance(callbacks, callbacks_module.CallbackList): @@ -765,7 +765,7 @@ def test_on_batch( data = (x, y, sample_weight) data = _distribute_data(data) # Maybe build model - self._symbolic_build(data_batch=data) + self._symbolic_build(data_batch=data, training=False) self._record_training_state_sharding_spec() self.make_test_function() diff --git a/keras/src/backend/numpy/trainer.py b/keras/src/backend/numpy/trainer.py index 6d40982be43..553602363e5 100644 --- a/keras/src/backend/numpy/trainer.py +++ b/keras/src/backend/numpy/trainer.py @@ -91,7 +91,7 @@ def multi_predict_steps(data): self.predict_function = predict_step - def _symbolic_build(self, data_batch): + def _symbolic_build(self, data_batch, training=True): model_unbuilt = not all(layer.built for layer in self._flatten_layers()) compile_metrics_unbuilt = ( self._compile_metrics is not None @@ -113,7 +113,7 @@ def to_symbolic_input(v): ) = data_adapter_utils.unpack_x_y_sample_weight(data_batch) # Build all model state with `backend.compute_output_spec`. try: - y_pred = backend.compute_output_spec(self, x) + y_pred = backend.compute_output_spec(self, x, training=training) except: raise RuntimeError( "Unable to automatically build the model. " @@ -246,7 +246,7 @@ def evaluate( # Build the model on one batch of data. for _, data in epoch_iterator.enumerate_epoch(): data_batch = data[0] - self._symbolic_build(data_batch) + self._symbolic_build(data_batch, training=False) break # Container that configures and calls callbacks. @@ -304,7 +304,7 @@ def test_on_batch( data = (x, y, sample_weight) # Maybe build model - self._symbolic_build(data) + self._symbolic_build(data, training=False) self.make_test_function() logs = self.test_function([data]) diff --git a/keras/src/backend/torch/trainer.py b/keras/src/backend/torch/trainer.py index bf0e133d67d..017d8d786b9 100644 --- a/keras/src/backend/torch/trainer.py +++ b/keras/src/backend/torch/trainer.py @@ -344,7 +344,7 @@ def evaluate( steps_per_execution=self.steps_per_execution, ) - self._symbolic_build(iterator=epoch_iterator) + self._symbolic_build(iterator=epoch_iterator, training=False) # Container that configures and calls callbacks. if not isinstance(callbacks, callbacks_module.CallbackList): @@ -483,7 +483,7 @@ def test_on_batch( data = (x, y, sample_weight) # Maybe build model - self._symbolic_build(data_batch=data) + self._symbolic_build(data_batch=data, training=False) self.make_test_function() logs = self.test_function([data]) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index dc945e22781..68ba4e1e06b 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1143,6 +1143,8 @@ def _get_regularization_losses(self): v = backend.get_stateless_scope().get_current_value(variable) else: v = variable + if v is None: + v = variable weight_regularization_losses.append(variable.regularizer(v)) return weight_regularization_losses diff --git a/keras/src/trainers/compile_utils.py b/keras/src/trainers/compile_utils.py index 8a2aa4df31c..51b34742fc7 100644 --- a/keras/src/trainers/compile_utils.py +++ b/keras/src/trainers/compile_utils.py @@ -447,11 +447,7 @@ def __init__( def metrics(self): if not self.built: return [] - metrics = [] - for m in self._metrics: - if m is not None: - metrics.append(m) - return metrics + return self._metrics @property def variables(self): @@ -461,8 +457,7 @@ def variables(self): return [] vars = [] for m in self.metrics: - if m is not None: - vars.extend(m.variables) + vars.extend(m.variables) return vars def build(self, y_true, y_pred): @@ -563,14 +558,14 @@ def build(self, y_true, y_pred): # Add `Mean` metric to the tracker for each loss. if len(flat_losses) > 1: - for i, loss in enumerate(flat_losses): - if loss is not None: + for i, _loss in enumerate(flat_losses): + if _loss is not None: if inferred_output_names is not None and len( inferred_output_names ) == len(flat_losses): name = inferred_output_names[i] else: - name = loss.name + name = _loss.name name += "_loss" self._tracker.add_to_store( "metrics", metrics_module.Mean(name=name) diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index 8cf4cd9445b..6ca3e2e0a0b 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -1000,7 +1000,7 @@ def _assert_compile_called(self, method_name=None): msg += f"calling `{method_name}()`." raise ValueError(msg) - def _symbolic_build(self, iterator=None, data_batch=None): + def _symbolic_build(self, iterator=None, data_batch=None, training=True): model_unbuilt = not all(layer.built for layer in self._flatten_layers()) compile_metrics_unbuilt = ( self._compile_metrics is not None @@ -1012,7 +1012,12 @@ def _symbolic_build(self, iterator=None, data_batch=None): optimizer_unbuilt = ( self.optimizer is not None and not self.optimizer.built ) - if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt: + need_build = model_unbuilt or compile_metrics_unbuilt + if backend.backend() != "torch": + # TODO: TorchModuleWrapper will have incorrect behavior using + # `_symbolic_build`. Not sure why. + need_build = need_build or compile_loss_unbuilt + if need_build: # Create symbolic tensors matching an input batch. def to_symbolic_input(v): @@ -1035,7 +1040,7 @@ def to_symbolic_input(v): # Build all model state with `backend.compute_output_spec`. try: - y_pred = backend.compute_output_spec(self, x) + y_pred = backend.compute_output_spec(self, x, training=training) except Exception as e: raise RuntimeError( "Unable to automatically build the model. " @@ -1065,6 +1070,7 @@ def to_symbolic_input(v): y, y_pred, sample_weight=sample_weight, + training=training, ) if optimizer_unbuilt: # Build optimizer From 1276d2e755b28c948f1f6c127f85cae5f90ed98a Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sun, 21 Jul 2024 23:21:39 +0800 Subject: [PATCH 3/7] Update torch eval mode --- keras/src/backend/jax/trainer.py | 4 ++-- keras/src/backend/numpy/trainer.py | 8 ++++---- keras/src/backend/torch/trainer.py | 4 ++-- keras/src/trainers/trainer.py | 18 +++++++++--------- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index ab3d49e6a2c..cab8333688a 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -536,7 +536,7 @@ def evaluate( steps_per_execution=self.steps_per_execution, ) - self._symbolic_build(iterator=epoch_iterator, training=False) + self._symbolic_build(iterator=epoch_iterator) # Container that configures and calls callbacks. if not isinstance(callbacks, callbacks_module.CallbackList): @@ -765,7 +765,7 @@ def test_on_batch( data = (x, y, sample_weight) data = _distribute_data(data) # Maybe build model - self._symbolic_build(data_batch=data, training=False) + self._symbolic_build(data_batch=data) self._record_training_state_sharding_spec() self.make_test_function() diff --git a/keras/src/backend/numpy/trainer.py b/keras/src/backend/numpy/trainer.py index 553602363e5..6d40982be43 100644 --- a/keras/src/backend/numpy/trainer.py +++ b/keras/src/backend/numpy/trainer.py @@ -91,7 +91,7 @@ def multi_predict_steps(data): self.predict_function = predict_step - def _symbolic_build(self, data_batch, training=True): + def _symbolic_build(self, data_batch): model_unbuilt = not all(layer.built for layer in self._flatten_layers()) compile_metrics_unbuilt = ( self._compile_metrics is not None @@ -113,7 +113,7 @@ def to_symbolic_input(v): ) = data_adapter_utils.unpack_x_y_sample_weight(data_batch) # Build all model state with `backend.compute_output_spec`. try: - y_pred = backend.compute_output_spec(self, x, training=training) + y_pred = backend.compute_output_spec(self, x) except: raise RuntimeError( "Unable to automatically build the model. " @@ -246,7 +246,7 @@ def evaluate( # Build the model on one batch of data. for _, data in epoch_iterator.enumerate_epoch(): data_batch = data[0] - self._symbolic_build(data_batch, training=False) + self._symbolic_build(data_batch) break # Container that configures and calls callbacks. @@ -304,7 +304,7 @@ def test_on_batch( data = (x, y, sample_weight) # Maybe build model - self._symbolic_build(data, training=False) + self._symbolic_build(data) self.make_test_function() logs = self.test_function([data]) diff --git a/keras/src/backend/torch/trainer.py b/keras/src/backend/torch/trainer.py index 017d8d786b9..bf0e133d67d 100644 --- a/keras/src/backend/torch/trainer.py +++ b/keras/src/backend/torch/trainer.py @@ -344,7 +344,7 @@ def evaluate( steps_per_execution=self.steps_per_execution, ) - self._symbolic_build(iterator=epoch_iterator, training=False) + self._symbolic_build(iterator=epoch_iterator) # Container that configures and calls callbacks. if not isinstance(callbacks, callbacks_module.CallbackList): @@ -483,7 +483,7 @@ def test_on_batch( data = (x, y, sample_weight) # Maybe build model - self._symbolic_build(data_batch=data, training=False) + self._symbolic_build(data_batch=data) self.make_test_function() logs = self.test_function([data]) diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index 6ca3e2e0a0b..59c82e55b33 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -1000,7 +1000,7 @@ def _assert_compile_called(self, method_name=None): msg += f"calling `{method_name}()`." raise ValueError(msg) - def _symbolic_build(self, iterator=None, data_batch=None, training=True): + def _symbolic_build(self, iterator=None, data_batch=None): model_unbuilt = not all(layer.built for layer in self._flatten_layers()) compile_metrics_unbuilt = ( self._compile_metrics is not None @@ -1012,12 +1012,10 @@ def _symbolic_build(self, iterator=None, data_batch=None, training=True): optimizer_unbuilt = ( self.optimizer is not None and not self.optimizer.built ) - need_build = model_unbuilt or compile_metrics_unbuilt - if backend.backend() != "torch": - # TODO: TorchModuleWrapper will have incorrect behavior using - # `_symbolic_build`. Not sure why. - need_build = need_build or compile_loss_unbuilt - if need_build: + if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt: + if backend.backend() == "torch": + original_training = self.training + self.eval() # Create symbolic tensors matching an input batch. def to_symbolic_input(v): @@ -1040,7 +1038,7 @@ def to_symbolic_input(v): # Build all model state with `backend.compute_output_spec`. try: - y_pred = backend.compute_output_spec(self, x, training=training) + y_pred = backend.compute_output_spec(self, x) except Exception as e: raise RuntimeError( "Unable to automatically build the model. " @@ -1070,8 +1068,10 @@ def to_symbolic_input(v): y, y_pred, sample_weight=sample_weight, - training=training, ) + if backend.backend() == "torch": + if original_training: + self.train() if optimizer_unbuilt: # Build optimizer self.optimizer.build(self.trainable_variables) From 11edb68a85b86cae1a791d606c2d4fe91bb40ee6 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 22 Jul 2024 09:40:07 +0800 Subject: [PATCH 4/7] Add SymbolicScope --- keras/__init__.py | 1 + keras/api/__init__.py | 1 + keras/api/_tf_keras/keras/__init__.py | 1 + keras/src/backend/common/symbolic_scope.py | 21 +++++++++++++++++++++ keras/src/backend/jax/core.py | 3 ++- keras/src/backend/numpy/core.py | 3 ++- keras/src/backend/tensorflow/core.py | 3 ++- keras/src/backend/torch/core.py | 3 ++- keras/src/layers/layer.py | 2 -- keras/src/trainers/trainer.py | 11 ++++++++--- keras/src/trainers/trainer_test.py | 10 +++++++--- 11 files changed, 47 insertions(+), 12 deletions(-) create mode 100644 keras/src/backend/common/symbolic_scope.py diff --git a/keras/__init__.py b/keras/__init__.py index 701568f2a01..5a429d3a5d8 100644 --- a/keras/__init__.py +++ b/keras/__init__.py @@ -18,6 +18,7 @@ from keras.api import Regularizer from keras.api import Sequential from keras.api import StatelessScope +from keras.api import SymbolicScope from keras.api import Variable from keras.api import __version__ from keras.api import activations diff --git a/keras/api/__init__.py b/keras/api/__init__.py index 1750a42e869..9d082ae9b89 100644 --- a/keras/api/__init__.py +++ b/keras/api/__init__.py @@ -33,6 +33,7 @@ from keras.api import utils from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.exports import Variable from keras.src.backend.exports import device from keras.src.backend.exports import name_scope diff --git a/keras/api/_tf_keras/keras/__init__.py b/keras/api/_tf_keras/keras/__init__.py index 5e0a7229473..39a7e9cdb18 100644 --- a/keras/api/_tf_keras/keras/__init__.py +++ b/keras/api/_tf_keras/keras/__init__.py @@ -31,6 +31,7 @@ from keras.api._tf_keras.keras import preprocessing from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.exports import Variable from keras.src.backend.exports import device from keras.src.backend.exports import name_scope diff --git a/keras/src/backend/common/symbolic_scope.py b/keras/src/backend/common/symbolic_scope.py new file mode 100644 index 00000000000..780032d5728 --- /dev/null +++ b/keras/src/backend/common/symbolic_scope.py @@ -0,0 +1,21 @@ +from keras.src.api_export import keras_export +from keras.src.backend.common import global_state + + +@keras_export("keras.SymbolicScope") +class SymbolicScope: + def __enter__(self): + self.original_scope = get_symbolic_scope() + global_state.set_global_attribute("symbolic_scope", self) + return self + + def __exit__(self, *args, **kwargs): + global_state.set_global_attribute("symbolic_scope", self.original_scope) + + +def in_symbolic_scope(): + return global_state.get_global_attribute("symbolic_scope") is not None + + +def get_symbolic_scope(): + return global_state.get_global_attribute("symbolic_scope") diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 3ccaf06a980..c36dfee6a04 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -10,6 +10,7 @@ from keras.src.backend.common import standardize_dtype from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.jax import distribution_lib SUPPORTS_SPARSE_TENSORS = True @@ -101,7 +102,7 @@ def cast(x, dtype): # Shape / dtype / sparseness inference util def compute_output_spec(fn, *args, **kwargs): - with StatelessScope(): + with StatelessScope(), SymbolicScope(): built_in_types = (type(None), int, float, str, bool, complex, bytes) # First, separate symbolic args from other args diff --git a/keras/src/backend/numpy/core.py b/keras/src/backend/numpy/core.py index 2d34fcd7c6c..97be123f9e8 100644 --- a/keras/src/backend/numpy/core.py +++ b/keras/src/backend/numpy/core.py @@ -12,6 +12,7 @@ from keras.src.backend.common.dtypes import result_type from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.backend.common.symbolic_scope import SymbolicScope SUPPORTS_SPARSE_TENSORS = False @@ -88,7 +89,7 @@ def vectorized_map(function, elements): # Shape / dtype inference util def compute_output_spec(fn, *args, **kwargs): - with StatelessScope(): + with StatelessScope(), SymbolicScope(): def has_none_shape(x): if isinstance(x, KerasTensor): diff --git a/keras/src/backend/tensorflow/core.py b/keras/src/backend/tensorflow/core.py index db33ce227d0..09d65e827cc 100644 --- a/keras/src/backend/tensorflow/core.py +++ b/keras/src/backend/tensorflow/core.py @@ -14,6 +14,7 @@ from keras.src.backend.common.name_scope import name_scope as base_name_scope from keras.src.backend.common.stateless_scope import StatelessScope from keras.src.backend.common.stateless_scope import in_stateless_scope +from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.tensorflow.sparse import sparse_to_dense from keras.src.utils.naming import auto_name @@ -182,7 +183,7 @@ def cast(x, dtype): def compute_output_spec(fn, *args, **kwargs): - with StatelessScope(): + with StatelessScope(), SymbolicScope(): graph_name = auto_name("scratch_graph") with tf.__internal__.FuncGraph(graph_name).as_default(): diff --git a/keras/src/backend/torch/core.py b/keras/src/backend/torch/core.py index 3a941fc46a4..5f01d57d5b7 100644 --- a/keras/src/backend/torch/core.py +++ b/keras/src/backend/torch/core.py @@ -17,6 +17,7 @@ from keras.src.backend.common.stateless_scope import StatelessScope from keras.src.backend.common.stateless_scope import get_stateless_scope from keras.src.backend.common.stateless_scope import in_stateless_scope +from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.config import floatx SUPPORTS_SPARSE_TENSORS = False @@ -335,7 +336,7 @@ def symbolic_call(fn, args, kwargs, fill_value): ) return fn(*eager_args, **eager_kwargs) - with StatelessScope(), torch.no_grad(): + with StatelessScope(), SymbolicScope(), torch.no_grad(): outputs = symbolic_call(fn, args, kwargs, fill_value=83) none_in_shape = any( diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 68ba4e1e06b..dc945e22781 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1143,8 +1143,6 @@ def _get_regularization_losses(self): v = backend.get_stateless_scope().get_current_value(variable) else: v = variable - if v is None: - v = variable weight_regularization_losses.append(variable.regularizer(v)) return weight_regularization_losses diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index 59c82e55b33..22917f61644 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -7,6 +7,7 @@ from keras.src import ops from keras.src import optimizers from keras.src import tree +from keras.src.backend.common.symbolic_scope import in_symbolic_scope from keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer from keras.src.saving import serialization_lib from keras.src.trainers.compile_utils import CompileLoss @@ -327,8 +328,11 @@ def metrics(self): loss = self._compile_loss(y, y_pred, sample_weight) if loss is not None: losses.append(loss) - for loss in self.losses: - losses.append(ops.sum(ops.cast(loss, dtype=backend.floatx()))) + if not in_symbolic_scope(): + # If in symbolic scope, skip `self.losses` to ensure we don't access + # any variables. + for loss in self.losses: + losses.append(ops.sum(ops.cast(loss, dtype=backend.floatx()))) if backend.backend() != "jax" and len(losses) == 0: raise ValueError( "No loss to compute. Provide a `loss` argument in `compile()`." @@ -1038,7 +1042,7 @@ def to_symbolic_input(v): # Build all model state with `backend.compute_output_spec`. try: - y_pred = backend.compute_output_spec(self, x) + y_pred = backend.compute_output_spec(self, x, training=False) except Exception as e: raise RuntimeError( "Unable to automatically build the model. " @@ -1068,6 +1072,7 @@ def to_symbolic_input(v): y, y_pred, sample_weight=sample_weight, + training=False, ) if backend.backend() == "torch": if original_training: diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index a4b21e5f505..e5bc3cbdc8f 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -14,6 +14,7 @@ from keras.src import ops from keras.src import optimizers from keras.src import testing +from keras.src.backend.common.symbolic_scope import in_symbolic_scope from keras.src.callbacks.callback import Callback from keras.src.optimizers.rmsprop import RMSprop from keras.src.testing.test_utils import named_product @@ -1406,7 +1407,8 @@ def compute_loss( sample_weight=None, training=True, ): - test_self.assertTrue(training) + if not in_symbolic_scope(): + test_self.assertTrue(training) loss = super().compute_loss( x, y, y_pred, sample_weight, training ) @@ -1443,7 +1445,8 @@ def compute_loss( sample_weight=None, training=True, ): - test_self.assertTrue(training) + if not in_symbolic_scope(): + test_self.assertTrue(training) loss = super().compute_loss( x, y, y_pred, sample_weight, training ) @@ -1478,7 +1481,8 @@ def compute_loss( sample_weight=None, training=True, ): - test_self.assertFalse(training) + if not in_symbolic_scope(): + test_self.assertFalse(training) loss = super().compute_loss( x, y, y_pred, sample_weight, training ) From 97915e335528f5bee45f14f8ab3f15b79f4e3b34 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 22 Jul 2024 10:28:17 +0800 Subject: [PATCH 5/7] Minor updates and add tests --- keras/src/backend/__init__.py | 2 + keras/src/backend/common/symbolic_scope.py | 2 + .../src/backend/common/symbolic_scope_test.py | 26 ++++++++ keras/src/backend/numpy/trainer.py | 14 ++++- keras/src/trainers/compile_utils.py | 15 +---- keras/src/trainers/trainer.py | 8 +-- keras/src/trainers/trainer_test.py | 59 +++++++++++++++++++ 7 files changed, 109 insertions(+), 17 deletions(-) create mode 100644 keras/src/backend/common/symbolic_scope_test.py diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 5c7fa223520..794fe3ca364 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -14,6 +14,8 @@ from keras.src.backend.common.stateless_scope import StatelessScope from keras.src.backend.common.stateless_scope import get_stateless_scope from keras.src.backend.common.stateless_scope import in_stateless_scope +from keras.src.backend.common.symbolic_scope import SymbolicScope +from keras.src.backend.common.symbolic_scope import in_symbolic_scope from keras.src.backend.common.variables import AutocastScope from keras.src.backend.common.variables import get_autocast_scope from keras.src.backend.common.variables import is_float_dtype diff --git a/keras/src/backend/common/symbolic_scope.py b/keras/src/backend/common/symbolic_scope.py index 780032d5728..15cd7a5ee05 100644 --- a/keras/src/backend/common/symbolic_scope.py +++ b/keras/src/backend/common/symbolic_scope.py @@ -4,6 +4,8 @@ @keras_export("keras.SymbolicScope") class SymbolicScope: + """Scope to indicate the symbolic stage.""" + def __enter__(self): self.original_scope = get_symbolic_scope() global_state.set_global_attribute("symbolic_scope", self) diff --git a/keras/src/backend/common/symbolic_scope_test.py b/keras/src/backend/common/symbolic_scope_test.py new file mode 100644 index 00000000000..092dcfe0748 --- /dev/null +++ b/keras/src/backend/common/symbolic_scope_test.py @@ -0,0 +1,26 @@ +import numpy as np + +from keras.src import ops +from keras.src import testing +from keras.src.backend.common.symbolic_scope import SymbolicScope +from keras.src.backend.common.symbolic_scope import in_symbolic_scope + + +class TestSymbolicScope(testing.TestCase): + def test_basic_flow(self): + + # Define a function that behaves differently according to + # `in_symbolic_scope`. + def compute_loss(y, y_pred): + if in_symbolic_scope(): + return ops.zeros_like(y) + return ops.add(y, y_pred) + + y = ops.ones(shape=(2,)) + y_pred = ops.ones(shape=(2,)) + with SymbolicScope(): + loss = compute_loss(y, y_pred) + self.assertAllClose(loss, np.zeros((2,))) + + loss = compute_loss(y, y_pred) + self.assertAllClose(loss, 2 * np.ones((2,))) diff --git a/keras/src/backend/numpy/trainer.py b/keras/src/backend/numpy/trainer.py index 6d40982be43..12c3aad56b6 100644 --- a/keras/src/backend/numpy/trainer.py +++ b/keras/src/backend/numpy/trainer.py @@ -97,7 +97,10 @@ def _symbolic_build(self, data_batch): self._compile_metrics is not None and not self._compile_metrics.built ) - if model_unbuilt or compile_metrics_unbuilt: + compile_loss_unbuilt = ( + self._compile_loss is not None and not self._compile_loss.built + ) + if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt: # Create symbolic tensors matching an input batch. def to_symbolic_input(v): @@ -133,6 +136,15 @@ def to_symbolic_input(v): y_pred, sample_weight=sample_weight, ) + if compile_loss_unbuilt: + # Build `CompileLoss` state with `backend.compute_output_spec`. + backend.compute_output_spec( + self._compute_loss, + x, + y, + y_pred, + sample_weight=sample_weight, + ) self._post_build() def fit( diff --git a/keras/src/trainers/compile_utils.py b/keras/src/trainers/compile_utils.py index 51b34742fc7..114925e669d 100644 --- a/keras/src/trainers/compile_utils.py +++ b/keras/src/trainers/compile_utils.py @@ -445,16 +445,10 @@ def __init__( @property def metrics(self): - if not self.built: - return [] return self._metrics @property def variables(self): - # Avoiding relying on implicit tracking since - # CompileLoss may be instantiated or built in a no tracking scope. - if not self.built: - return [] vars = [] for m in self.metrics: vars.extend(m.variables) @@ -639,12 +633,9 @@ def call(self, y_true, y_pred, sample_weight=None): sample_weight = [sample_weight[0] for _ in range(len(y_true))] else: sample_weight = [None for _ in y_true] - if len(self.metrics) == 0: - # This means that the model has a single output. We need to add a - # dummy `None` for the following `zip` to function correctly. - metrics = [None] - else: - metrics = self.metrics + + # We need to add a dummy `None` if the model has only a single output. + metrics = [None] if len(self.metrics) == 0 else self.metrics # Iterate all losses in flat form. loss_values = [] diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index 22917f61644..c2f6ede112d 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -328,9 +328,10 @@ def metrics(self): loss = self._compile_loss(y, y_pred, sample_weight) if loss is not None: losses.append(loss) + + # If in symbolic scope, skip `self.losses` to ensure we don't access + # any variables. Otherwise, it might break. if not in_symbolic_scope(): - # If in symbolic scope, skip `self.losses` to ensure we don't access - # any variables. for loss in self.losses: losses.append(ops.sum(ops.cast(loss, dtype=backend.floatx()))) if backend.backend() != "jax" and len(losses) == 0: @@ -1042,7 +1043,7 @@ def to_symbolic_input(v): # Build all model state with `backend.compute_output_spec`. try: - y_pred = backend.compute_output_spec(self, x, training=False) + y_pred = backend.compute_output_spec(self, x) except Exception as e: raise RuntimeError( "Unable to automatically build the model. " @@ -1072,7 +1073,6 @@ def to_symbolic_input(v): y, y_pred, sample_weight=sample_weight, - training=False, ) if backend.backend() == "torch": if original_training: diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index e5bc3cbdc8f..064df23adc2 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -1617,6 +1617,65 @@ def test_loss_weights(self): atol=1e-3, ) + def test_symbolic_build(self): + class ExampleModelWithTrainingArgs(Trainer, layers.Layer): + def __init__(self, units): + layers.Layer.__init__(self) + Trainer.__init__(self) + self.dense = layers.Dense(units) + self.bn = layers.BatchNormalization(axis=-1) + + def build(self, input_shape): + self.dense.build(input_shape) + input_shape = self.dense.compute_output_shape(input_shape) + self.bn.build(input_shape) + + def call(self, x, training=None): + outputs = self.bn(self.dense(x), training=training) + return [outputs, outputs] + + model = ExampleModelWithTrainingArgs(units=3) + model.compile( + optimizer=optimizers.SGD(), + loss=[losses.MeanSquaredError(), losses.MeanSquaredError()], + metrics=[metrics.MeanSquaredError(), metrics.MeanSquaredError()], + ) + x = np.ones((4, 4)) + y = np.zeros((4, 3)) + model(x) # Eager call to build model weights + ref_weights = model.get_weights() + + # Before `_symbolic_build` + self.assertTrue(model.built) + self.assertFalse(model._compile_metrics.built) + self.assertFalse(model._compile_loss.built) + self.assertLen(model._compile_loss.metrics, 0) + self.assertLen(model.metrics, 2) + + model._symbolic_build(data_batch=(x, (y, y))) + weights = model.get_weights() + + # Ensure weights are intact + self.assertEqual(len(weights), len(ref_weights)) + for w, ref_w in zip(weights, ref_weights): + self.assertAllClose(w, ref_w) + + # Ensure `built` + self.assertTrue(model.built) + self.assertTrue(model._compile_metrics.built) + self.assertTrue(model._compile_loss.built) + + # Ensure the len of metrics (original metrics + loss trackers) + self.assertLen(model._compile_metrics.metrics, 2) + self.assertLen(model._compile_loss.metrics, 2) + self.assertLen(model.metrics, 4) + + # Ensure no values in metrics + for v in model._compile_metrics.variables: + self.assertAllClose(v, 0.0) + for v in model._compile_loss.variables: + self.assertAllClose(v, 0.0) + class TrainerDistributeTest(testing.TestCase): @pytest.mark.skipif( From 8ef4d64845223183caf135e1ddf8c870dcba1ca8 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 22 Jul 2024 14:25:15 +0800 Subject: [PATCH 6/7] Address comment --- keras/src/layers/layer.py | 6 +++++- keras/src/trainers/trainer.py | 9 ++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index dc945e22781..d6b68e2ae16 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -32,6 +32,7 @@ from keras.src.backend import KerasTensor from keras.src.backend.common import global_state from keras.src.backend.common.name_scope import current_path +from keras.src.backend.common.symbolic_scope import in_symbolic_scope from keras.src.distribution import distribution_lib from keras.src.dtype_policies import DTypePolicyMap from keras.src.layers import input_spec @@ -1139,7 +1140,10 @@ def _get_regularization_losses(self): for variable in self.trainable_weights: if variable.regularizer is None: continue - if backend.in_stateless_scope(): + if backend.in_stateless_scope() and not in_symbolic_scope(): + # If in symbolic scope, we might get `None` from + # `get_current_value` in `backend.compute_output_spec`. So we + # assign `variable` instead. v = backend.get_stateless_scope().get_current_value(variable) else: v = variable diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index c2f6ede112d..59c82e55b33 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -7,7 +7,6 @@ from keras.src import ops from keras.src import optimizers from keras.src import tree -from keras.src.backend.common.symbolic_scope import in_symbolic_scope from keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer from keras.src.saving import serialization_lib from keras.src.trainers.compile_utils import CompileLoss @@ -328,12 +327,8 @@ def metrics(self): loss = self._compile_loss(y, y_pred, sample_weight) if loss is not None: losses.append(loss) - - # If in symbolic scope, skip `self.losses` to ensure we don't access - # any variables. Otherwise, it might break. - if not in_symbolic_scope(): - for loss in self.losses: - losses.append(ops.sum(ops.cast(loss, dtype=backend.floatx()))) + for loss in self.losses: + losses.append(ops.sum(ops.cast(loss, dtype=backend.floatx()))) if backend.backend() != "jax" and len(losses) == 0: raise ValueError( "No loss to compute. Provide a `loss` argument in `compile()`." From 603525c2ab05129f805184dd5e8f5b9a9628eb7d Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 22 Jul 2024 15:48:16 +0800 Subject: [PATCH 7/7] Fix TorchWrapper of training args --- keras/src/trainers/trainer.py | 9 ++---- keras/src/utils/torch_utils.py | 6 +++- keras/src/utils/torch_utils_test.py | 50 +++++++++++++++++++++++++++-- 3 files changed, 54 insertions(+), 11 deletions(-) diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index 59c82e55b33..9b027da9c4c 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -1013,9 +1013,6 @@ def _symbolic_build(self, iterator=None, data_batch=None): self.optimizer is not None and not self.optimizer.built ) if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt: - if backend.backend() == "torch": - original_training = self.training - self.eval() # Create symbolic tensors matching an input batch. def to_symbolic_input(v): @@ -1038,7 +1035,7 @@ def to_symbolic_input(v): # Build all model state with `backend.compute_output_spec`. try: - y_pred = backend.compute_output_spec(self, x) + y_pred = backend.compute_output_spec(self, x, training=False) except Exception as e: raise RuntimeError( "Unable to automatically build the model. " @@ -1068,10 +1065,8 @@ def to_symbolic_input(v): y, y_pred, sample_weight=sample_weight, + training=False, ) - if backend.backend() == "torch": - if original_training: - self.train() if optimizer_unbuilt: # Build optimizer self.optimizer.build(self.trainable_variables) diff --git a/keras/src/utils/torch_utils.py b/keras/src/utils/torch_utils.py index f20669e4955..e81018e0da7 100644 --- a/keras/src/utils/torch_utils.py +++ b/keras/src/utils/torch_utils.py @@ -112,7 +112,11 @@ def _track_module_parameters(self): self._track_variable(variable) self.built = True - def call(self, *args, **kwargs): + def call(self, *args, training=None, **kwargs): + if training is False: + self.eval() + else: + self.train() return self.module(*args, **kwargs) def save_own_variables(self, store): diff --git a/keras/src/utils/torch_utils_test.py b/keras/src/utils/torch_utils_test.py index 7e972f5b1b5..55003240710 100644 --- a/keras/src/utils/torch_utils_test.py +++ b/keras/src/utils/torch_utils_test.py @@ -29,9 +29,9 @@ def __init__( self.torch_wrappers.append(TorchModuleWrapper(torch_model)) self.fc = layers.Dense(1) - def call(self, x): + def call(self, x, training=None): for wrapper in self.torch_wrappers: - x = wrapper(x) + x = wrapper(x, training=training) return self.fc(x) def get_config(self): @@ -49,7 +49,7 @@ def __init__(self, *args, **kwargs): self.fc2 = torch.nn.Linear(4, 4) self.fc3 = layers.Dense(2) - def call(self, x): + def call(self, x, training=None): return self.fc3(self.fc2(self.bn1(self.fc1(x)))) @@ -82,6 +82,50 @@ def test_basic_usage(self, use_batch_norm, num_torch_layers): model.compile(optimizer="sgd", loss="mse") model.fit(np.random.random((3, 2)), np.random.random((3, 1))) + @parameterized.named_parameters( + ( + "explicit_torch_wrapper", + Classifier, + {"use_batch_norm": True, "num_torch_layers": 1}, + ), + ("implicit_torch_wrapper", ClassifierWithNoSpecialCasing, {}), + ) + def test_training_args(self, cls, kwargs): + model = cls(**kwargs) + model(np.random.random((3, 2)), training=False) # Eager call to build + ref_weights = model.get_weights() + ref_running_mean = backend.convert_to_numpy( + model.torch_wrappers[0].module[-1].running_mean + if cls is Classifier + else model.bn1.module.running_mean + ) + + # Test training=False doesn't affect model weights + model(np.random.random((3, 2)), training=False) + weights = model.get_weights() + for w, ref_w in zip(weights, ref_weights): + self.assertAllClose(w, ref_w) + + # Test training=None affects BN's stats + model.set_weights(ref_weights) # Restore previous weights + model(np.random.random((3, 2))) + running_mean = backend.convert_to_numpy( + model.torch_wrappers[0].module[-1].running_mean + if cls is Classifier + else model.bn1.module.running_mean + ) + self.assertNotAllClose(running_mean, ref_running_mean) + + # Test training=True affects BN's stats + model.set_weights(ref_weights) # Restore previous weights + model(np.random.random((3, 2)), training=True) + running_mean = backend.convert_to_numpy( + model.torch_wrappers[0].module[-1].running_mean + if cls is Classifier + else model.bn1.module.running_mean + ) + self.assertNotAllClose(running_mean, ref_running_mean) + def test_module_autowrapping(self): model = ClassifierWithNoSpecialCasing() self.assertIsInstance(model.fc1, TorchModuleWrapper)