Skip to content

Commit

Permalink
CTC loss for TF and torch (#18929)
Browse files Browse the repository at this point in the history
* Implement CTC loss in tensorflow backend

* Implement CTC api in torch backend

* Add CTC loss to keras losses

* Remove CTC from losses

* Perform log softmax in torch CTC loss

* Refactor reviewed code in CTC API

- Refactor sparse labels into main ctc_batch_cost function for tf

* Fix formatting issue in docstring

* Removed trailing space

* Naming changes in nn.ctc_loss backend functions

* Add ctc_loss keras op

* Add correctness unit test for CTC loss

* Skip test for CTC loss in JAX backend

* Update ctc_loss function to also export to ops.nn

* Add static type testing for CTC loss

* Fix enabled backends for CTC loss test

* Linting keras ops

* Fix line overflow in CtcLoss class
  • Loading branch information
MaanasArora authored Dec 14, 2023
1 parent e70f28f commit 92e7171
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 0 deletions.
52 changes: 52 additions & 0 deletions keras/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,3 +787,55 @@ def batch_normalization(
scale=scale,
variance_epsilon=epsilon,
)


def ctc_loss(
target,
output,
target_length,
output_length,
mask_index=0,
):
"""Runs CTC (Connectionist Temporal Classification) loss on each
batch element.
Arguments:
target: Tensor `(batch_size, max_target_length)` containing the
target sequences in integer format.
output: Tensor `(batch_size, max_output_length, num_classes)`
containing the output of the softmax.
target_length: Tensor `(batch_size,)` containing the sequence length
for each target sequence in the batch.
output_length: Tensor `(batch_size,)` containing the sequence length
for each output sequence in the batch.
mask_index: The value in `target` and `output` that represents the
blank label.
Returns:
A tensor of shape `(batch_size,)` containing the CTC loss for each
sample in the batch.
"""
target = tf.convert_to_tensor(target)
target = tf.cast(target, dtype="int32")
output = tf.convert_to_tensor(output)
output = tf.cast(output, dtype="float32")

max_label_len = tf.shape(target)[1]

mask = tf.sequence_mask(target_length, max_label_len)
indices = tf.where(mask)
values = tf.boolean_mask(target, mask)

sparse_target = tf.SparseTensor(
indices=indices,
values=values,
dense_shape=tf.cast(tf.shape(target), dtype="int64"),
)

return tf.nn.ctc_loss(
labels=sparse_target,
logits=output,
label_length=target_length,
logit_length=output_length,
blank_index=mask_index,
)
24 changes: 24 additions & 0 deletions keras/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,27 @@ def _batch_norm():
order.pop(1)
order.insert(axis, 1)
return x.permute(order)


def ctc_loss(
target,
output,
target_length,
output_length,
mask_index=0,
):
target = convert_to_tensor(target)
output = convert_to_tensor(output)
target_length = convert_to_tensor(target_length)
output_length = convert_to_tensor(output_length)

logits = tnn.log_softmax(output, dim=-1)

return tnn.ctc_loss(
logits,
target,
output_length,
target_length,
blank=mask_index,
reduction="none",
)
64 changes: 64 additions & 0 deletions keras/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,3 +1777,67 @@ def batch_normalization(
return backend.nn.batch_normalization(
x, mean, variance, axis, offset, scale, epsilon
)


class CtcLoss(Operation):
def __init__(self, mask_index):
super().__init__()
self.mask_index = mask_index

def call(self, target, output, target_length, output_length):
return backend.nn.ctc_loss(
target, output, target_length, output_length, self.mask_index
)

def _check_shape_first_dim(self, name1, shape1, name2, shape2):
if shape1[0] != shape2[0]:
raise ValueError(
f"Arguments `{name1}` and `{name2}` must have the same "
"first dimension. "
f"Received shapes: `{shape1}` and `{shape2}`."
)

def compute_output_spec(self, target, output, target_length, output_length):
self._check_shape_first_dim(
"target", target.shape, "output", output.shape
)
self._check_shape_first_dim(
"target_length", target_length.shape, "target", target.shape
)
self._check_shape_first_dim(
"output_length", output_length.shape, "output", output.shape
)

return KerasTensor((target.shape[0],), dtype=target.dtype)


@keras_export(
[
"keras.ops.ctc_loss",
"keras.ops.nn.ctc_loss",
]
)
def ctc_loss(target, output, target_length, output_length, mask_index=0):
"""CTC (Connectionist Temporal Classification) loss.
Args:
target: A tensor of shape `(batch_size, target_max_length)` containing
the true labels in integer format.
output: A tensor of shape `(batch_size, output_max_length, num_classes)`
containing the output from the network.
target_length: A tensor of shape `(batch_size,)` containing the
true label lengths.
output_length: A tensor of shape `(batch_size,)` containing the
output lengths.
mask_index: The index of the mask character in the vocabulary.
Defaults to `0`.
"""

if any_symbolic_tensors((target, output, target_length, output_length)):
return CtcLoss(mask_index).symbolic_call(
target, output, target_length, output_length
)

return backend.nn.ctc_loss(
target, output, target_length, output_length, mask_index
)
31 changes: 31 additions & 0 deletions keras/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,17 @@ def test_batch_normalization(self):
(10, 3, 4, 5),
)

@pytest.mark.skipif(
backend.backend() not in ["tensorflow", "torch"],
reason="Only TF and Torch support CTC loss",
)
def test_ctc_loss(self):
x = KerasTensor([10, 3, 4])
y = KerasTensor([10, 3], dtype="int32")
x_lengths = KerasTensor([10], dtype="int32")
y_lengths = KerasTensor([10], dtype="int32")
self.assertEqual(knn.ctc_loss(x, y, x_lengths, y_lengths).shape, (10,))


class NNOpsCorrectnessTest(testing.TestCase, parameterized.TestCase):
def test_relu(self):
Expand Down Expand Up @@ -1750,6 +1761,26 @@ def test_batch_normalization(self):
)
self.assertEqual(tuple(output.shape), (2, 3, 3, 5))

@pytest.mark.skipif(
backend.backend() not in ["tensorflow", "torch"],
reason="Only TF and Torch support CTC loss",
)
def test_ctc_loss(self):
labels = np.array([[1, 2, 1], [1, 2, 2]])
outputs = np.array(
[
[[0.4, 0.8, 0.4], [0.4, 0.8, 0.4]],
[[0.2, 0.8, 0.3], [0.2, 0.3, 0.3]],
[[0.9, 0.4, 0.5], [0.4, 0.3, 0.2]],
]
)

label_length = np.array([3, 2])
output_length = np.array([3, 2])

result = knn.ctc_loss(labels, outputs, label_length, output_length)
self.assertAllClose(result, np.array([3.4411672, 1.91680186]))


class TestLogitRecovery(testing.TestCase):
def test_logit_recovery_binary_crossentropy(self):
Expand Down

0 comments on commit 92e7171

Please sign in to comment.