diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index 1b9ef9fd11..de7c8ee05c 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -1,11 +1,11 @@ #![allow(clippy::single_range_in_vec_init)] use core::marker::PhantomData; -use burn_tensor::{backend::Backend, ElementConversion, Int, Tensor}; +use burn_tensor::{backend::Backend, Element, ElementConversion, Int, Numeric, Tensor}; use super::Reduction; -const NEG_INF: f32 = -10000.0; +const NEG_INF: f32 = -1e5; /// The Connectionist Temporal Classification loss. #[derive(Clone, Debug)] @@ -78,19 +78,22 @@ impl CTCLoss { let target_with_blank_length = 2 * max_target_length + 1; let targets_pad = Self::pad_target( - targets.clone(), + targets, target_lengths.clone(), max_target_length, self.blank, &device, ); + let targets_intersperse = intersperse(targets_pad.clone(), self.blank as u32); + println!("{}", targets_intersperse.clone()); + let targets_one_hot = one_hot(targets_intersperse.clone(), num_classes); - let mut log_alphas = Tensor::::empty_device( + let log_alphas = Tensor::::empty_device( [batch_size, seq_length, target_with_blank_length], &device, ); // initialize value at t0 - log_alphas = log_alphas.slice_assign( + let log_alphas = log_alphas.slice_assign( [0..batch_size, 0..1, 0..target_with_blank_length], Tensor::::full_device( [batch_size, 1, target_with_blank_length], @@ -98,86 +101,71 @@ impl CTCLoss { &device, ), ); - log_alphas = log_alphas.slice_assign( + let log_alphas = log_alphas.slice_assign( [0..batch_size, 0..1, 0..1], log_probs .clone() .slice([0..batch_size, 0..1, self.blank..(self.blank + 1)]), ); - let target_primes = Self::get_target_primes(targets_pad.clone(), 1, self.blank); - log_alphas = log_alphas.slice_assign( + let target_primes: Tensor = targets_pad + .slice([0..batch_size, 0..1]) + .reshape([batch_size, 1, 1]); + let mut log_alphas = log_alphas.slice_assign( [0..batch_size, 0..1, 1..2], log_probs .clone() .slice([0..batch_size, 0..1, 0..num_classes]) - .gather(2, target_primes.reshape([batch_size, 1, 1])), + .gather(2, target_primes), ); + let log_probs_available = targets_one_hot.matmul(log_probs.swap_dims(1, 2)); let mut neg_log_likelihood = Tensor::::zeros_device([batch_size], &device); - for s in 0..target_with_blank_length { - let current_target_primes = Self::get_target_primes(targets_pad.clone(), s, self.blank); - - for t in 1..seq_length { - // \alpha_{t-1}(s) - let la1 = log_alphas + // s != s-2 + let mask_la3 = targets_intersperse + .clone() + .slice([0..batch_size, 0..(target_with_blank_length - 2)]) + .equal(targets_intersperse.slice([0..batch_size, 2..target_with_blank_length])) + .bool_not() + .float(); + let mask_la3 = pad(mask_la3, [(0, 0), (2, 0)], 0.0).unsqueeze_dim(1); + + for t in 1..seq_length { + // \alpha_{t-1}(s) + let la1 = + log_alphas .clone() - .slice([0..batch_size, (t - 1)..t, s..(s + 1)]) - .reshape([batch_size]); - - // for the logsumexp calculation - let mut lamax = la1.clone(); - - // \alpha_{t-1}(s-1) - let mut la2 = Tensor::::full_device([batch_size], NEG_INF, &device); - if s > 0 { - la2 = log_alphas - .clone() - .slice([0..batch_size, (t - 1)..t, (s - 1)..s]) - .reshape([batch_size]); - - lamax = lamax - .clone() - .mask_where(la2.clone().greater(lamax.clone()), la2.clone()); - } - - // \alpha_{t-1}(s-2) - let mut la3 = Tensor::::full_device([batch_size], NEG_INF, &device); - if s > 1 { - la3 = la3.mask_where( - Self::get_target_primes(targets_pad.clone(), s - 2, self.blank) - .equal(current_target_primes.clone()) - .bool_not(), - log_alphas - .clone() - .slice([0..batch_size, (t - 1)..t, (s - 2)..(s - 1)]) - .reshape([batch_size]), - ); - - lamax = lamax + .slice([0..batch_size, (t - 1)..t, 0..target_with_blank_length]); + // \alpha_{t-1}(s-1) + let la2 = la1 + .clone() + .slice([0..batch_size, 0..1, 0..(target_with_blank_length - 1)]) + .clamp_min(NEG_INF); + let la2 = pad(la2, [(0, 0), (0, 0), (1, 0)], NEG_INF); + // \alpha_{t-1}(s-2) + let la3 = la1 + .clone() + .slice([0..batch_size, 0..1, 0..(target_with_blank_length - 2)]) + .clamp_min(NEG_INF); + let la3 = pad(la3, [(0, 0), (0, 0), (2, 0)], NEG_INF); + // for the logsumexp calculation + let lamax: Tensor = + Tensor::stack::<4>([la1.clone(), la2.clone(), la3.clone()].to_vec(), 3) + .max_dim(3) + .squeeze(3); + + log_alphas = log_alphas.slice_assign( + [0..batch_size, t..(t + 1), 0..target_with_blank_length], + ((la1 - lamax.clone()).exp() + + (la2 - lamax.clone()).exp() + + (la3 - lamax.clone()).exp().mul(mask_la3.clone())) + .log() + .clamp_min(NEG_INF) + + lamax + + log_probs_available .clone() - .mask_where(la3.clone().greater(lamax.clone()), la3.clone()); - } - - lamax = lamax - .clone() - .mask_fill(lamax.clone().lower_equal_elem(NEG_INF), 0.0); - - log_alphas = log_alphas.slice_assign( - [0..batch_size, t..(t + 1), s..(s + 1)], - (((la1.clone() - lamax.clone()).exp() - + (la2.clone() - lamax.clone()).exp() - + (la3.clone() - lamax.clone()).exp()) - .log() - .clamp_min(NEG_INF) - + lamax.clone() - + log_probs - .clone() - .slice([0..batch_size, t..(t + 1), 0..num_classes]) - .gather(2, current_target_primes.clone().reshape([batch_size, 1, 1])) - .reshape([batch_size])) - .reshape([batch_size, 1, 1]), - ); - } + .slice([0..batch_size, 0..target_with_blank_length, t..(t + 1)]) + .swap_dims(1, 2), + ); } let l1 = log_alphas @@ -194,7 +182,7 @@ impl CTCLoss { .clone() .gather( 1, - (input_lengths.clone() - 1) + (input_lengths - 1) .reshape([batch_size, 1, 1]) .repeat(2, target_with_blank_length), ) @@ -219,23 +207,6 @@ impl CTCLoss { } } - fn get_target_primes( - targets_pad: Tensor, - idx: usize, - blank: usize, - ) -> Tensor { - let device = targets_pad.device(); - let [batch_size, _] = targets_pad.dims(); - - if idx % 2 == 0 { - Tensor::::full_device([batch_size], blank as i32, &device) - } else { - targets_pad - .slice([0..batch_size, (idx / 2)..(idx / 2 + 1)]) - .squeeze(1) - } - } - fn pad_target( targets: Tensor, target_lengths: Tensor, @@ -313,6 +284,62 @@ impl CTCLoss { } } +fn pad( + tensor: Tensor, + pad_width: [(usize, usize); D], + fill_value: E, +) -> Tensor +where + B: Backend, + K: Numeric, + K::Elem: Element, + E: ElementConversion, +{ + let device = tensor.device(); + let origin_shape = tensor.dims(); + + let mut pad_shape = [0; D]; + let mut assign_range = Vec::with_capacity(D); + for (idx, (&origin_len, (left_pad, right_pad))) in + origin_shape.iter().zip(pad_width).enumerate() + { + pad_shape[idx] = origin_len + left_pad + right_pad; + assign_range.push(left_pad..(left_pad + origin_len)); + } + + let padded = Tensor::::full_device(pad_shape, fill_value, &device); + + padded.slice_assign::(assign_range.try_into().unwrap(), tensor) +} + +fn intersperse(tensor: Tensor, value: E) -> Tensor +where + B: Backend, + K: Numeric, + K::Elem: Element, + E: ElementConversion + Clone, +{ + let device = tensor.device(); + let mut shape = tensor.dims(); + let constants: Tensor = Tensor::full_device(shape, value.clone(), &device); + shape[1] = shape[1] * 2; + let stack = Tensor::stack::<3>([tensor, constants].to_vec(), 2).reshape(shape); + pad(stack, [(0, 0), (1, 0)], value) +} + +fn one_hot(tensor: Tensor, num_classes: usize) -> Tensor { + let device = tensor.device(); + let shape = tensor.dims(); + + let labels: Tensor = tensor.unsqueeze_dim(2).repeat(2, num_classes); + let indices = Tensor::::arange_device(0..num_classes, &device) + .reshape([1, 1, num_classes]) + .repeat(1, shape[1]) + .repeat(0, shape[0]); + + labels.equal(indices).float() +} + #[cfg(test)] mod test { use burn_tensor::Data; @@ -321,6 +348,14 @@ mod test { use super::*; + #[test] + fn test_intersperse() { + let tensor = Tensor::::arange(1..25).reshape([4, 6]); + let tensor = intersperse(tensor, 0); + + println!("{}", tensor); + } + #[test] fn test_ctc_loss() { let input = Tensor::::from_data([[