Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Jul 22, 2024
1 parent 11edb68 commit da41073
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 3 deletions.
2 changes: 2 additions & 0 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from keras.src.backend.common.stateless_scope import StatelessScope
from keras.src.backend.common.stateless_scope import get_stateless_scope
from keras.src.backend.common.stateless_scope import in_stateless_scope
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.common.symbolic_scope import in_symbolic_scope
from keras.src.backend.common.variables import AutocastScope
from keras.src.backend.common.variables import get_autocast_scope
from keras.src.backend.common.variables import is_float_dtype
Expand Down
2 changes: 2 additions & 0 deletions keras/src/backend/common/symbolic_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

@keras_export("keras.SymbolicScope")
class SymbolicScope:
"""Scope to indicate the symbolic stage."""

def __enter__(self):
self.original_scope = get_symbolic_scope()
global_state.set_global_attribute("symbolic_scope", self)
Expand Down
26 changes: 26 additions & 0 deletions keras/src/backend/common/symbolic_scope_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np

from keras.src import ops
from keras.src import testing
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.common.symbolic_scope import in_symbolic_scope


class TestSymbolicScope(testing.TestCase):
def test_basic_flow(self):

# Define a function that behaves differently according to
# `in_symbolic_scope`.
def compute_loss(y, y_pred):
if in_symbolic_scope():
return ops.zeros_like(y)
return ops.add(y, y_pred)

y = ops.ones(shape=(2,))
y_pred = ops.ones(shape=(2,))
with SymbolicScope():
loss = compute_loss(y, y_pred)
self.assertAllClose(loss, np.zeros((2,)))

loss = compute_loss(y, y_pred)
self.assertAllClose(loss, 2 * np.ones((2,)))
14 changes: 13 additions & 1 deletion keras/src/backend/numpy/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def _symbolic_build(self, data_batch):
self._compile_metrics is not None
and not self._compile_metrics.built
)
if model_unbuilt or compile_metrics_unbuilt:
compile_loss_unbuilt = (
self._compile_loss is not None and not self._compile_loss.built
)
if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt:
# Create symbolic tensors matching an input batch.

def to_symbolic_input(v):
Expand Down Expand Up @@ -133,6 +136,15 @@ def to_symbolic_input(v):
y_pred,
sample_weight=sample_weight,
)
if compile_loss_unbuilt:
# Build `CompileLoss` state with `backend.compute_output_spec`.
backend.compute_output_spec(
self._compute_loss,
x,
y,
y_pred,
sample_weight=sample_weight,
)
self._post_build()

def fit(
Expand Down
3 changes: 1 addition & 2 deletions keras/src/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ def to_symbolic_input(v):

# Build all model state with `backend.compute_output_spec`.
try:
y_pred = backend.compute_output_spec(self, x, training=False)
y_pred = backend.compute_output_spec(self, x)
except Exception as e:
raise RuntimeError(
"Unable to automatically build the model. "
Expand Down Expand Up @@ -1072,7 +1072,6 @@ def to_symbolic_input(v):
y,
y_pred,
sample_weight=sample_weight,
training=False,
)
if backend.backend() == "torch":
if original_training:
Expand Down
47 changes: 47 additions & 0 deletions keras/src/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,6 +1617,53 @@ def test_loss_weights(self):
atol=1e-3,
)

def test_symbolic_build(self):
class ExampleModelWithTrainingArgs(Trainer, layers.Layer):
def __init__(self, units):
layers.Layer.__init__(self)
Trainer.__init__(self)
self.dense = layers.Dense(
units,
use_bias=False,
kernel_initializer=initializers.Ones(),
)
self.bn = layers.BatchNormalization(axis=-1)

def build(self, input_shape):
self.dense.build(input_shape)
input_shape = self.dense.compute_output_shape(input_shape)
self.bn.build(input_shape)

def call(self, x, training=None):
outputs = self.bn(self.dense(x), training=training)
return [outputs, outputs]

model = ExampleModelWithTrainingArgs(units=3)
model.compile(
optimizer=optimizers.SGD(),
loss=[losses.MeanSquaredError(), losses.MeanSquaredError()],
metrics=[metrics.MeanSquaredError(), metrics.MeanSquaredError()],
)
x = np.ones((4, 4))
y = np.zeros((4, 3))
model.build(x.shape)
ref_weights = model.get_weights()
model._symbolic_build(data_batch=(x, (y, y)))
weights = model.get_weights()

# Ensure weights are intact
self.assertEqual(len(weights), len(ref_weights))
for w, ref_w in zip(weights, ref_weights):
self.assertAllClose(w, ref_w)

# Ensure `built`
self.assertTrue(model.built)
self.assertTrue(model._compile_metrics.built)
self.assertTrue(model._compile_loss.built)

# Ensure the len of CompileLoss's metrics (loss trackers)
self.assertLen(model._compile_loss.metrics, 2)


class TrainerDistributeTest(testing.TestCase):
@pytest.mark.skipif(
Expand Down

0 comments on commit da41073

Please sign in to comment.