diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 5c7fa2235207..794fe3ca3645 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -14,6 +14,8 @@ from keras.src.backend.common.stateless_scope import StatelessScope from keras.src.backend.common.stateless_scope import get_stateless_scope from keras.src.backend.common.stateless_scope import in_stateless_scope +from keras.src.backend.common.symbolic_scope import SymbolicScope +from keras.src.backend.common.symbolic_scope import in_symbolic_scope from keras.src.backend.common.variables import AutocastScope from keras.src.backend.common.variables import get_autocast_scope from keras.src.backend.common.variables import is_float_dtype diff --git a/keras/src/backend/common/symbolic_scope.py b/keras/src/backend/common/symbolic_scope.py index 780032d57282..15cd7a5ee059 100644 --- a/keras/src/backend/common/symbolic_scope.py +++ b/keras/src/backend/common/symbolic_scope.py @@ -4,6 +4,8 @@ @keras_export("keras.SymbolicScope") class SymbolicScope: + """Scope to indicate the symbolic stage.""" + def __enter__(self): self.original_scope = get_symbolic_scope() global_state.set_global_attribute("symbolic_scope", self) diff --git a/keras/src/backend/common/symbolic_scope_test.py b/keras/src/backend/common/symbolic_scope_test.py new file mode 100644 index 000000000000..092dcfe0748c --- /dev/null +++ b/keras/src/backend/common/symbolic_scope_test.py @@ -0,0 +1,26 @@ +import numpy as np + +from keras.src import ops +from keras.src import testing +from keras.src.backend.common.symbolic_scope import SymbolicScope +from keras.src.backend.common.symbolic_scope import in_symbolic_scope + + +class TestSymbolicScope(testing.TestCase): + def test_basic_flow(self): + + # Define a function that behaves differently according to + # `in_symbolic_scope`. + def compute_loss(y, y_pred): + if in_symbolic_scope(): + return ops.zeros_like(y) + return ops.add(y, y_pred) + + y = ops.ones(shape=(2,)) + y_pred = ops.ones(shape=(2,)) + with SymbolicScope(): + loss = compute_loss(y, y_pred) + self.assertAllClose(loss, np.zeros((2,))) + + loss = compute_loss(y, y_pred) + self.assertAllClose(loss, 2 * np.ones((2,))) diff --git a/keras/src/backend/numpy/trainer.py b/keras/src/backend/numpy/trainer.py index 6d40982be43e..12c3aad56b65 100644 --- a/keras/src/backend/numpy/trainer.py +++ b/keras/src/backend/numpy/trainer.py @@ -97,7 +97,10 @@ def _symbolic_build(self, data_batch): self._compile_metrics is not None and not self._compile_metrics.built ) - if model_unbuilt or compile_metrics_unbuilt: + compile_loss_unbuilt = ( + self._compile_loss is not None and not self._compile_loss.built + ) + if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt: # Create symbolic tensors matching an input batch. def to_symbolic_input(v): @@ -133,6 +136,15 @@ def to_symbolic_input(v): y_pred, sample_weight=sample_weight, ) + if compile_loss_unbuilt: + # Build `CompileLoss` state with `backend.compute_output_spec`. + backend.compute_output_spec( + self._compute_loss, + x, + y, + y_pred, + sample_weight=sample_weight, + ) self._post_build() def fit( diff --git a/keras/src/trainers/trainer.py b/keras/src/trainers/trainer.py index 22917f616449..6a35f93a54e6 100644 --- a/keras/src/trainers/trainer.py +++ b/keras/src/trainers/trainer.py @@ -1042,7 +1042,7 @@ def to_symbolic_input(v): # Build all model state with `backend.compute_output_spec`. try: - y_pred = backend.compute_output_spec(self, x, training=False) + y_pred = backend.compute_output_spec(self, x) except Exception as e: raise RuntimeError( "Unable to automatically build the model. " @@ -1072,7 +1072,6 @@ def to_symbolic_input(v): y, y_pred, sample_weight=sample_weight, - training=False, ) if backend.backend() == "torch": if original_training: diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index e5bc3cbdc8ff..d7c320c39bfe 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -1617,6 +1617,53 @@ def test_loss_weights(self): atol=1e-3, ) + def test_symbolic_build(self): + class ExampleModelWithTrainingArgs(Trainer, layers.Layer): + def __init__(self, units): + layers.Layer.__init__(self) + Trainer.__init__(self) + self.dense = layers.Dense( + units, + use_bias=False, + kernel_initializer=initializers.Ones(), + ) + self.bn = layers.BatchNormalization(axis=-1) + + def build(self, input_shape): + self.dense.build(input_shape) + input_shape = self.dense.compute_output_shape(input_shape) + self.bn.build(input_shape) + + def call(self, x, training=None): + outputs = self.bn(self.dense(x), training=training) + return [outputs, outputs] + + model = ExampleModelWithTrainingArgs(units=3) + model.compile( + optimizer=optimizers.SGD(), + loss=[losses.MeanSquaredError(), losses.MeanSquaredError()], + metrics=[metrics.MeanSquaredError(), metrics.MeanSquaredError()], + ) + x = np.ones((4, 4)) + y = np.zeros((4, 3)) + model.build(x.shape) + ref_weights = model.get_weights() + model._symbolic_build(data_batch=(x, (y, y))) + weights = model.get_weights() + + # Ensure weights are intact + self.assertEqual(len(weights), len(ref_weights)) + for w, ref_w in zip(weights, ref_weights): + self.assertAllClose(w, ref_w) + + # Ensure `built` + self.assertTrue(model.built) + self.assertTrue(model._compile_metrics.built) + self.assertTrue(model._compile_loss.built) + + # Ensure the len of CompileLoss's metrics (loss trackers) + self.assertLen(model._compile_loss.metrics, 2) + class TrainerDistributeTest(testing.TestCase): @pytest.mark.skipif(