From 92e7171a46c89b899664ac6b05181ee6642aaff1 Mon Sep 17 00:00:00 2001 From: Maanas Arora Date: Thu, 14 Dec 2023 12:10:13 -0500 Subject: [PATCH] CTC loss for TF and torch (#18929) * Implement CTC loss in tensorflow backend * Implement CTC api in torch backend * Add CTC loss to keras losses * Remove CTC from losses * Perform log softmax in torch CTC loss * Refactor reviewed code in CTC API - Refactor sparse labels into main ctc_batch_cost function for tf * Fix formatting issue in docstring * Removed trailing space * Naming changes in nn.ctc_loss backend functions * Add ctc_loss keras op * Add correctness unit test for CTC loss * Skip test for CTC loss in JAX backend * Update ctc_loss function to also export to ops.nn * Add static type testing for CTC loss * Fix enabled backends for CTC loss test * Linting keras ops * Fix line overflow in CtcLoss class --- keras/backend/tensorflow/nn.py | 52 +++++++++++++++++++++++++++ keras/backend/torch/nn.py | 24 +++++++++++++ keras/ops/nn.py | 64 ++++++++++++++++++++++++++++++++++ keras/ops/nn_test.py | 31 ++++++++++++++++ 4 files changed, 171 insertions(+) diff --git a/keras/backend/tensorflow/nn.py b/keras/backend/tensorflow/nn.py index 20a6a6c65b1..acc7e921343 100644 --- a/keras/backend/tensorflow/nn.py +++ b/keras/backend/tensorflow/nn.py @@ -787,3 +787,55 @@ def batch_normalization( scale=scale, variance_epsilon=epsilon, ) + + +def ctc_loss( + target, + output, + target_length, + output_length, + mask_index=0, +): + """Runs CTC (Connectionist Temporal Classification) loss on each + batch element. + + Arguments: + target: Tensor `(batch_size, max_target_length)` containing the + target sequences in integer format. + output: Tensor `(batch_size, max_output_length, num_classes)` + containing the output of the softmax. + target_length: Tensor `(batch_size,)` containing the sequence length + for each target sequence in the batch. + output_length: Tensor `(batch_size,)` containing the sequence length + for each output sequence in the batch. + mask_index: The value in `target` and `output` that represents the + blank label. + + Returns: + A tensor of shape `(batch_size,)` containing the CTC loss for each + sample in the batch. + """ + target = tf.convert_to_tensor(target) + target = tf.cast(target, dtype="int32") + output = tf.convert_to_tensor(output) + output = tf.cast(output, dtype="float32") + + max_label_len = tf.shape(target)[1] + + mask = tf.sequence_mask(target_length, max_label_len) + indices = tf.where(mask) + values = tf.boolean_mask(target, mask) + + sparse_target = tf.SparseTensor( + indices=indices, + values=values, + dense_shape=tf.cast(tf.shape(target), dtype="int64"), + ) + + return tf.nn.ctc_loss( + labels=sparse_target, + logits=output, + label_length=target_length, + logit_length=output_length, + blank_index=mask_index, + ) diff --git a/keras/backend/torch/nn.py b/keras/backend/torch/nn.py index 8d5d61472b0..840d5072459 100644 --- a/keras/backend/torch/nn.py +++ b/keras/backend/torch/nn.py @@ -745,3 +745,27 @@ def _batch_norm(): order.pop(1) order.insert(axis, 1) return x.permute(order) + + +def ctc_loss( + target, + output, + target_length, + output_length, + mask_index=0, +): + target = convert_to_tensor(target) + output = convert_to_tensor(output) + target_length = convert_to_tensor(target_length) + output_length = convert_to_tensor(output_length) + + logits = tnn.log_softmax(output, dim=-1) + + return tnn.ctc_loss( + logits, + target, + output_length, + target_length, + blank=mask_index, + reduction="none", + ) diff --git a/keras/ops/nn.py b/keras/ops/nn.py index ec5ffaf7178..7fe3c759f43 100644 --- a/keras/ops/nn.py +++ b/keras/ops/nn.py @@ -1777,3 +1777,67 @@ def batch_normalization( return backend.nn.batch_normalization( x, mean, variance, axis, offset, scale, epsilon ) + + +class CtcLoss(Operation): + def __init__(self, mask_index): + super().__init__() + self.mask_index = mask_index + + def call(self, target, output, target_length, output_length): + return backend.nn.ctc_loss( + target, output, target_length, output_length, self.mask_index + ) + + def _check_shape_first_dim(self, name1, shape1, name2, shape2): + if shape1[0] != shape2[0]: + raise ValueError( + f"Arguments `{name1}` and `{name2}` must have the same " + "first dimension. " + f"Received shapes: `{shape1}` and `{shape2}`." + ) + + def compute_output_spec(self, target, output, target_length, output_length): + self._check_shape_first_dim( + "target", target.shape, "output", output.shape + ) + self._check_shape_first_dim( + "target_length", target_length.shape, "target", target.shape + ) + self._check_shape_first_dim( + "output_length", output_length.shape, "output", output.shape + ) + + return KerasTensor((target.shape[0],), dtype=target.dtype) + + +@keras_export( + [ + "keras.ops.ctc_loss", + "keras.ops.nn.ctc_loss", + ] +) +def ctc_loss(target, output, target_length, output_length, mask_index=0): + """CTC (Connectionist Temporal Classification) loss. + + Args: + target: A tensor of shape `(batch_size, target_max_length)` containing + the true labels in integer format. + output: A tensor of shape `(batch_size, output_max_length, num_classes)` + containing the output from the network. + target_length: A tensor of shape `(batch_size,)` containing the + true label lengths. + output_length: A tensor of shape `(batch_size,)` containing the + output lengths. + mask_index: The index of the mask character in the vocabulary. + Defaults to `0`. + """ + + if any_symbolic_tensors((target, output, target_length, output_length)): + return CtcLoss(mask_index).symbolic_call( + target, output, target_length, output_length + ) + + return backend.nn.ctc_loss( + target, output, target_length, output_length, mask_index + ) diff --git a/keras/ops/nn_test.py b/keras/ops/nn_test.py index daf4bd11de8..b56df81a808 100644 --- a/keras/ops/nn_test.py +++ b/keras/ops/nn_test.py @@ -974,6 +974,17 @@ def test_batch_normalization(self): (10, 3, 4, 5), ) + @pytest.mark.skipif( + backend.backend() not in ["tensorflow", "torch"], + reason="Only TF and Torch support CTC loss", + ) + def test_ctc_loss(self): + x = KerasTensor([10, 3, 4]) + y = KerasTensor([10, 3], dtype="int32") + x_lengths = KerasTensor([10], dtype="int32") + y_lengths = KerasTensor([10], dtype="int32") + self.assertEqual(knn.ctc_loss(x, y, x_lengths, y_lengths).shape, (10,)) + class NNOpsCorrectnessTest(testing.TestCase, parameterized.TestCase): def test_relu(self): @@ -1750,6 +1761,26 @@ def test_batch_normalization(self): ) self.assertEqual(tuple(output.shape), (2, 3, 3, 5)) + @pytest.mark.skipif( + backend.backend() not in ["tensorflow", "torch"], + reason="Only TF and Torch support CTC loss", + ) + def test_ctc_loss(self): + labels = np.array([[1, 2, 1], [1, 2, 2]]) + outputs = np.array( + [ + [[0.4, 0.8, 0.4], [0.4, 0.8, 0.4]], + [[0.2, 0.8, 0.3], [0.2, 0.3, 0.3]], + [[0.9, 0.4, 0.5], [0.4, 0.3, 0.2]], + ] + ) + + label_length = np.array([3, 2]) + output_length = np.array([3, 2]) + + result = knn.ctc_loss(labels, outputs, label_length, output_length) + self.assertAllClose(result, np.array([3.4411672, 1.91680186])) + class TestLogitRecovery(testing.TestCase): def test_logit_recovery_binary_crossentropy(self):