Skip to content

Commit

Permalink
Fix CI
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Jul 21, 2024
1 parent 71a360b commit b0a318c
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 10 deletions.
4 changes: 2 additions & 2 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def evaluate(
steps_per_execution=self.steps_per_execution,
)

self._symbolic_build(iterator=epoch_iterator)
self._symbolic_build(iterator=epoch_iterator, training=False)

# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
Expand Down Expand Up @@ -765,7 +765,7 @@ def test_on_batch(
data = (x, y, sample_weight)
data = _distribute_data(data)
# Maybe build model
self._symbolic_build(data_batch=data)
self._symbolic_build(data_batch=data, training=False)
self._record_training_state_sharding_spec()
self.make_test_function()

Expand Down
8 changes: 4 additions & 4 deletions keras/src/backend/numpy/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def multi_predict_steps(data):

self.predict_function = predict_step

def _symbolic_build(self, data_batch):
def _symbolic_build(self, data_batch, training=True):
model_unbuilt = not all(layer.built for layer in self._flatten_layers())
compile_metrics_unbuilt = (
self._compile_metrics is not None
Expand All @@ -113,7 +113,7 @@ def to_symbolic_input(v):
) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)
# Build all model state with `backend.compute_output_spec`.
try:
y_pred = backend.compute_output_spec(self, x)
y_pred = backend.compute_output_spec(self, x, training=training)
except:
raise RuntimeError(
"Unable to automatically build the model. "
Expand Down Expand Up @@ -246,7 +246,7 @@ def evaluate(
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch():
data_batch = data[0]
self._symbolic_build(data_batch)
self._symbolic_build(data_batch, training=False)
break

# Container that configures and calls callbacks.
Expand Down Expand Up @@ -304,7 +304,7 @@ def test_on_batch(
data = (x, y, sample_weight)

# Maybe build model
self._symbolic_build(data)
self._symbolic_build(data, training=False)
self.make_test_function()

logs = self.test_function([data])
Expand Down
4 changes: 2 additions & 2 deletions keras/src/backend/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def evaluate(
steps_per_execution=self.steps_per_execution,
)

self._symbolic_build(iterator=epoch_iterator)
self._symbolic_build(iterator=epoch_iterator, training=False)

# Container that configures and calls callbacks.
if not isinstance(callbacks, callbacks_module.CallbackList):
Expand Down Expand Up @@ -483,7 +483,7 @@ def test_on_batch(
data = (x, y, sample_weight)

# Maybe build model
self._symbolic_build(data_batch=data)
self._symbolic_build(data_batch=data, training=False)
self.make_test_function()

logs = self.test_function([data])
Expand Down
2 changes: 2 additions & 0 deletions keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,8 @@ def _get_regularization_losses(self):
v = backend.get_stateless_scope().get_current_value(variable)
else:
v = variable
if v is None:
v = variable
weight_regularization_losses.append(variable.regularizer(v))
return weight_regularization_losses

Expand Down
5 changes: 3 additions & 2 deletions keras/src/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def _assert_compile_called(self, method_name=None):
msg += f"calling `{method_name}()`."
raise ValueError(msg)

def _symbolic_build(self, iterator=None, data_batch=None):
def _symbolic_build(self, iterator=None, data_batch=None, training=True):
model_unbuilt = not all(layer.built for layer in self._flatten_layers())
compile_metrics_unbuilt = (
self._compile_metrics is not None
Expand Down Expand Up @@ -1035,7 +1035,7 @@ def to_symbolic_input(v):

# Build all model state with `backend.compute_output_spec`.
try:
y_pred = backend.compute_output_spec(self, x)
y_pred = backend.compute_output_spec(self, x, training=training)
except Exception as e:
raise RuntimeError(
"Unable to automatically build the model. "
Expand Down Expand Up @@ -1065,6 +1065,7 @@ def to_symbolic_input(v):
y,
y_pred,
sample_weight=sample_weight,
training=training,
)
if optimizer_unbuilt:
# Build optimizer
Expand Down

0 comments on commit b0a318c

Please sign in to comment.