Skip to content

Commit

Permalink
Bring back loss information for multiple outputs (#20023)
Browse files Browse the repository at this point in the history
* Bring back loss info for multiple outputs

* Fix CI

* Update torch eval mode

* Add SymbolicScope

* Minor updates and add tests

* Address comment

* Fix TorchWrapper of training args
  • Loading branch information
james77777778 authored Jul 24, 2024
1 parent de5d18b commit f6a81cc
Show file tree
Hide file tree
Showing 18 changed files with 289 additions and 41 deletions.
1 change: 1 addition & 0 deletions keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from keras.api import Regularizer
from keras.api import Sequential
from keras.api import StatelessScope
from keras.api import SymbolicScope
from keras.api import Variable
from keras.api import __version__
from keras.api import activations
Expand Down
1 change: 1 addition & 0 deletions keras/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from keras.api import utils
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.backend.common.stateless_scope import StatelessScope
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.exports import Variable
from keras.src.backend.exports import device
from keras.src.backend.exports import name_scope
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from keras.api._tf_keras.keras import preprocessing
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.backend.common.stateless_scope import StatelessScope
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.exports import Variable
from keras.src.backend.exports import device
from keras.src.backend.exports import name_scope
Expand Down
2 changes: 2 additions & 0 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from keras.src.backend.common.stateless_scope import StatelessScope
from keras.src.backend.common.stateless_scope import get_stateless_scope
from keras.src.backend.common.stateless_scope import in_stateless_scope
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.common.symbolic_scope import in_symbolic_scope
from keras.src.backend.common.variables import AutocastScope
from keras.src.backend.common.variables import get_autocast_scope
from keras.src.backend.common.variables import is_float_dtype
Expand Down
23 changes: 23 additions & 0 deletions keras/src/backend/common/symbolic_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from keras.src.api_export import keras_export
from keras.src.backend.common import global_state


@keras_export("keras.SymbolicScope")
class SymbolicScope:
"""Scope to indicate the symbolic stage."""

def __enter__(self):
self.original_scope = get_symbolic_scope()
global_state.set_global_attribute("symbolic_scope", self)
return self

def __exit__(self, *args, **kwargs):
global_state.set_global_attribute("symbolic_scope", self.original_scope)


def in_symbolic_scope():
return global_state.get_global_attribute("symbolic_scope") is not None


def get_symbolic_scope():
return global_state.get_global_attribute("symbolic_scope")
26 changes: 26 additions & 0 deletions keras/src/backend/common/symbolic_scope_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np

from keras.src import ops
from keras.src import testing
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.common.symbolic_scope import in_symbolic_scope


class TestSymbolicScope(testing.TestCase):
def test_basic_flow(self):

# Define a function that behaves differently according to
# `in_symbolic_scope`.
def compute_loss(y, y_pred):
if in_symbolic_scope():
return ops.zeros_like(y)
return ops.add(y, y_pred)

y = ops.ones(shape=(2,))
y_pred = ops.ones(shape=(2,))
with SymbolicScope():
loss = compute_loss(y, y_pred)
self.assertAllClose(loss, np.zeros((2,)))

loss = compute_loss(y, y_pred)
self.assertAllClose(loss, 2 * np.ones((2,)))
3 changes: 2 additions & 1 deletion keras/src/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from keras.src.backend.common import standardize_dtype
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.backend.common.stateless_scope import StatelessScope
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.jax import distribution_lib

SUPPORTS_SPARSE_TENSORS = True
Expand Down Expand Up @@ -101,7 +102,7 @@ def cast(x, dtype):

# Shape / dtype / sparseness inference util
def compute_output_spec(fn, *args, **kwargs):
with StatelessScope():
with StatelessScope(), SymbolicScope():
built_in_types = (type(None), int, float, str, bool, complex, bytes)

# First, separate symbolic args from other args
Expand Down
3 changes: 2 additions & 1 deletion keras/src/backend/numpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from keras.src.backend.common.dtypes import result_type
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.backend.common.stateless_scope import StatelessScope
from keras.src.backend.common.symbolic_scope import SymbolicScope

SUPPORTS_SPARSE_TENSORS = False

Expand Down Expand Up @@ -88,7 +89,7 @@ def vectorized_map(function, elements):

# Shape / dtype inference util
def compute_output_spec(fn, *args, **kwargs):
with StatelessScope():
with StatelessScope(), SymbolicScope():

def has_none_shape(x):
if isinstance(x, KerasTensor):
Expand Down
14 changes: 13 additions & 1 deletion keras/src/backend/numpy/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def _symbolic_build(self, data_batch):
self._compile_metrics is not None
and not self._compile_metrics.built
)
if model_unbuilt or compile_metrics_unbuilt:
compile_loss_unbuilt = (
self._compile_loss is not None and not self._compile_loss.built
)
if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt:
# Create symbolic tensors matching an input batch.

def to_symbolic_input(v):
Expand Down Expand Up @@ -133,6 +136,15 @@ def to_symbolic_input(v):
y_pred,
sample_weight=sample_weight,
)
if compile_loss_unbuilt:
# Build `CompileLoss` state with `backend.compute_output_spec`.
backend.compute_output_spec(
self._compute_loss,
x,
y,
y_pred,
sample_weight=sample_weight,
)
self._post_build()

