From 3365b50f195b6d25882f724193a1223021e30cc8 Mon Sep 17 00:00:00 2001 From: wcshds Date: Tue, 5 Dec 2023 00:00:39 +0800 Subject: [PATCH 01/19] implement ctc loss function --- burn-core/src/nn/loss/ctc.rs | 413 +++++++++++++++++++++++++++++++++++ burn-core/src/nn/loss/mod.rs | 2 + 2 files changed, 415 insertions(+) create mode 100644 burn-core/src/nn/loss/ctc.rs diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs new file mode 100644 index 0000000000..60fa71a262 --- /dev/null +++ b/burn-core/src/nn/loss/ctc.rs @@ -0,0 +1,413 @@ +#![allow(clippy::single_range_in_vec_init)] +use core::marker::PhantomData; + +use burn_tensor::{backend::Backend, ElementConversion, Int, Tensor}; +use half::f16; + +use super::Reduction; + +const NEG_INF: f16 = f16::NEG_INFINITY; + +/// The Connectionist Temporal Classification loss. +#[derive(Clone, Debug)] +pub struct CTCLoss { + blank: usize, + backend: PhantomData, +} + +impl Default for CTCLoss { + fn default() -> Self { + CTCLoss::new(0) + } +} + +impl CTCLoss { + /// Create the criterion. + pub fn new(blank: usize) -> Self { + Self { + blank, + backend: PhantomData, + } + } + + /// Compute the criterion on the input tensor. + /// + /// # Parameters: + /// + /// - log_probs: The logarithmized probabilities of the outputs. Shape: + /// `[batch_size, input_length, num_classes]` + /// - targets: It represent the concatenated target sequences. Each + /// element in the target sequence is a class index. And the target + /// index cannot be blank. Shape: `[target_lengths_sum]` + /// - input_lengths: It represent the lengths of the inputs. And the + /// lengths are specified for each sequence to achieve masking under + /// the assumption that sequences are padded to equal lengths. Shape: + /// `[batch_size]` + /// - target_lengths: It represent lengths of the targets. Shape: + /// `[batch_size]` + /// - reduction: Specifies the reduction to apply to the output. None: + /// no reduction will be applied; Some(Reduction::Mean): the output + /// losses will be divided by the target lengths and then the mean + /// over the batch is taken; Some(Reduction::Sum): the output losses + /// will be summed. + /// + /// # Reference + /// + /// - [PyTorch implementation](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossCTC.cpp) + /// - [Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks](https://www.cs.toronto.edu/~graves/icml_2006.pdf) + pub fn forward( + &self, + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + reduction: Option, + ) -> Tensor { + Self::assertions( + log_probs.clone(), + targets.clone(), + input_lengths.clone(), + target_lengths.clone(), + ); + + let [batch_size, seq_length, _] = log_probs.dims(); + let max_target_length = target_lengths.clone().max().into_scalar().elem::() as usize; + let target_with_blank_length = 2 * max_target_length + 1; + + let mut log_alphas = + Tensor::::zeros([batch_size, seq_length, target_with_blank_length]); + log_alphas = log_alphas.slice_assign( + [0..batch_size, 0..1, 0..target_with_blank_length], + Tensor::::full([batch_size, 1, target_with_blank_length], NEG_INF), + ); + let mut neg_log_likelihood = Tensor::::zeros([batch_size]); + + let mut target_iter = target_lengths + .clone() + .iter_dim(0) + .scan(0usize, |start, current| { + let step = current.into_scalar().elem::() as usize; + let res = targets.clone().slice([*start..(*start + step)]); + *start += step; + + Some(res) + }); + + for b in 0..batch_size { + let target_data = target_iter.next().unwrap(); + + let input_length = input_lengths + .clone() + .slice([b..(b + 1)]) + .into_scalar() + .elem::() as usize; + let [target_length] = target_data.dims(); + + log_alphas = log_alphas.slice_assign( + [b..(b + 1), 0..1, 0..1], + log_probs + .clone() + .slice([b..(b + 1), 0..1, self.blank..(self.blank + 1)]), + ); + + if target_length > 0 { + let target_prime = Self::get_target_prime(target_data.clone(), 1, self.blank); + log_alphas = log_alphas.slice_assign( + [b..(b + 1), 0..1, 0..1], + log_probs + .clone() + .slice([b..(b + 1), 0..1, target_prime..(target_prime + 1)]), + ); + } + + for t in 1..input_length { + for s in 0..(2 * target_length + 1) { + let current_target_prime = + Self::get_target_prime(target_data.clone(), s, self.blank); + + // \alpha_{t-1}(s) + let la1 = log_alphas + .clone() + .slice([b..(b + 1), (t - 1)..t, s..(s + 1)]) + .reshape([1]); + // for the logsumexp calculation + let mut lamax = la1.clone(); + // \alpha_{t-1}(s-1) + let (la2, la3); + + if s > 0 { + la2 = log_alphas + .clone() + .slice([b..(b + 1), (t - 1)..t, (s - 1)..s]) + .reshape([1]); + if la2.clone().greater(lamax.clone()).to_data().value[0] { + lamax = la2.clone(); + } + } else { + la2 = Tensor::::full([1], NEG_INF); + } + + if (s > 1) + && (Self::get_target_prime(target_data.clone(), s - 2, self.blank) + != current_target_prime) + { + // \alpha_{t-1}(s-2) + la3 = log_alphas + .clone() + .slice([b..(b + 1), (t - 1)..t, (s - 2)..(s - 1)]) + .reshape([1]); + if la3.clone().greater(lamax.clone()).to_data().value[0] { + lamax = la3.clone(); + } + } else { + la3 = Tensor::::full([1], NEG_INF); + } + + if lamax.clone().equal_elem(NEG_INF).to_data().value[0] { + lamax = Tensor::::from_floats([0.0]); + } + log_alphas = log_alphas.slice_assign( + [b..(b + 1), t..(t + 1), s..(s + 1)], + (((la1 - lamax.clone()).exp() + + (la2 - lamax.clone()).exp() + + (la3 - lamax.clone()).exp()) + .log() + + lamax + + log_probs + .clone() + .slice([ + b..(b + 1), + t..(t + 1), + current_target_prime..(current_target_prime + 1), + ]) + .reshape([1])) + .reshape([1, 1, 1]), + ); + } + } + + // the likelihood is the sum of the last two alphas, + // the loss is the negative log likelihood + if target_length == 0 { + // if the target is empty then there is no preceding BLANK + // state and hence there is no path to merge + neg_log_likelihood = neg_log_likelihood.slice_assign( + [b..(b + 1)], + -log_alphas + .clone() + .slice([b..(b + 1), (input_length - 1)..input_length, 0..1]) + .reshape([1]), + ); + } else { + let l1 = log_alphas + .clone() + .slice([ + b..(b + 1), + (input_length - 1)..input_length, + (target_length * 2)..(target_length * 2 + 1), + ]) + .reshape([1]); + let l2 = log_alphas + .clone() + .slice([ + b..(b + 1), + (input_length - 1)..input_length, + (target_length * 2 - 1)..(target_length * 2), + ]) + .reshape([1]); + // for the logsumexp calculation + let mut m = Tensor::cat(vec![l1.clone(), l2.clone()], 0).max(); + + if m.clone().equal_elem(NEG_INF).to_data().value[0] { + m = Tensor::::from_floats([0.0]) + }; + let log_likelihood = ((l1 - m.clone()).exp() + (l2 - m.clone()).exp()).log() + m; + neg_log_likelihood = neg_log_likelihood.slice_assign([b..(b + 1)], -log_likelihood); + } + } + + match reduction { + Some(Reduction::Mean) | Some(Reduction::Auto) => { + (neg_log_likelihood / target_lengths.float()).mean() + } + Some(Reduction::Sum) => neg_log_likelihood.sum(), + None => neg_log_likelihood, + } + } + + fn get_target_prime(target_data: Tensor, idx: usize, blank: usize) -> usize { + if idx % 2 == 0 { + blank + } else { + target_data + .slice([(idx / 2)..(idx / 2 + 1)]) + .into_scalar() + .elem::() as usize + } + } + + fn assertions( + log_probs: Tensor, + targets: Tensor, + input_lengths: Tensor, + target_lengths: Tensor, + ) { + let [log_probs_batch_size, input_seq_length, _] = log_probs.dims(); + let [targets_size] = targets.dims(); + let [input_lengths_size] = input_lengths.dims(); + let [target_lengths_size] = target_lengths.dims(); + + assert!( + log_probs_batch_size == input_lengths_size, + "Batch size of log_probs ({}) should correspond to size of input_lengths ({}).", + log_probs_batch_size, + input_lengths + ); + + assert!( + log_probs_batch_size == target_lengths_size, + "Batch size of log_probs ({}) should correspond to size of target_lengths ({}).", + log_probs_batch_size, + target_lengths_size + ); + + assert!( + target_lengths + .sum() + .equal_elem(targets_size as u32) + .into_data() + .value[0], + "Batch size of targets ({}) should correspond to sum of target_lengths ({}).", + log_probs_batch_size, + target_lengths_size + ); + + let max_input_length = input_lengths.max(); + assert!( + max_input_length.clone() + .lower_equal_elem(input_seq_length as u32) + .into_data() + .value[0], + "The maximum value of input_lengths ({}) must not be greater than the sequence length of log_probs ({}).", + max_input_length.into_scalar(), input_seq_length + ); + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::TestBackend; + + #[test] + fn test_ctc_loss() { + let input = Tensor::::from_data([[ + [ + -3.941, -2.116, -3.559, -2.559, -2.576, -2.445, -0.759, -3.240, -3.116, -3.221, + ], + [ + -3.001, -2.181, -2.915, -2.382, -4.597, -4.133, -3.738, -3.256, -3.291, -0.571, + ], + [ + -1.127, -3.112, -2.896, -1.613, -4.025, -2.752, -2.086, -3.241, -2.187, -3.925, + ], + [ + -1.568, -4.852, -4.101, -3.584, -1.354, -2.619, -1.798, -3.845, -2.914, -1.789, + ], + [ + -3.770, -4.748, -3.915, -0.978, -6.070, -2.430, -3.295, -2.307, -3.980, -1.119, + ], + [ + -2.117, -2.178, -2.084, -2.325, -1.426, -3.922, -2.020, -4.461, -2.366, -3.078, + ], + [ + -2.195, -1.658, -2.019, -2.959, -3.266, -3.922, -1.259, -3.566, -2.426, -2.904, + ], + [ + -2.441, -1.606, -2.835, -3.703, -1.418, -3.456, -2.504, -2.445, -1.907, -3.263, + ], + [ + -3.509, -2.281, -2.405, -4.563, -2.469, -2.816, -1.916, -2.147, -1.701, -1.736, + ], + [ + -3.313, -1.417, -2.122, -3.138, -3.365, -2.074, -3.471, -1.530, -2.885, -2.362, + ], + [ + -3.784, -0.829, -2.479, -2.101, -3.563, -2.265, -4.733, -2.501, -2.731, -3.067, + ], + [ + -2.533, -2.684, -0.890, -2.986, -3.694, -3.484, -2.270, -2.169, -2.913, -2.751, + ], + [ + -3.435, -2.567, -2.526, -1.183, -3.210, -2.538, -1.184, -3.352, -3.935, -3.704, + ], + [ + -3.139, -2.204, -0.668, -5.249, -3.855, -3.706, -2.839, -1.971, -2.852, -3.608, + ], + [ + -1.445, -2.020, -3.576, -3.153, -2.949, -2.717, -3.902, -3.726, -1.594, -1.635, + ], + [ + -1.596, -4.902, -4.364, -4.571, -1.465, -3.689, -1.751, -2.032, -1.945, -2.764, + ], + [ + -3.326, -2.239, -2.965, -1.831, -2.958, -1.912, -1.695, -1.932, -2.353, -3.791, + ], + [ + -3.372, -2.850, -2.342, -0.841, -2.754, -3.297, -3.610, -2.152, -2.611, -2.760, + ], + [ + -2.843, -3.622, -1.551, -4.361, -4.325, -0.975, -3.459, -2.004, -2.758, -2.658, + ], + [ + -2.094, -3.114, -0.915, -3.207, -2.865, -2.215, -3.892, -4.120, -2.113, -2.693, + ], + [ + -3.049, -2.809, -3.370, -2.358, -2.038, -1.879, -1.957, -3.337, -2.198, -1.648, + ], + [ + -4.449, -2.300, -2.324, -3.414, -2.296, -1.620, -3.738, -2.128, -1.276, -3.311, + ], + [ + -2.133, -2.854, -2.711, -3.328, -3.735, -3.705, -0.627, -3.701, -4.156, -2.319, + ], + [ + -3.160, -3.321, -1.590, -3.735, -1.640, -3.614, -2.270, -1.911, -2.099, -2.314, + ], + [ + -3.044, -3.279, -1.939, -2.554, -2.272, -1.209, -2.627, -3.025, -2.187, -2.837, + ], + [ + -3.209, -3.186, -3.113, -2.002, -2.527, -2.561, -3.697, -2.347, -1.694, -1.282, + ], + [ + -1.297, -2.826, -2.052, -2.534, -2.544, -3.318, -2.015, -3.384, -2.755, -2.171, + ], + [ + -2.774, -2.740, -1.453, -3.754, -2.903, -2.309, -2.528, -1.664, -2.338, -2.345, + ], + [ + -3.036, -2.509, -0.726, -2.385, -4.339, -4.286, -3.388, -3.196, -3.755, -1.772, + ], + [ + -3.222, -3.674, -2.348, -2.324, -3.065, -2.748, -0.912, -2.595, -1.952, -4.408, + ], + ]]); + let target = Tensor::::from_data([3, 4, 7, 6, 3, 7, 3, 6, 2]); + let input_lengths = Tensor::::from_data([30]); + let target_lengths = Tensor::::from_data([9]); + let _expected_res = 47.73889923095703; + + let ctc_loss = CTCLoss::::new(0); + let res = ctc_loss.forward( + input, + target, + input_lengths, + target_lengths, + Some(Reduction::Sum), + ); + + // 47.061913 + res.to_data().assert_within_range(47..49); + } +} diff --git a/burn-core/src/nn/loss/mod.rs b/burn-core/src/nn/loss/mod.rs index 5b37df84f2..c8bb454cb9 100644 --- a/burn-core/src/nn/loss/mod.rs +++ b/burn-core/src/nn/loss/mod.rs @@ -1,9 +1,11 @@ mod binary_cross_entropy; mod cross_entropy; +mod ctc; mod mse; mod reduction; pub use binary_cross_entropy::*; pub use cross_entropy::*; +pub use ctc::*; pub use mse::*; pub use reduction::*; From f92b9451fcc32b474dbde09ba93749aeb1d8b1cb Mon Sep 17 00:00:00 2001 From: wcshds Date: Tue, 5 Dec 2023 00:24:03 +0800 Subject: [PATCH 02/19] remove vec! macro --- burn-core/src/nn/loss/ctc.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index 60fa71a262..a44cf127e4 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -216,7 +216,7 @@ impl CTCLoss { ]) .reshape([1]); // for the logsumexp calculation - let mut m = Tensor::cat(vec![l1.clone(), l2.clone()], 0).max(); + let mut m = Tensor::cat([l1.clone(), l2.clone()].to_vec(), 0).max(); if m.clone().equal_elem(NEG_INF).to_data().value[0] { m = Tensor::::from_floats([0.0]) From 4ddfbb3aebbd95483531dbc8947232b08cb9ceec Mon Sep 17 00:00:00 2001 From: wcshds Date: Tue, 5 Dec 2023 12:29:51 +0800 Subject: [PATCH 03/19] fix wrong indice when assign initial val to alpha --- burn-core/src/nn/loss/ctc.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index a44cf127e4..65c01e42a5 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -75,7 +75,7 @@ impl CTCLoss { let target_with_blank_length = 2 * max_target_length + 1; let mut log_alphas = - Tensor::::zeros([batch_size, seq_length, target_with_blank_length]); + Tensor::::empty([batch_size, seq_length, target_with_blank_length]); log_alphas = log_alphas.slice_assign( [0..batch_size, 0..1, 0..target_with_blank_length], Tensor::::full([batch_size, 1, target_with_blank_length], NEG_INF), @@ -113,7 +113,7 @@ impl CTCLoss { if target_length > 0 { let target_prime = Self::get_target_prime(target_data.clone(), 1, self.blank); log_alphas = log_alphas.slice_assign( - [b..(b + 1), 0..1, 0..1], + [b..(b + 1), 0..1, 1..2], log_probs .clone() .slice([b..(b + 1), 0..1, target_prime..(target_prime + 1)]), @@ -296,6 +296,8 @@ impl CTCLoss { #[cfg(test)] mod test { + use burn_tensor::Data; + use super::*; use crate::TestBackend; @@ -396,7 +398,7 @@ mod test { let target = Tensor::::from_data([3, 4, 7, 6, 3, 7, 3, 6, 2]); let input_lengths = Tensor::::from_data([30]); let target_lengths = Tensor::::from_data([9]); - let _expected_res = 47.73889923095703; + let expected_res = Data::from([47.73889923095703]); let ctc_loss = CTCLoss::::new(0); let res = ctc_loss.forward( @@ -407,7 +409,7 @@ mod test { Some(Reduction::Sum), ); - // 47.061913 - res.to_data().assert_within_range(47..49); + // 47.7376 + res.to_data().assert_approx_eq(&expected_res, 2); } } From 92d694d51d2fc7ecd4d8503b661cb4a8bd387180 Mon Sep 17 00:00:00 2001 From: wcshds Date: Tue, 5 Dec 2023 12:57:00 +0800 Subject: [PATCH 04/19] update test case --- burn-core/src/nn/loss/ctc.rs | 73 ++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 36 deletions(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index 65c01e42a5..9ee061828c 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -298,107 +298,108 @@ impl CTCLoss { mod test { use burn_tensor::Data; - use super::*; use crate::TestBackend; + use super::*; + #[test] fn test_ctc_loss() { let input = Tensor::::from_data([[ [ - -3.941, -2.116, -3.559, -2.559, -2.576, -2.445, -0.759, -3.240, -3.116, -3.221, + -0.785, -3.471, -2.531, -3.948, -2.373, -3.042, -2.029, -2.255, -4.228, -3.810, ], [ - -3.001, -2.181, -2.915, -2.382, -4.597, -4.133, -3.738, -3.256, -3.291, -0.571, + -3.548, -1.692, -0.967, -2.519, -2.806, -2.760, -2.434, -2.762, -3.638, -3.669, ], [ - -1.127, -3.112, -2.896, -1.613, -4.025, -2.752, -2.086, -3.241, -2.187, -3.925, + -3.904, -1.799, -1.312, -2.530, -2.267, -3.169, -3.838, -2.073, -2.484, -2.418, ], [ - -1.568, -4.852, -4.101, -3.584, -1.354, -2.619, -1.798, -3.845, -2.914, -1.789, + -0.890, -2.506, -3.405, -3.038, -2.483, -2.861, -2.749, -3.086, -1.960, -3.336, ], [ - -3.770, -4.748, -3.915, -0.978, -6.070, -2.430, -3.295, -2.307, -3.980, -1.119, + -1.113, -3.557, -2.580, -1.465, -3.884, -1.993, -3.574, -3.466, -2.669, -2.985, ], [ - -2.117, -2.178, -2.084, -2.325, -1.426, -3.922, -2.020, -4.461, -2.366, -3.078, + -3.948, -0.828, -1.805, -2.842, -2.767, -3.891, -2.825, -1.783, -5.566, -5.072, ], [ - -2.195, -1.658, -2.019, -2.959, -3.266, -3.922, -1.259, -3.566, -2.426, -2.904, + -1.677, -1.703, -4.191, -3.862, -1.726, -2.616, -2.366, -2.324, -2.767, -2.418, ], [ - -2.441, -1.606, -2.835, -3.703, -1.418, -3.456, -2.504, -2.445, -1.907, -3.263, + -1.511, -1.125, -3.526, -3.007, -2.975, -3.358, -2.037, -2.093, -4.137, -3.900, ], [ - -3.509, -2.281, -2.405, -4.563, -2.469, -2.816, -1.916, -2.147, -1.701, -1.736, + -1.850, -2.767, -1.718, -2.185, -2.890, -1.998, -3.661, -3.997, -2.738, -1.671, ], [ - -3.313, -1.417, -2.122, -3.138, -3.365, -2.074, -3.471, -1.530, -2.885, -2.362, + -2.621, -1.234, -3.499, -3.494, -1.612, -1.713, -2.179, -2.884, -4.122, -4.581, ], [ - -3.784, -0.829, -2.479, -2.101, -3.563, -2.265, -4.733, -2.501, -2.731, -3.067, + -1.519, -3.283, -1.287, -3.217, -2.544, -3.128, -2.061, -3.039, -2.388, -3.272, ], [ - -2.533, -2.684, -0.890, -2.986, -3.694, -3.484, -2.270, -2.169, -2.913, -2.751, + -1.112, -1.258, -3.206, -3.103, -3.918, -2.577, -4.399, -4.488, -2.187, -2.663, ], [ - -3.435, -2.567, -2.526, -1.183, -3.210, -2.538, -1.184, -3.352, -3.935, -3.704, + -1.889, -2.344, -3.232, -2.781, -3.312, -0.911, -2.864, -4.825, -3.180, -2.243, ], [ - -3.139, -2.204, -0.668, -5.249, -3.855, -3.706, -2.839, -1.971, -2.852, -3.608, + -4.368, -1.471, -1.308, -2.950, -3.211, -2.692, -1.923, -2.020, -3.859, -3.601, ], [ - -1.445, -2.020, -3.576, -3.153, -2.949, -2.717, -3.902, -3.726, -1.594, -1.635, + -4.254, -3.291, -1.539, -2.622, -2.281, -1.427, -1.712, -3.082, -2.653, -3.809, ], [ - -1.596, -4.902, -4.364, -4.571, -1.465, -3.689, -1.751, -2.032, -1.945, -2.764, + -3.322, -2.904, -0.942, -3.157, -2.987, -3.736, -1.208, -4.155, -4.383, -2.583, ], [ - -3.326, -2.239, -2.965, -1.831, -2.958, -1.912, -1.695, -1.932, -2.353, -3.791, + -2.827, -2.293, -3.109, -3.196, -3.297, -2.451, -2.136, -3.423, -1.012, -2.146, ], [ - -3.372, -2.850, -2.342, -0.841, -2.754, -3.297, -3.610, -2.152, -2.611, -2.760, + -1.803, -1.666, -1.780, -4.024, -3.083, -4.520, -2.674, -2.527, -3.365, -1.516, ], [ - -2.843, -3.622, -1.551, -4.361, -4.325, -0.975, -3.459, -2.004, -2.758, -2.658, + -2.199, -2.340, -2.009, -3.736, -3.363, -2.721, -2.350, -1.951, -1.815, -2.009, ], [ - -2.094, -3.114, -0.915, -3.207, -2.865, -2.215, -3.892, -4.120, -2.113, -2.693, + -1.721, -3.726, -1.701, -3.503, -2.153, -3.242, -2.284, -1.838, -2.646, -2.329, ], [ - -3.049, -2.809, -3.370, -2.358, -2.038, -1.879, -1.957, -3.337, -2.198, -1.648, + -3.655, -2.916, -2.913, -1.197, -3.060, -2.154, -1.776, -3.404, -1.823, -3.310, ], [ - -4.449, -2.300, -2.324, -3.414, -2.296, -1.620, -3.738, -2.128, -1.276, -3.311, + -2.671, -2.592, -2.929, -1.416, -2.007, -2.886, -2.781, -2.597, -1.738, -2.862, ], [ - -2.133, -2.854, -2.711, -3.328, -3.735, -3.705, -0.627, -3.701, -4.156, -2.319, + -1.686, -4.173, -0.884, -5.493, -5.498, -1.707, -3.573, -5.085, -2.060, -3.352, ], [ - -3.160, -3.321, -1.590, -3.735, -1.640, -3.614, -2.270, -1.911, -2.099, -2.314, + -2.114, -2.478, -2.178, -3.457, -3.264, -2.659, -2.653, -1.222, -2.375, -2.475, ], [ - -3.044, -3.279, -1.939, -2.554, -2.272, -1.209, -2.627, -3.025, -2.187, -2.837, + -2.136, -3.563, -2.325, -3.081, -2.035, -3.154, -1.122, -3.486, -1.951, -3.270, ], [ - -3.209, -3.186, -3.113, -2.002, -2.527, -2.561, -3.697, -2.347, -1.694, -1.282, + -3.206, -3.031, -3.913, -2.652, -2.985, -2.635, -1.153, -3.122, -3.256, -1.203, ], [ - -1.297, -2.826, -2.052, -2.534, -2.544, -3.318, -2.015, -3.384, -2.755, -2.171, + -2.104, -1.719, -2.141, -2.695, -2.448, -2.991, -1.542, -2.646, -3.090, -3.066, ], [ - -2.774, -2.740, -1.453, -3.754, -2.903, -2.309, -2.528, -1.664, -2.338, -2.345, + -3.320, -5.098, -1.085, -1.335, -2.588, -3.098, -2.466, -2.951, -3.911, -2.538, ], [ - -3.036, -2.509, -0.726, -2.385, -4.339, -4.286, -3.388, -3.196, -3.755, -1.772, + -3.756, -1.814, -2.752, -2.410, -3.305, -2.387, -2.112, -1.720, -2.616, -1.843, ], [ - -3.222, -3.674, -2.348, -2.324, -3.065, -2.748, -0.912, -2.595, -1.952, -4.408, + -3.985, -2.489, -2.305, -1.454, -2.533, -5.091, -1.759, -2.180, -3.673, -1.779, ], ]]); - let target = Tensor::::from_data([3, 4, 7, 6, 3, 7, 3, 6, 2]); + let target = Tensor::::from_data([1, 9, 6, 9, 4]); let input_lengths = Tensor::::from_data([30]); - let target_lengths = Tensor::::from_data([9]); - let expected_res = Data::from([47.73889923095703]); + let target_lengths = Tensor::::from_data([5]); + let expected_res = Data::from([50.3788948059082]); let ctc_loss = CTCLoss::::new(0); let res = ctc_loss.forward( @@ -409,7 +410,7 @@ mod test { Some(Reduction::Sum), ); - // 47.7376 - res.to_data().assert_approx_eq(&expected_res, 2); + // 50.3789 + res.to_data().assert_approx_eq(&expected_res, 3); } } From 71fc4df79b456c0aecbefc0d80e18c3cb079e25b Mon Sep 17 00:00:00 2001 From: wcshds Date: Wed, 6 Dec 2023 13:20:49 +0800 Subject: [PATCH 05/19] remove batch loop --- burn-core/src/nn/loss/ctc.rs | 282 ++++++++++++++++++----------------- 1 file changed, 144 insertions(+), 138 deletions(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index 9ee061828c..e05967d698 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -2,11 +2,10 @@ use core::marker::PhantomData; use burn_tensor::{backend::Backend, ElementConversion, Int, Tensor}; -use half::f16; use super::Reduction; -const NEG_INF: f16 = f16::NEG_INFINITY; +const NEG_INF: f32 = -10000.0; /// The Connectionist Temporal Classification loss. #[derive(Clone, Debug)] @@ -70,162 +69,139 @@ impl CTCLoss { target_lengths.clone(), ); - let [batch_size, seq_length, _] = log_probs.dims(); + let [batch_size, seq_length, num_classes] = log_probs.dims(); let max_target_length = target_lengths.clone().max().into_scalar().elem::() as usize; let target_with_blank_length = 2 * max_target_length + 1; + let targets_pad = Self::pad_target( + targets.clone(), + target_lengths.clone(), + max_target_length, + self.blank, + ); + let mut log_alphas = Tensor::::empty([batch_size, seq_length, target_with_blank_length]); + // initialize value at t0 log_alphas = log_alphas.slice_assign( [0..batch_size, 0..1, 0..target_with_blank_length], Tensor::::full([batch_size, 1, target_with_blank_length], NEG_INF), ); + log_alphas = log_alphas.slice_assign( + [0..batch_size, 0..1, 0..1], + log_probs + .clone() + .slice([0..batch_size, 0..1, self.blank..(self.blank + 1)]), + ); + let target_primes = Self::get_target_primes(targets_pad.clone(), 1, self.blank); + log_alphas = log_alphas.slice_assign( + [0..batch_size, 0..1, 1..2], + log_probs + .clone() + .slice([0..batch_size, 0..1, 0..num_classes]) + .gather(2, target_primes.reshape([batch_size, 1, 1])), + ); let mut neg_log_likelihood = Tensor::::zeros([batch_size]); - let mut target_iter = target_lengths - .clone() - .iter_dim(0) - .scan(0usize, |start, current| { - let step = current.into_scalar().elem::() as usize; - let res = targets.clone().slice([*start..(*start + step)]); - *start += step; - - Some(res) - }); + for t in 1..seq_length { + for s in 0..target_with_blank_length { + let current_target_prime = + Self::get_target_primes(targets_pad.clone(), s, self.blank); - for b in 0..batch_size { - let target_data = target_iter.next().unwrap(); - - let input_length = input_lengths - .clone() - .slice([b..(b + 1)]) - .into_scalar() - .elem::() as usize; - let [target_length] = target_data.dims(); - - log_alphas = log_alphas.slice_assign( - [b..(b + 1), 0..1, 0..1], - log_probs + // \alpha_{t-1}(s) + let la1 = log_alphas .clone() - .slice([b..(b + 1), 0..1, self.blank..(self.blank + 1)]), - ); + .slice([0..batch_size, (t - 1)..t, s..(s + 1)]) + .reshape([batch_size]); - if target_length > 0 { - let target_prime = Self::get_target_prime(target_data.clone(), 1, self.blank); - log_alphas = log_alphas.slice_assign( - [b..(b + 1), 0..1, 1..2], - log_probs - .clone() - .slice([b..(b + 1), 0..1, target_prime..(target_prime + 1)]), - ); - } + // for the logsumexp calculation + let mut lamax = la1.clone(); - for t in 1..input_length { - for s in 0..(2 * target_length + 1) { - let current_target_prime = - Self::get_target_prime(target_data.clone(), s, self.blank); + // \alpha_{t-1}(s-1) + let mut la2 = Tensor::::full([batch_size], NEG_INF); + if s > 0 { + la2 = log_alphas + .clone() + .slice([0..batch_size, (t - 1)..t, (s - 1)..s]) + .reshape([batch_size]); - // \alpha_{t-1}(s) - let la1 = log_alphas + lamax = lamax .clone() - .slice([b..(b + 1), (t - 1)..t, s..(s + 1)]) - .reshape([1]); - // for the logsumexp calculation - let mut lamax = la1.clone(); - // \alpha_{t-1}(s-1) - let (la2, la3); - - if s > 0 { - la2 = log_alphas - .clone() - .slice([b..(b + 1), (t - 1)..t, (s - 1)..s]) - .reshape([1]); - if la2.clone().greater(lamax.clone()).to_data().value[0] { - lamax = la2.clone(); - } - } else { - la2 = Tensor::::full([1], NEG_INF); - } - - if (s > 1) - && (Self::get_target_prime(target_data.clone(), s - 2, self.blank) - != current_target_prime) - { - // \alpha_{t-1}(s-2) - la3 = log_alphas + .mask_where(la2.clone().greater(lamax.clone()), la2.clone()); + } + + let mut la3 = Tensor::::full([batch_size], NEG_INF); + if s > 1 { + // \alpha_{t-1}(s-2) + la3 = la3.mask_where( + Self::get_target_primes(targets_pad.clone(), s - 2, self.blank) + .equal(current_target_prime.clone()) + .bool_not(), + log_alphas .clone() - .slice([b..(b + 1), (t - 1)..t, (s - 2)..(s - 1)]) - .reshape([1]); - if la3.clone().greater(lamax.clone()).to_data().value[0] { - lamax = la3.clone(); - } - } else { - la3 = Tensor::::full([1], NEG_INF); - } - - if lamax.clone().equal_elem(NEG_INF).to_data().value[0] { - lamax = Tensor::::from_floats([0.0]); - } - log_alphas = log_alphas.slice_assign( - [b..(b + 1), t..(t + 1), s..(s + 1)], - (((la1 - lamax.clone()).exp() - + (la2 - lamax.clone()).exp() - + (la3 - lamax.clone()).exp()) - .log() - + lamax - + log_probs - .clone() - .slice([ - b..(b + 1), - t..(t + 1), - current_target_prime..(current_target_prime + 1), - ]) - .reshape([1])) - .reshape([1, 1, 1]), + .slice([0..batch_size, (t - 1)..t, (s - 2)..(s - 1)]) + .reshape([batch_size]), ); - } - } - // the likelihood is the sum of the last two alphas, - // the loss is the negative log likelihood - if target_length == 0 { - // if the target is empty then there is no preceding BLANK - // state and hence there is no path to merge - neg_log_likelihood = neg_log_likelihood.slice_assign( - [b..(b + 1)], - -log_alphas + lamax = lamax .clone() - .slice([b..(b + 1), (input_length - 1)..input_length, 0..1]) - .reshape([1]), - ); - } else { - let l1 = log_alphas - .clone() - .slice([ - b..(b + 1), - (input_length - 1)..input_length, - (target_length * 2)..(target_length * 2 + 1), - ]) - .reshape([1]); - let l2 = log_alphas + .mask_where(la3.clone().greater(lamax.clone()), la3.clone()); + } + + lamax = lamax .clone() - .slice([ - b..(b + 1), - (input_length - 1)..input_length, - (target_length * 2 - 1)..(target_length * 2), - ]) - .reshape([1]); - // for the logsumexp calculation - let mut m = Tensor::cat([l1.clone(), l2.clone()].to_vec(), 0).max(); + .mask_fill(lamax.clone().lower_equal_elem(NEG_INF), 0.0); - if m.clone().equal_elem(NEG_INF).to_data().value[0] { - m = Tensor::::from_floats([0.0]) - }; - let log_likelihood = ((l1 - m.clone()).exp() + (l2 - m.clone()).exp()).log() + m; - neg_log_likelihood = neg_log_likelihood.slice_assign([b..(b + 1)], -log_likelihood); + log_alphas = log_alphas.slice_assign( + [0..batch_size, t..(t + 1), s..(s + 1)], + (((la1.clone() - lamax.clone()).exp() + + (la2.clone() - lamax.clone()).exp() + + (la3.clone() - lamax.clone()).exp()) + .log() + .clamp_min(NEG_INF) + + lamax.clone() + + log_probs + .clone() + .slice([0..batch_size, t..(t + 1), 0..num_classes]) + .gather(2, current_target_prime.clone().reshape([batch_size, 1, 1])) + .reshape([batch_size])) + .reshape([batch_size, 1, 1]), + ); } } + let l1 = log_alphas + .clone() + .gather( + 1, + (input_lengths.clone() - 1) + .reshape([batch_size, 1, 1]) + .repeat(2, target_with_blank_length), + ) + .gather(2, (target_lengths.clone() * 2).reshape([batch_size, 1, 1])) + .reshape([batch_size]); + let l2 = log_alphas + .clone() + .gather( + 1, + (input_lengths.clone() - 1) + .reshape([batch_size, 1, 1]) + .repeat(2, target_with_blank_length), + ) + .gather( + 2, + (target_lengths.clone() * 2 - 1).reshape([batch_size, 1, 1]), + ) + .reshape([batch_size]); + // for the logsumexp calculation + let mut m = Tensor::cat([l1.clone(), l2.clone()].to_vec(), 0).max(); + + if m.clone().lower_equal_elem(NEG_INF).to_data().value[0] { + m = Tensor::::from_floats([0.0]) + }; + let log_likelihood = ((l1 - m.clone()).exp() + (l2 - m.clone()).exp()).log() + m; + neg_log_likelihood = neg_log_likelihood.slice_assign([0..batch_size], -log_likelihood); + match reduction { Some(Reduction::Mean) | Some(Reduction::Auto) => { (neg_log_likelihood / target_lengths.float()).mean() @@ -235,17 +211,47 @@ impl CTCLoss { } } - fn get_target_prime(target_data: Tensor, idx: usize, blank: usize) -> usize { + fn get_target_primes( + targets_pad: Tensor, + idx: usize, + blank: usize, + ) -> Tensor { + let [batch_size, _] = targets_pad.dims(); + if idx % 2 == 0 { - blank + Tensor::::full([batch_size], blank as i32) } else { - target_data - .slice([(idx / 2)..(idx / 2 + 1)]) - .into_scalar() - .elem::() as usize + targets_pad + .slice([0..batch_size, (idx / 2)..(idx / 2 + 1)]) + .squeeze(1) } } + fn pad_target( + targets: Tensor, + target_lengths: Tensor, + max_target_length: usize, + blank: usize, + ) -> Tensor { + let [batch_size] = target_lengths.dims(); + + let mut targets_pad = + Tensor::::full([batch_size, max_target_length], blank as i32); + let mut start = 0usize; + for (batch, length) in target_lengths.iter_dim(0).enumerate() { + let length = length.into_scalar().elem::() as usize; + + targets_pad = targets_pad.clone().slice_assign( + [batch..(batch + 1), 0..length], + targets.clone().slice([start..(start + length)]).unsqueeze(), + ); + + start += length + } + + targets_pad + } + fn assertions( log_probs: Tensor, targets: Tensor, From 0813605c355481271aac0b0fda222b64901f8928 Mon Sep 17 00:00:00 2001 From: wcshds Date: Wed, 6 Dec 2023 14:24:10 +0800 Subject: [PATCH 06/19] cache current_target_primes --- burn-core/src/nn/loss/ctc.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index e05967d698..2f50540655 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -103,11 +103,10 @@ impl CTCLoss { ); let mut neg_log_likelihood = Tensor::::zeros([batch_size]); - for t in 1..seq_length { - for s in 0..target_with_blank_length { - let current_target_prime = - Self::get_target_primes(targets_pad.clone(), s, self.blank); + for s in 0..target_with_blank_length { + let current_target_primes = Self::get_target_primes(targets_pad.clone(), s, self.blank); + for t in 1..seq_length { // \alpha_{t-1}(s) let la1 = log_alphas .clone() @@ -135,7 +134,7 @@ impl CTCLoss { // \alpha_{t-1}(s-2) la3 = la3.mask_where( Self::get_target_primes(targets_pad.clone(), s - 2, self.blank) - .equal(current_target_prime.clone()) + .equal(current_target_primes.clone()) .bool_not(), log_alphas .clone() @@ -163,7 +162,7 @@ impl CTCLoss { + log_probs .clone() .slice([0..batch_size, t..(t + 1), 0..num_classes]) - .gather(2, current_target_prime.clone().reshape([batch_size, 1, 1])) + .gather(2, current_target_primes.clone().reshape([batch_size, 1, 1])) .reshape([batch_size])) .reshape([batch_size, 1, 1]), ); From dd9cf8bc497df9de8c961eb9a473968933be03a5 Mon Sep 17 00:00:00 2001 From: wcshds Date: Wed, 6 Dec 2023 19:54:58 +0800 Subject: [PATCH 07/19] make sure tensors are on the same device --- burn-core/src/nn/loss/ctc.rs | 38 +++++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index 2f50540655..28a6e715ca 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -69,6 +69,10 @@ impl CTCLoss { target_lengths.clone(), ); + let device = log_probs.device(); + let input_lengths = input_lengths.to_device(&device); + let target_lengths = target_lengths.to_device(&device); + let [batch_size, seq_length, num_classes] = log_probs.dims(); let max_target_length = target_lengths.clone().max().into_scalar().elem::() as usize; let target_with_blank_length = 2 * max_target_length + 1; @@ -78,14 +82,21 @@ impl CTCLoss { target_lengths.clone(), max_target_length, self.blank, + &device, ); - let mut log_alphas = - Tensor::::empty([batch_size, seq_length, target_with_blank_length]); + let mut log_alphas = Tensor::::empty_device( + [batch_size, seq_length, target_with_blank_length], + &device, + ); // initialize value at t0 log_alphas = log_alphas.slice_assign( [0..batch_size, 0..1, 0..target_with_blank_length], - Tensor::::full([batch_size, 1, target_with_blank_length], NEG_INF), + Tensor::::full_device( + [batch_size, 1, target_with_blank_length], + NEG_INF, + &device, + ), ); log_alphas = log_alphas.slice_assign( [0..batch_size, 0..1, 0..1], @@ -101,7 +112,7 @@ impl CTCLoss { .slice([0..batch_size, 0..1, 0..num_classes]) .gather(2, target_primes.reshape([batch_size, 1, 1])), ); - let mut neg_log_likelihood = Tensor::::zeros([batch_size]); + let mut neg_log_likelihood = Tensor::::zeros_device([batch_size], &device); for s in 0..target_with_blank_length { let current_target_primes = Self::get_target_primes(targets_pad.clone(), s, self.blank); @@ -117,7 +128,7 @@ impl CTCLoss { let mut lamax = la1.clone(); // \alpha_{t-1}(s-1) - let mut la2 = Tensor::::full([batch_size], NEG_INF); + let mut la2 = Tensor::::full_device([batch_size], NEG_INF, &device); if s > 0 { la2 = log_alphas .clone() @@ -129,9 +140,9 @@ impl CTCLoss { .mask_where(la2.clone().greater(lamax.clone()), la2.clone()); } - let mut la3 = Tensor::::full([batch_size], NEG_INF); + // \alpha_{t-1}(s-2) + let mut la3 = Tensor::::full_device([batch_size], NEG_INF, &device); if s > 1 { - // \alpha_{t-1}(s-2) la3 = la3.mask_where( Self::get_target_primes(targets_pad.clone(), s - 2, self.blank) .equal(current_target_primes.clone()) @@ -196,7 +207,7 @@ impl CTCLoss { let mut m = Tensor::cat([l1.clone(), l2.clone()].to_vec(), 0).max(); if m.clone().lower_equal_elem(NEG_INF).to_data().value[0] { - m = Tensor::::from_floats([0.0]) + m = Tensor::::full_device([1], 0.0, &device); }; let log_likelihood = ((l1 - m.clone()).exp() + (l2 - m.clone()).exp()).log() + m; neg_log_likelihood = neg_log_likelihood.slice_assign([0..batch_size], -log_likelihood); @@ -215,10 +226,11 @@ impl CTCLoss { idx: usize, blank: usize, ) -> Tensor { + let device = targets_pad.device(); let [batch_size, _] = targets_pad.dims(); if idx % 2 == 0 { - Tensor::::full([batch_size], blank as i32) + Tensor::::full_device([batch_size], blank as i32, &device) } else { targets_pad .slice([0..batch_size, (idx / 2)..(idx / 2 + 1)]) @@ -231,11 +243,15 @@ impl CTCLoss { target_lengths: Tensor, max_target_length: usize, blank: usize, + device: &B::Device, ) -> Tensor { let [batch_size] = target_lengths.dims(); - let mut targets_pad = - Tensor::::full([batch_size, max_target_length], blank as i32); + let mut targets_pad = Tensor::::full_device( + [batch_size, max_target_length], + blank as i32, + &device, + ); let mut start = 0usize; for (batch, length) in target_lengths.iter_dim(0).enumerate() { let length = length.into_scalar().elem::() as usize; From 94d52d2ed5053c96e8a3828d7a48a2fd39982e37 Mon Sep 17 00:00:00 2001 From: wcshds Date: Wed, 6 Dec 2023 20:10:28 +0800 Subject: [PATCH 08/19] use clamp_min instead of mask_where --- burn-core/src/nn/loss/ctc.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index 28a6e715ca..1b9ef9fd11 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -203,12 +203,10 @@ impl CTCLoss { (target_lengths.clone() * 2 - 1).reshape([batch_size, 1, 1]), ) .reshape([batch_size]); - // for the logsumexp calculation - let mut m = Tensor::cat([l1.clone(), l2.clone()].to_vec(), 0).max(); - if m.clone().lower_equal_elem(NEG_INF).to_data().value[0] { - m = Tensor::::full_device([1], 0.0, &device); - }; + // for the logsumexp calculation + let m = Tensor::cat([l1.clone(), l2.clone()].to_vec(), 0).max(); + let m = m.clone().clamp_min(NEG_INF); let log_likelihood = ((l1 - m.clone()).exp() + (l2 - m.clone()).exp()).log() + m; neg_log_likelihood = neg_log_likelihood.slice_assign([0..batch_size], -log_likelihood); From 712185292c8fb79a6fbc73d55ed719bbf7859a92 Mon Sep 17 00:00:00 2001 From: wcshds Date: Fri, 8 Dec 2023 18:24:10 +0800 Subject: [PATCH 09/19] remove squence loop --- burn-core/src/nn/loss/ctc.rs | 213 ++++++++++++++++++++--------------- 1 file changed, 124 insertions(+), 89 deletions(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index 1b9ef9fd11..de7c8ee05c 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -1,11 +1,11 @@ #![allow(clippy::single_range_in_vec_init)] use core::marker::PhantomData; -use burn_tensor::{backend::Backend, ElementConversion, Int, Tensor}; +use burn_tensor::{backend::Backend, Element, ElementConversion, Int, Numeric, Tensor}; use super::Reduction; -const NEG_INF: f32 = -10000.0; +const NEG_INF: f32 = -1e5; /// The Connectionist Temporal Classification loss. #[derive(Clone, Debug)] @@ -78,19 +78,22 @@ impl CTCLoss { let target_with_blank_length = 2 * max_target_length + 1; let targets_pad = Self::pad_target( - targets.clone(), + targets, target_lengths.clone(), max_target_length, self.blank, &device, ); + let targets_intersperse = intersperse(targets_pad.clone(), self.blank as u32); + println!("{}", targets_intersperse.clone()); + let targets_one_hot = one_hot(targets_intersperse.clone(), num_classes); - let mut log_alphas = Tensor::::empty_device( + let log_alphas = Tensor::::empty_device( [batch_size, seq_length, target_with_blank_length], &device, ); // initialize value at t0 - log_alphas = log_alphas.slice_assign( + let log_alphas = log_alphas.slice_assign( [0..batch_size, 0..1, 0..target_with_blank_length], Tensor::::full_device( [batch_size, 1, target_with_blank_length], @@ -98,86 +101,71 @@ impl CTCLoss { &device, ), ); - log_alphas = log_alphas.slice_assign( + let log_alphas = log_alphas.slice_assign( [0..batch_size, 0..1, 0..1], log_probs .clone() .slice([0..batch_size, 0..1, self.blank..(self.blank + 1)]), ); - let target_primes = Self::get_target_primes(targets_pad.clone(), 1, self.blank); - log_alphas = log_alphas.slice_assign( + let target_primes: Tensor = targets_pad + .slice([0..batch_size, 0..1]) + .reshape([batch_size, 1, 1]); + let mut log_alphas = log_alphas.slice_assign( [0..batch_size, 0..1, 1..2], log_probs .clone() .slice([0..batch_size, 0..1, 0..num_classes]) - .gather(2, target_primes.reshape([batch_size, 1, 1])), + .gather(2, target_primes), ); + let log_probs_available = targets_one_hot.matmul(log_probs.swap_dims(1, 2)); let mut neg_log_likelihood = Tensor::::zeros_device([batch_size], &device); - for s in 0..target_with_blank_length { - let current_target_primes = Self::get_target_primes(targets_pad.clone(), s, self.blank); - - for t in 1..seq_length { - // \alpha_{t-1}(s) - let la1 = log_alphas + // s != s-2 + let mask_la3 = targets_intersperse + .clone() + .slice([0..batch_size, 0..(target_with_blank_length - 2)]) + .equal(targets_intersperse.slice([0..batch_size, 2..target_with_blank_length])) + .bool_not() + .float(); + let mask_la3 = pad(mask_la3, [(0, 0), (2, 0)], 0.0).unsqueeze_dim(1); + + for t in 1..seq_length { + // \alpha_{t-1}(s) + let la1 = + log_alphas .clone() - .slice([0..batch_size, (t - 1)..t, s..(s + 1)]) - .reshape([batch_size]); - - // for the logsumexp calculation - let mut lamax = la1.clone(); - - // \alpha_{t-1}(s-1) - let mut la2 = Tensor::::full_device([batch_size], NEG_INF, &device); - if s > 0 { - la2 = log_alphas - .clone() - .slice([0..batch_size, (t - 1)..t, (s - 1)..s]) - .reshape([batch_size]); - - lamax = lamax - .clone() - .mask_where(la2.clone().greater(lamax.clone()), la2.clone()); - } - - // \alpha_{t-1}(s-2) - let mut la3 = Tensor::::full_device([batch_size], NEG_INF, &device); - if s > 1 { - la3 = la3.mask_where( - Self::get_target_primes(targets_pad.clone(), s - 2, self.blank) - .equal(current_target_primes.clone()) - .bool_not(), - log_alphas - .clone() - .slice([0..batch_size, (t - 1)..t, (s - 2)..(s - 1)]) - .reshape([batch_size]), - ); - - lamax = lamax + .slice([0..batch_size, (t - 1)..t, 0..target_with_blank_length]); + // \alpha_{t-1}(s-1) + let la2 = la1 + .clone() + .slice([0..batch_size, 0..1, 0..(target_with_blank_length - 1)]) + .clamp_min(NEG_INF); + let la2 = pad(la2, [(0, 0), (0, 0), (1, 0)], NEG_INF); + // \alpha_{t-1}(s-2) + let la3 = la1 + .clone() + .slice([0..batch_size, 0..1, 0..(target_with_blank_length - 2)]) + .clamp_min(NEG_INF); + let la3 = pad(la3, [(0, 0), (0, 0), (2, 0)], NEG_INF); + // for the logsumexp calculation + let lamax: Tensor = + Tensor::stack::<4>([la1.clone(), la2.clone(), la3.clone()].to_vec(), 3) + .max_dim(3) + .squeeze(3); + + log_alphas = log_alphas.slice_assign( + [0..batch_size, t..(t + 1), 0..target_with_blank_length], + ((la1 - lamax.clone()).exp() + + (la2 - lamax.clone()).exp() + + (la3 - lamax.clone()).exp().mul(mask_la3.clone())) + .log() + .clamp_min(NEG_INF) + + lamax + + log_probs_available .clone() - .mask_where(la3.clone().greater(lamax.clone()), la3.clone()); - } - - lamax = lamax - .clone() - .mask_fill(lamax.clone().lower_equal_elem(NEG_INF), 0.0); - - log_alphas = log_alphas.slice_assign( - [0..batch_size, t..(t + 1), s..(s + 1)], - (((la1.clone() - lamax.clone()).exp() - + (la2.clone() - lamax.clone()).exp() - + (la3.clone() - lamax.clone()).exp()) - .log() - .clamp_min(NEG_INF) - + lamax.clone() - + log_probs - .clone() - .slice([0..batch_size, t..(t + 1), 0..num_classes]) - .gather(2, current_target_primes.clone().reshape([batch_size, 1, 1])) - .reshape([batch_size])) - .reshape([batch_size, 1, 1]), - ); - } + .slice([0..batch_size, 0..target_with_blank_length, t..(t + 1)]) + .swap_dims(1, 2), + ); } let l1 = log_alphas @@ -194,7 +182,7 @@ impl CTCLoss { .clone() .gather( 1, - (input_lengths.clone() - 1) + (input_lengths - 1) .reshape([batch_size, 1, 1]) .repeat(2, target_with_blank_length), ) @@ -219,23 +207,6 @@ impl CTCLoss { } } - fn get_target_primes( - targets_pad: Tensor, - idx: usize, - blank: usize, - ) -> Tensor { - let device = targets_pad.device(); - let [batch_size, _] = targets_pad.dims(); - - if idx % 2 == 0 { - Tensor::::full_device([batch_size], blank as i32, &device) - } else { - targets_pad - .slice([0..batch_size, (idx / 2)..(idx / 2 + 1)]) - .squeeze(1) - } - } - fn pad_target( targets: Tensor, target_lengths: Tensor, @@ -313,6 +284,62 @@ impl CTCLoss { } } +fn pad( + tensor: Tensor, + pad_width: [(usize, usize); D], + fill_value: E, +) -> Tensor +where + B: Backend, + K: Numeric, + K::Elem: Element, + E: ElementConversion, +{ + let device = tensor.device(); + let origin_shape = tensor.dims(); + + let mut pad_shape = [0; D]; + let mut assign_range = Vec::with_capacity(D); + for (idx, (&origin_len, (left_pad, right_pad))) in + origin_shape.iter().zip(pad_width).enumerate() + { + pad_shape[idx] = origin_len + left_pad + right_pad; + assign_range.push(left_pad..(left_pad + origin_len)); + } + + let padded = Tensor::::full_device(pad_shape, fill_value, &device); + + padded.slice_assign::(assign_range.try_into().unwrap(), tensor) +} + +fn intersperse(tensor: Tensor, value: E) -> Tensor +where + B: Backend, + K: Numeric, + K::Elem: Element, + E: ElementConversion + Clone, +{ + let device = tensor.device(); + let mut shape = tensor.dims(); + let constants: Tensor = Tensor::full_device(shape, value.clone(), &device); + shape[1] = shape[1] * 2; + let stack = Tensor::stack::<3>([tensor, constants].to_vec(), 2).reshape(shape); + pad(stack, [(0, 0), (1, 0)], value) +} + +fn one_hot(tensor: Tensor, num_classes: usize) -> Tensor { + let device = tensor.device(); + let shape = tensor.dims(); + + let labels: Tensor = tensor.unsqueeze_dim(2).repeat(2, num_classes); + let indices = Tensor::::arange_device(0..num_classes, &device) + .reshape([1, 1, num_classes]) + .repeat(1, shape[1]) + .repeat(0, shape[0]); + + labels.equal(indices).float() +} + #[cfg(test)] mod test { use burn_tensor::Data; @@ -321,6 +348,14 @@ mod test { use super::*; + #[test] + fn test_intersperse() { + let tensor = Tensor::::arange(1..25).reshape([4, 6]); + let tensor = intersperse(tensor, 0); + + println!("{}", tensor); + } + #[test] fn test_ctc_loss() { let input = Tensor::::from_data([[ From eb5357f582fc46e2b822951c5dfa4c84c388b771 Mon Sep 17 00:00:00 2001 From: wcshds Date: Fri, 8 Dec 2023 20:26:01 +0800 Subject: [PATCH 10/19] use into_scalar in assertions --- burn-core/src/nn/loss/ctc.rs | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index de7c8ee05c..43ddbaafd4 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -1,6 +1,7 @@ #![allow(clippy::single_range_in_vec_init)] use core::marker::PhantomData; +use alloc::vec::Vec; use burn_tensor::{backend::Backend, Element, ElementConversion, Int, Numeric, Tensor}; use super::Reduction; @@ -85,7 +86,6 @@ impl CTCLoss { &device, ); let targets_intersperse = intersperse(targets_pad.clone(), self.blank as u32); - println!("{}", targets_intersperse.clone()); let targets_one_hot = one_hot(targets_intersperse.clone(), num_classes); let log_alphas = Tensor::::empty_device( @@ -262,24 +262,17 @@ impl CTCLoss { ); assert!( - target_lengths - .sum() - .equal_elem(targets_size as u32) - .into_data() - .value[0], + target_lengths.sum().into_scalar().elem::() == targets_size as u32, "Batch size of targets ({}) should correspond to sum of target_lengths ({}).", log_probs_batch_size, target_lengths_size ); - let max_input_length = input_lengths.max(); + let max_input_length = input_lengths.max().into_scalar().elem::() as usize; assert!( - max_input_length.clone() - .lower_equal_elem(input_seq_length as u32) - .into_data() - .value[0], + max_input_length == input_seq_length, "The maximum value of input_lengths ({}) must not be greater than the sequence length of log_probs ({}).", - max_input_length.into_scalar(), input_seq_length + max_input_length, input_seq_length ); } } From 5c5728325d2e70a5200efe2a0e2aaae7aa59ae32 Mon Sep 17 00:00:00 2001 From: wcshds Date: Mon, 11 Dec 2023 20:41:19 +0800 Subject: [PATCH 11/19] reduce the size of log alphas --- burn-core/src/nn/loss/ctc.rs | 39 ++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index 43ddbaafd4..0428c8125e 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -70,13 +70,17 @@ impl CTCLoss { target_lengths.clone(), ); + // make sure tensors are on the same device let device = log_probs.device(); let input_lengths = input_lengths.to_device(&device); let target_lengths = target_lengths.to_device(&device); - let [batch_size, seq_length, num_classes] = log_probs.dims(); + let [batch_size, _, num_classes] = log_probs.dims(); + let min_input_length = input_lengths.clone().min().into_scalar().elem::() as usize; + let max_input_length = input_lengths.clone().max().into_scalar().elem::() as usize; let max_target_length = target_lengths.clone().max().into_scalar().elem::() as usize; let target_with_blank_length = 2 * max_target_length + 1; + let reserved_seq_length = 1 + max_input_length - min_input_length; let targets_pad = Self::pad_target( targets, @@ -88,8 +92,10 @@ impl CTCLoss { let targets_intersperse = intersperse(targets_pad.clone(), self.blank as u32); let targets_one_hot = one_hot(targets_intersperse.clone(), num_classes); + // There is no need to reserve alpha for each time step; only reserved_seq_length is needed. + // If the input length is all the same, it is sufficient to save only one time step per iter. let log_alphas = Tensor::::empty_device( - [batch_size, seq_length, target_with_blank_length], + [batch_size, reserved_seq_length, target_with_blank_length], &device, ); // initialize value at t0 @@ -129,12 +135,19 @@ impl CTCLoss { .float(); let mask_la3 = pad(mask_la3, [(0, 0), (2, 0)], 0.0).unsqueeze_dim(1); - for t in 1..seq_length { + for t in 1..max_input_length { + let (alpha_prime_prev, alpha_prime_next) = if (t as i32 - min_input_length as i32) < 0 { + (0, 0) + } else { + let prev = t - min_input_length as usize; + (prev, prev + 1) + }; // \alpha_{t-1}(s) - let la1 = - log_alphas - .clone() - .slice([0..batch_size, (t - 1)..t, 0..target_with_blank_length]); + let la1 = log_alphas.clone().slice([ + 0..batch_size, + alpha_prime_prev..(alpha_prime_prev + 1), + 0..target_with_blank_length, + ]); // \alpha_{t-1}(s-1) let la2 = la1 .clone() @@ -154,7 +167,11 @@ impl CTCLoss { .squeeze(3); log_alphas = log_alphas.slice_assign( - [0..batch_size, t..(t + 1), 0..target_with_blank_length], + [ + 0..batch_size, + alpha_prime_next..(alpha_prime_next + 1), + 0..target_with_blank_length, + ], ((la1 - lamax.clone()).exp() + (la2 - lamax.clone()).exp() + (la3 - lamax.clone()).exp().mul(mask_la3.clone())) @@ -172,7 +189,7 @@ impl CTCLoss { .clone() .gather( 1, - (input_lengths.clone() - 1) + (input_lengths.clone() - min_input_length as i32) .reshape([batch_size, 1, 1]) .repeat(2, target_with_blank_length), ) @@ -182,7 +199,7 @@ impl CTCLoss { .clone() .gather( 1, - (input_lengths - 1) + (input_lengths.clone() - min_input_length as i32) .reshape([batch_size, 1, 1]) .repeat(2, target_with_blank_length), ) @@ -270,7 +287,7 @@ impl CTCLoss { let max_input_length = input_lengths.max().into_scalar().elem::() as usize; assert!( - max_input_length == input_seq_length, + max_input_length <= input_seq_length, "The maximum value of input_lengths ({}) must not be greater than the sequence length of log_probs ({}).", max_input_length, input_seq_length ); From 7d047cd7ab10667a2e2efcf0fad738257aa4b9a0 Mon Sep 17 00:00:00 2001 From: wcshds Date: Tue, 12 Dec 2023 00:08:52 +0800 Subject: [PATCH 12/19] wordaround for slice bug on libtorch backend #1055 --- burn-core/src/nn/loss/ctc.rs | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index 0428c8125e..51ccf060bd 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -92,8 +92,9 @@ impl CTCLoss { let targets_intersperse = intersperse(targets_pad.clone(), self.blank as u32); let targets_one_hot = one_hot(targets_intersperse.clone(), num_classes); - // There is no need to reserve alpha for each time step; only reserved_seq_length is needed. - // If the input length is all the same, it is sufficient to save only one time step per iter. + // There is no need to reserve alpha for each time step; only reserved_seq_length + // is needed. For instance, if the input length is all the same, the reserved_seq_length + // value will be set to 1, which is adequate. let log_alphas = Tensor::::empty_device( [batch_size, reserved_seq_length, target_with_blank_length], &device, @@ -130,7 +131,12 @@ impl CTCLoss { let mask_la3 = targets_intersperse .clone() .slice([0..batch_size, 0..(target_with_blank_length - 2)]) - .equal(targets_intersperse.slice([0..batch_size, 2..target_with_blank_length])) + .equal( + targets_intersperse + .clone() + .slice([0..batch_size, 2..target_with_blank_length]) + .clone(), + ) .bool_not() .float(); let mask_la3 = pad(mask_la3, [(0, 0), (2, 0)], 0.0).unsqueeze_dim(1); @@ -358,14 +364,6 @@ mod test { use super::*; - #[test] - fn test_intersperse() { - let tensor = Tensor::::arange(1..25).reshape([4, 6]); - let tensor = intersperse(tensor, 0); - - println!("{}", tensor); - } - #[test] fn test_ctc_loss() { let input = Tensor::::from_data([[ From 9599f32d5f55a812e7c839b5ce8bcea32dbb506e Mon Sep 17 00:00:00 2001 From: wcshds Date: Tue, 12 Dec 2023 15:39:24 +0800 Subject: [PATCH 13/19] reduce the size of one_hot --- burn-core/src/nn/loss/ctc.rs | 80 +++++++++++++++++++++++------------- 1 file changed, 52 insertions(+), 28 deletions(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index 51ccf060bd..79dc3c6676 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -75,7 +75,7 @@ impl CTCLoss { let input_lengths = input_lengths.to_device(&device); let target_lengths = target_lengths.to_device(&device); - let [batch_size, _, num_classes] = log_probs.dims(); + let [batch_size, seq_length, num_classes] = log_probs.dims(); let min_input_length = input_lengths.clone().min().into_scalar().elem::() as usize; let max_input_length = input_lengths.clone().max().into_scalar().elem::() as usize; let max_target_length = target_lengths.clone().max().into_scalar().elem::() as usize; @@ -89,8 +89,7 @@ impl CTCLoss { self.blank, &device, ); - let targets_intersperse = intersperse(targets_pad.clone(), self.blank as u32); - let targets_one_hot = one_hot(targets_intersperse.clone(), num_classes); + let targets_one_hot = one_hot(targets_pad.clone(), num_classes); // There is no need to reserve alpha for each time step; only reserved_seq_length // is needed. For instance, if the input length is all the same, the reserved_seq_length @@ -115,6 +114,7 @@ impl CTCLoss { .slice([0..batch_size, 0..1, self.blank..(self.blank + 1)]), ); let target_primes: Tensor = targets_pad + .clone() .slice([0..batch_size, 0..1]) .reshape([batch_size, 1, 1]); let mut log_alphas = log_alphas.slice_assign( @@ -124,22 +124,60 @@ impl CTCLoss { .slice([0..batch_size, 0..1, 0..num_classes]) .gather(2, target_primes), ); - let log_probs_available = targets_one_hot.matmul(log_probs.swap_dims(1, 2)); + + // Shape: [batch_size, seq_length, max_target_length] + let log_probs_letter_available = targets_one_hot + .matmul(log_probs.clone().swap_dims(1, 2)) + .swap_dims(1, 2); + // Shape: [batch_size, seq_length, 1] + let log_probs_blank_available = + log_probs + .clone() + .slice([0..batch_size, 0..seq_length, self.blank..self.blank + 1]); + // Shape: [batch_size, seq_length, 2 * max_target_length + 1] + let log_probs_available = + Tensor::::zeros([batch_size, seq_length, target_with_blank_length]); + let log_probs_available = log_probs_available.slice_assign( + [0..batch_size, 0..seq_length, 0..1], + log_probs_blank_available.clone(), + ); + let log_probs_available = log_probs_available.slice_assign( + [0..batch_size, 0..seq_length, 1..target_with_blank_length], + // interlace log_probs_letter_available and log_probs_blank_available + Tensor::stack::<4>( + [ + log_probs_letter_available.clone(), + log_probs_blank_available.repeat(2, max_target_length), + ] + .to_vec(), + 3, + ) + .reshape([batch_size, seq_length, 2 * max_target_length]), + ); let mut neg_log_likelihood = Tensor::::zeros_device([batch_size], &device); // s != s-2 - let mask_la3 = targets_intersperse + let mask_la3_letter = targets_pad .clone() - .slice([0..batch_size, 0..(target_with_blank_length - 2)]) + .slice([0..batch_size, 0..(max_target_length - 1)]) .equal( - targets_intersperse + targets_pad .clone() - .slice([0..batch_size, 2..target_with_blank_length]) + .slice([0..batch_size, 1..max_target_length]) .clone(), ) .bool_not() .float(); - let mask_la3 = pad(mask_la3, [(0, 0), (2, 0)], 0.0).unsqueeze_dim(1); + let mask_la3_blank = + Tensor::::zeros_device([batch_size, max_target_length - 1], &device); + let mask_la3: Tensor = pad( + // interlace mask_la3_letter and mask_la3_blank + Tensor::stack::<3>([mask_la3_letter, mask_la3_blank].to_vec(), 2) + .reshape([batch_size, 2 * (max_target_length - 1)]), + [(0, 0), (3, 0)], + 0.0, + ) + .unsqueeze_dim(1); for t in 1..max_input_length { let (alpha_prime_prev, alpha_prime_next) = if (t as i32 - min_input_length as i32) < 0 { @@ -184,10 +222,11 @@ impl CTCLoss { .log() .clamp_min(NEG_INF) + lamax - + log_probs_available - .clone() - .slice([0..batch_size, 0..target_with_blank_length, t..(t + 1)]) - .swap_dims(1, 2), + + log_probs_available.clone().slice([ + 0..batch_size, + t..(t + 1), + 0..target_with_blank_length, + ]), ); } @@ -328,21 +367,6 @@ where padded.slice_assign::(assign_range.try_into().unwrap(), tensor) } -fn intersperse(tensor: Tensor, value: E) -> Tensor -where - B: Backend, - K: Numeric, - K::Elem: Element, - E: ElementConversion + Clone, -{ - let device = tensor.device(); - let mut shape = tensor.dims(); - let constants: Tensor = Tensor::full_device(shape, value.clone(), &device); - shape[1] = shape[1] * 2; - let stack = Tensor::stack::<3>([tensor, constants].to_vec(), 2).reshape(shape); - pad(stack, [(0, 0), (1, 0)], value) -} - fn one_hot(tensor: Tensor, num_classes: usize) -> Tensor { let device = tensor.device(); let shape = tensor.dims(); From 25972cedec078684126eb4212bf6176f2c85ae38 Mon Sep 17 00:00:00 2001 From: wcshds Date: Tue, 12 Dec 2023 15:48:33 +0800 Subject: [PATCH 14/19] make sure tensors are on the same device --- burn-core/src/nn/loss/ctc.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index 79dc3c6676..afbe77c936 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -135,8 +135,10 @@ impl CTCLoss { .clone() .slice([0..batch_size, 0..seq_length, self.blank..self.blank + 1]); // Shape: [batch_size, seq_length, 2 * max_target_length + 1] - let log_probs_available = - Tensor::::zeros([batch_size, seq_length, target_with_blank_length]); + let log_probs_available = Tensor::::zeros_device( + [batch_size, seq_length, target_with_blank_length], + &device, + ); let log_probs_available = log_probs_available.slice_assign( [0..batch_size, 0..seq_length, 0..1], log_probs_blank_available.clone(), From 2758b1b7aeb2bc7285aa5b642cba3e6c32825419 Mon Sep 17 00:00:00 2001 From: wcshds Date: Mon, 25 Dec 2023 00:45:32 +0800 Subject: [PATCH 15/19] adapt to burn's new device api --- burn-core/src/nn/loss/ctc.rs | 225 +++++++++++++++++------------------ 1 file changed, 110 insertions(+), 115 deletions(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index afbe77c936..3a17a0db97 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -94,18 +94,14 @@ impl CTCLoss { // There is no need to reserve alpha for each time step; only reserved_seq_length // is needed. For instance, if the input length is all the same, the reserved_seq_length // value will be set to 1, which is adequate. - let log_alphas = Tensor::::empty_device( + let log_alphas = Tensor::::empty( [batch_size, reserved_seq_length, target_with_blank_length], &device, ); // initialize value at t0 let log_alphas = log_alphas.slice_assign( [0..batch_size, 0..1, 0..target_with_blank_length], - Tensor::::full_device( - [batch_size, 1, target_with_blank_length], - NEG_INF, - &device, - ), + Tensor::::full([batch_size, 1, target_with_blank_length], NEG_INF, &device), ); let log_alphas = log_alphas.slice_assign( [0..batch_size, 0..1, 0..1], @@ -135,10 +131,8 @@ impl CTCLoss { .clone() .slice([0..batch_size, 0..seq_length, self.blank..self.blank + 1]); // Shape: [batch_size, seq_length, 2 * max_target_length + 1] - let log_probs_available = Tensor::::zeros_device( - [batch_size, seq_length, target_with_blank_length], - &device, - ); + let log_probs_available = + Tensor::::zeros([batch_size, seq_length, target_with_blank_length], &device); let log_probs_available = log_probs_available.slice_assign( [0..batch_size, 0..seq_length, 0..1], log_probs_blank_available.clone(), @@ -156,7 +150,7 @@ impl CTCLoss { ) .reshape([batch_size, seq_length, 2 * max_target_length]), ); - let mut neg_log_likelihood = Tensor::::zeros_device([batch_size], &device); + let mut neg_log_likelihood = Tensor::::zeros([batch_size], &device); // s != s-2 let mask_la3_letter = targets_pad @@ -170,8 +164,7 @@ impl CTCLoss { ) .bool_not() .float(); - let mask_la3_blank = - Tensor::::zeros_device([batch_size, max_target_length - 1], &device); + let mask_la3_blank = Tensor::::zeros([batch_size, max_target_length - 1], &device); let mask_la3: Tensor = pad( // interlace mask_la3_letter and mask_la3_blank Tensor::stack::<3>([mask_la3_letter, mask_la3_blank].to_vec(), 2) @@ -280,11 +273,8 @@ impl CTCLoss { ) -> Tensor { let [batch_size] = target_lengths.dims(); - let mut targets_pad = Tensor::::full_device( - [batch_size, max_target_length], - blank as i32, - &device, - ); + let mut targets_pad = + Tensor::::full([batch_size, max_target_length], blank as i32, &device); let mut start = 0usize; for (batch, length) in target_lengths.iter_dim(0).enumerate() { let length = length.into_scalar().elem::() as usize; @@ -364,7 +354,7 @@ where assign_range.push(left_pad..(left_pad + origin_len)); } - let padded = Tensor::::full_device(pad_shape, fill_value, &device); + let padded = Tensor::::full(pad_shape, fill_value, &device); padded.slice_assign::(assign_range.try_into().unwrap(), tensor) } @@ -374,7 +364,7 @@ fn one_hot(tensor: Tensor, num_classes: usize) -> Tensor< let shape = tensor.dims(); let labels: Tensor = tensor.unsqueeze_dim(2).repeat(2, num_classes); - let indices = Tensor::::arange_device(0..num_classes, &device) + let indices = Tensor::::arange(0..num_classes, &device) .reshape([1, 1, num_classes]) .repeat(1, shape[1]) .repeat(0, shape[0]); @@ -392,101 +382,106 @@ mod test { #[test] fn test_ctc_loss() { - let input = Tensor::::from_data([[ - [ - -0.785, -3.471, -2.531, -3.948, -2.373, -3.042, -2.029, -2.255, -4.228, -3.810, - ], - [ - -3.548, -1.692, -0.967, -2.519, -2.806, -2.760, -2.434, -2.762, -3.638, -3.669, - ], - [ - -3.904, -1.799, -1.312, -2.530, -2.267, -3.169, -3.838, -2.073, -2.484, -2.418, - ], - [ - -0.890, -2.506, -3.405, -3.038, -2.483, -2.861, -2.749, -3.086, -1.960, -3.336, - ], - [ - -1.113, -3.557, -2.580, -1.465, -3.884, -1.993, -3.574, -3.466, -2.669, -2.985, - ], - [ - -3.948, -0.828, -1.805, -2.842, -2.767, -3.891, -2.825, -1.783, -5.566, -5.072, - ], - [ - -1.677, -1.703, -4.191, -3.862, -1.726, -2.616, -2.366, -2.324, -2.767, -2.418, - ], - [ - -1.511, -1.125, -3.526, -3.007, -2.975, -3.358, -2.037, -2.093, -4.137, -3.900, - ], - [ - -1.850, -2.767, -1.718, -2.185, -2.890, -1.998, -3.661, -3.997, -2.738, -1.671, - ], - [ - -2.621, -1.234, -3.499, -3.494, -1.612, -1.713, -2.179, -2.884, -4.122, -4.581, - ], - [ - -1.519, -3.283, -1.287, -3.217, -2.544, -3.128, -2.061, -3.039, -2.388, -3.272, - ], - [ - -1.112, -1.258, -3.206, -3.103, -3.918, -2.577, -4.399, -4.488, -2.187, -2.663, - ], - [ - -1.889, -2.344, -3.232, -2.781, -3.312, -0.911, -2.864, -4.825, -3.180, -2.243, - ], - [ - -4.368, -1.471, -1.308, -2.950, -3.211, -2.692, -1.923, -2.020, -3.859, -3.601, - ], - [ - -4.254, -3.291, -1.539, -2.622, -2.281, -1.427, -1.712, -3.082, -2.653, -3.809, - ], - [ - -3.322, -2.904, -0.942, -3.157, -2.987, -3.736, -1.208, -4.155, -4.383, -2.583, - ], - [ - -2.827, -2.293, -3.109, -3.196, -3.297, -2.451, -2.136, -3.423, -1.012, -2.146, - ], - [ - -1.803, -1.666, -1.780, -4.024, -3.083, -4.520, -2.674, -2.527, -3.365, -1.516, - ], - [ - -2.199, -2.340, -2.009, -3.736, -3.363, -2.721, -2.350, -1.951, -1.815, -2.009, - ], - [ - -1.721, -3.726, -1.701, -3.503, -2.153, -3.242, -2.284, -1.838, -2.646, -2.329, - ], - [ - -3.655, -2.916, -2.913, -1.197, -3.060, -2.154, -1.776, -3.404, -1.823, -3.310, - ], - [ - -2.671, -2.592, -2.929, -1.416, -2.007, -2.886, -2.781, -2.597, -1.738, -2.862, - ], - [ - -1.686, -4.173, -0.884, -5.493, -5.498, -1.707, -3.573, -5.085, -2.060, -3.352, - ], - [ - -2.114, -2.478, -2.178, -3.457, -3.264, -2.659, -2.653, -1.222, -2.375, -2.475, - ], - [ - -2.136, -3.563, -2.325, -3.081, -2.035, -3.154, -1.122, -3.486, -1.951, -3.270, - ], - [ - -3.206, -3.031, -3.913, -2.652, -2.985, -2.635, -1.153, -3.122, -3.256, -1.203, - ], - [ - -2.104, -1.719, -2.141, -2.695, -2.448, -2.991, -1.542, -2.646, -3.090, -3.066, - ], - [ - -3.320, -5.098, -1.085, -1.335, -2.588, -3.098, -2.466, -2.951, -3.911, -2.538, - ], - [ - -3.756, -1.814, -2.752, -2.410, -3.305, -2.387, -2.112, -1.720, -2.616, -1.843, - ], - [ - -3.985, -2.489, -2.305, -1.454, -2.533, -5.091, -1.759, -2.180, -3.673, -1.779, - ], - ]]); - let target = Tensor::::from_data([1, 9, 6, 9, 4]); - let input_lengths = Tensor::::from_data([30]); - let target_lengths = Tensor::::from_data([5]); + let device = Default::default(); + + let input = Tensor::::from_data( + [[ + [ + -0.785, -3.471, -2.531, -3.948, -2.373, -3.042, -2.029, -2.255, -4.228, -3.810, + ], + [ + -3.548, -1.692, -0.967, -2.519, -2.806, -2.760, -2.434, -2.762, -3.638, -3.669, + ], + [ + -3.904, -1.799, -1.312, -2.530, -2.267, -3.169, -3.838, -2.073, -2.484, -2.418, + ], + [ + -0.890, -2.506, -3.405, -3.038, -2.483, -2.861, -2.749, -3.086, -1.960, -3.336, + ], + [ + -1.113, -3.557, -2.580, -1.465, -3.884, -1.993, -3.574, -3.466, -2.669, -2.985, + ], + [ + -3.948, -0.828, -1.805, -2.842, -2.767, -3.891, -2.825, -1.783, -5.566, -5.072, + ], + [ + -1.677, -1.703, -4.191, -3.862, -1.726, -2.616, -2.366, -2.324, -2.767, -2.418, + ], + [ + -1.511, -1.125, -3.526, -3.007, -2.975, -3.358, -2.037, -2.093, -4.137, -3.900, + ], + [ + -1.850, -2.767, -1.718, -2.185, -2.890, -1.998, -3.661, -3.997, -2.738, -1.671, + ], + [ + -2.621, -1.234, -3.499, -3.494, -1.612, -1.713, -2.179, -2.884, -4.122, -4.581, + ], + [ + -1.519, -3.283, -1.287, -3.217, -2.544, -3.128, -2.061, -3.039, -2.388, -3.272, + ], + [ + -1.112, -1.258, -3.206, -3.103, -3.918, -2.577, -4.399, -4.488, -2.187, -2.663, + ], + [ + -1.889, -2.344, -3.232, -2.781, -3.312, -0.911, -2.864, -4.825, -3.180, -2.243, + ], + [ + -4.368, -1.471, -1.308, -2.950, -3.211, -2.692, -1.923, -2.020, -3.859, -3.601, + ], + [ + -4.254, -3.291, -1.539, -2.622, -2.281, -1.427, -1.712, -3.082, -2.653, -3.809, + ], + [ + -3.322, -2.904, -0.942, -3.157, -2.987, -3.736, -1.208, -4.155, -4.383, -2.583, + ], + [ + -2.827, -2.293, -3.109, -3.196, -3.297, -2.451, -2.136, -3.423, -1.012, -2.146, + ], + [ + -1.803, -1.666, -1.780, -4.024, -3.083, -4.520, -2.674, -2.527, -3.365, -1.516, + ], + [ + -2.199, -2.340, -2.009, -3.736, -3.363, -2.721, -2.350, -1.951, -1.815, -2.009, + ], + [ + -1.721, -3.726, -1.701, -3.503, -2.153, -3.242, -2.284, -1.838, -2.646, -2.329, + ], + [ + -3.655, -2.916, -2.913, -1.197, -3.060, -2.154, -1.776, -3.404, -1.823, -3.310, + ], + [ + -2.671, -2.592, -2.929, -1.416, -2.007, -2.886, -2.781, -2.597, -1.738, -2.862, + ], + [ + -1.686, -4.173, -0.884, -5.493, -5.498, -1.707, -3.573, -5.085, -2.060, -3.352, + ], + [ + -2.114, -2.478, -2.178, -3.457, -3.264, -2.659, -2.653, -1.222, -2.375, -2.475, + ], + [ + -2.136, -3.563, -2.325, -3.081, -2.035, -3.154, -1.122, -3.486, -1.951, -3.270, + ], + [ + -3.206, -3.031, -3.913, -2.652, -2.985, -2.635, -1.153, -3.122, -3.256, -1.203, + ], + [ + -2.104, -1.719, -2.141, -2.695, -2.448, -2.991, -1.542, -2.646, -3.090, -3.066, + ], + [ + -3.320, -5.098, -1.085, -1.335, -2.588, -3.098, -2.466, -2.951, -3.911, -2.538, + ], + [ + -3.756, -1.814, -2.752, -2.410, -3.305, -2.387, -2.112, -1.720, -2.616, -1.843, + ], + [ + -3.985, -2.489, -2.305, -1.454, -2.533, -5.091, -1.759, -2.180, -3.673, -1.779, + ], + ]], + &device, + ); + let target = Tensor::::from_data([1, 9, 6, 9, 4], &device); + let input_lengths = Tensor::::from_data([30], &device); + let target_lengths = Tensor::::from_data([5], &device); let expected_res = Data::from([50.3788948059082]); let ctc_loss = CTCLoss::::new(0); From 9315d9822cf00f772241ff4ee1b50bb7d17b06a0 Mon Sep 17 00:00:00 2001 From: wcshds Date: Wed, 27 Dec 2023 20:54:05 +0800 Subject: [PATCH 16/19] make sure the argument of the logarithm greater than 0 --- burn-core/src/nn/loss/ctc.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index 3a17a0db97..a8c607c862 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -213,9 +213,10 @@ impl CTCLoss { ], ((la1 - lamax.clone()).exp() + (la2 - lamax.clone()).exp() - + (la3 - lamax.clone()).exp().mul(mask_la3.clone())) - .log() - .clamp_min(NEG_INF) + + (la3 - lamax.clone()).exp().mul(mask_la3.clone()) + + 1e-15) + .log() + .clamp_min(NEG_INF) + lamax + log_probs_available.clone().slice([ 0..batch_size, @@ -252,7 +253,7 @@ impl CTCLoss { // for the logsumexp calculation let m = Tensor::cat([l1.clone(), l2.clone()].to_vec(), 0).max(); let m = m.clone().clamp_min(NEG_INF); - let log_likelihood = ((l1 - m.clone()).exp() + (l2 - m.clone()).exp()).log() + m; + let log_likelihood = ((l1 - m.clone()).exp() + (l2 - m.clone()).exp() + 1e-15).log() + m; neg_log_likelihood = neg_log_likelihood.slice_assign([0..batch_size], -log_likelihood); match reduction { From 08dd4ebdb042296a20b9e7672521cbb7fa8561ce Mon Sep 17 00:00:00 2001 From: wcshds Date: Wed, 27 Dec 2023 22:30:30 +0800 Subject: [PATCH 17/19] refactor the small value used to prevent log(0) into a constant --- burn-core/src/nn/loss/ctc.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index a8c607c862..c22fd06908 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -7,6 +7,8 @@ use burn_tensor::{backend::Backend, Element, ElementConversion, Int, Numeric, Te use super::Reduction; const NEG_INF: f32 = -1e5; +// a small value used to prevent the occurrence of log(0) +const DELTA: f32 = -1e-5; /// The Connectionist Temporal Classification loss. #[derive(Clone, Debug)] @@ -214,7 +216,7 @@ impl CTCLoss { ((la1 - lamax.clone()).exp() + (la2 - lamax.clone()).exp() + (la3 - lamax.clone()).exp().mul(mask_la3.clone()) - + 1e-15) + + DELTA) .log() .clamp_min(NEG_INF) + lamax @@ -253,7 +255,7 @@ impl CTCLoss { // for the logsumexp calculation let m = Tensor::cat([l1.clone(), l2.clone()].to_vec(), 0).max(); let m = m.clone().clamp_min(NEG_INF); - let log_likelihood = ((l1 - m.clone()).exp() + (l2 - m.clone()).exp() + 1e-15).log() + m; + let log_likelihood = ((l1 - m.clone()).exp() + (l2 - m.clone()).exp() + DELTA).log() + m; neg_log_likelihood = neg_log_likelihood.slice_assign([0..batch_size], -log_likelihood); match reduction { From 4f009fccdfb1a82430dff9a5c477e472e6add597 Mon Sep 17 00:00:00 2001 From: wcshds Date: Wed, 27 Dec 2023 22:34:57 +0800 Subject: [PATCH 18/19] fix typo --- burn-core/src/nn/loss/ctc.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index c22fd06908..2fc3c28f2a 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -8,7 +8,7 @@ use super::Reduction; const NEG_INF: f32 = -1e5; // a small value used to prevent the occurrence of log(0) -const DELTA: f32 = -1e-5; +const DELTA: f32 = 1e-5; /// The Connectionist Temporal Classification loss. #[derive(Clone, Debug)] From 0b999224978a225e5856b039b1b8b7c838724b7c Mon Sep 17 00:00:00 2001 From: wcshds Date: Wed, 27 Dec 2023 23:33:19 +0800 Subject: [PATCH 19/19] remove unnecessary code --- burn-core/src/nn/loss/ctc.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/burn-core/src/nn/loss/ctc.rs b/burn-core/src/nn/loss/ctc.rs index 2fc3c28f2a..f7d2d883ab 100644 --- a/burn-core/src/nn/loss/ctc.rs +++ b/burn-core/src/nn/loss/ctc.rs @@ -134,7 +134,7 @@ impl CTCLoss { .slice([0..batch_size, 0..seq_length, self.blank..self.blank + 1]); // Shape: [batch_size, seq_length, 2 * max_target_length + 1] let log_probs_available = - Tensor::::zeros([batch_size, seq_length, target_with_blank_length], &device); + Tensor::::empty([batch_size, seq_length, target_with_blank_length], &device); let log_probs_available = log_probs_available.slice_assign( [0..batch_size, 0..seq_length, 0..1], log_probs_blank_available.clone(), @@ -152,7 +152,6 @@ impl CTCLoss { ) .reshape([batch_size, seq_length, 2 * max_target_length]), ); - let mut neg_log_likelihood = Tensor::::zeros([batch_size], &device); // s != s-2 let mask_la3_letter = targets_pad @@ -180,7 +179,7 @@ impl CTCLoss { let (alpha_prime_prev, alpha_prime_next) = if (t as i32 - min_input_length as i32) < 0 { (0, 0) } else { - let prev = t - min_input_length as usize; + let prev = t - min_input_length; (prev, prev + 1) }; // \alpha_{t-1}(s) @@ -256,7 +255,7 @@ impl CTCLoss { let m = Tensor::cat([l1.clone(), l2.clone()].to_vec(), 0).max(); let m = m.clone().clamp_min(NEG_INF); let log_likelihood = ((l1 - m.clone()).exp() + (l2 - m.clone()).exp() + DELTA).log() + m; - neg_log_likelihood = neg_log_likelihood.slice_assign([0..batch_size], -log_likelihood); + let neg_log_likelihood = -log_likelihood; match reduction { Some(Reduction::Mean) | Some(Reduction::Auto) => { @@ -277,7 +276,7 @@ impl CTCLoss { let [batch_size] = target_lengths.dims(); let mut targets_pad = - Tensor::::full([batch_size, max_target_length], blank as i32, &device); + Tensor::::full([batch_size, max_target_length], blank as i32, device); let mut start = 0usize; for (batch, length) in target_lengths.iter_dim(0).enumerate() { let length = length.into_scalar().elem::() as usize;