def fit(
Expand Down
3 changes: 2 additions & 1 deletion keras/src/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from keras.src.backend.common.name_scope import name_scope as base_name_scope
from keras.src.backend.common.stateless_scope import StatelessScope
from keras.src.backend.common.stateless_scope import in_stateless_scope
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.tensorflow.sparse import sparse_to_dense
from keras.src.utils.naming import auto_name

Expand Down Expand Up @@ -182,7 +183,7 @@ def cast(x, dtype):


def compute_output_spec(fn, *args, **kwargs):
with StatelessScope():
with StatelessScope(), SymbolicScope():
graph_name = auto_name("scratch_graph")
with tf.__internal__.FuncGraph(graph_name).as_default():

Expand Down
3 changes: 2 additions & 1 deletion keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from keras.src.backend.common.stateless_scope import StatelessScope
from keras.src.backend.common.stateless_scope import get_stateless_scope
from keras.src.backend.common.stateless_scope import in_stateless_scope
from keras.src.backend.common.symbolic_scope import SymbolicScope
from keras.src.backend.config import floatx

SUPPORTS_SPARSE_TENSORS = False
Expand Down Expand Up @@ -335,7 +336,7 @@ def symbolic_call(fn, args, kwargs, fill_value):
)
return fn(*eager_args, **eager_kwargs)

with StatelessScope(), torch.no_grad():
with StatelessScope(), SymbolicScope(), torch.no_grad():
outputs = symbolic_call(fn, args, kwargs, fill_value=83)

none_in_shape = any(
Expand Down
6 changes: 5 additions & 1 deletion keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from keras.src.backend import KerasTensor
from keras.src.backend.common import global_state
from keras.src.backend.common.name_scope import current_path
from keras.src.backend.common.symbolic_scope import in_symbolic_scope
from keras.src.distribution import distribution_lib
from keras.src.dtype_policies import DTypePolicyMap
from keras.src.layers import input_spec
Expand Down Expand Up @@ -1148,7 +1149,10 @@ def _get_regularization_losses(self):
for variable in self.trainable_weights:
if variable.regularizer is None:
continue
if backend.in_stateless_scope():
if backend.in_stateless_scope() and not in_symbolic_scope():
# If in symbolic scope, we might get `None` from
# `get_current_value` in `backend.compute_output_spec`. So we
# assign `variable` instead.
v = backend.get_stateless_scope().get_current_value(variable)
else:
v = variable
Expand Down
47 changes: 24 additions & 23 deletions keras/src/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,14 +239,13 @@ def test_functional_list_outputs_list_losses(self):
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_loss",
"output_a_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_loss",
"output_b_mean_squared_error",
]
)
Expand All @@ -270,16 +269,15 @@ def test_functional_list_outputs_list_losses_abbr(self):
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_loss",
"output_a_bce",
"output_a_mae",
"output_a_mse",
"output_b_acc",
# "output_b_loss",
"output_b_loss",
"output_b_mse",
]
)
Expand All @@ -303,14 +301,13 @@ def test_functional_list_outputs_nested_list_losses(self):
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_loss",
"output_a_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_loss",
"output_b_mean_squared_error",
]
)
Expand Down Expand Up @@ -351,15 +348,14 @@ def test_functional_dict_outputs_dict_losses(self):
verbose=0,
)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_loss",
"output_a_mean_squared_error",
"output_a_weighted_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_loss",
"output_b_mean_squared_error",
"output_b_weighted_accuracy",
"output_b_weighted_mean_squared_error",
Expand Down Expand Up @@ -396,15 +392,14 @@ def test_functional_list_outputs_dict_losses_metrics(self):
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_loss",
"output_a_mean_squared_error",
"output_a_weighted_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_loss",
"output_b_mean_squared_error",
"output_b_weighted_accuracy",
"output_b_weighted_mean_squared_error",
Expand Down Expand Up @@ -436,18 +431,17 @@ def test_functional_list_outputs_dict_losses_metrics_uniq_weighted(self):
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
# `output_b_accuracy` doesn't have `weighted_` in metric name.
# When a metric is only in weighted metrics, it skips `weighted_`
# prefix. This behavior matches`tf.keras`.
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_loss",
"output_a_mean_squared_error",
"output_a_weighted_mean_squared_error",
"output_b_accuracy",
# "output_b_loss",
"output_b_loss",
"output_b_mean_squared_error",
]
)
Expand All @@ -472,13 +466,12 @@ def test_functional_list_outputs_dict_losses_partial_metrics(self):
# Fit the model to make sure compile_metrics are built
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_loss",
"output_b_accuracy",
# "output_b_loss",
"output_b_loss",
"output_b_mean_squared_error",
]
)
Expand All @@ -500,7 +493,10 @@ def test_functional_dict_outputs_with_single_tensor(self):
"output_b": "binary_crossentropy",
},
)
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
ref_keys = sorted(["loss", "output_a_loss", "output_b_loss"])
self.assertListEqual(hist_keys, ref_keys)

def test_functional_list_outputs_with_custom_compute_loss(self):
model = _get_model_with_custom_compute_loss()
Expand All @@ -514,7 +510,12 @@ def test_functional_list_outputs_with_custom_compute_loss(self):
model.compile(
optimizer="sgd", loss=["mean_squared_error", "binary_crossentropy"]
)
model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
ref_keys = sorted(
["binary_crossentropy_loss", "loss", "mean_squared_error_loss"]
)
self.assertListEqual(hist_keys, ref_keys)

def test_functional_list_outputs_dict_losses_invalid_keys(self):
model = _get_model_multi_outputs_list()
Expand Down
Loading

0 comments on commit f6a81cc

Please sign in to comment.