Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement ctc loss function #1049

Closed
wants to merge 20 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions burn-core/src/nn/loss/ctc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use burn_tensor::{backend::Backend, Element, ElementConversion, Int, Numeric, Te
use super::Reduction;

const NEG_INF: f32 = -1e5;
// a small value used to prevent the occurrence of log(0)
const DELTA: f32 = -1e-5;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to have this number as negative? The literal number you has is positive. In an unlikely event, (l1 - m.clone()).exp() + (l2 - m.clone()).exp() expression could be equal to abs(DELTA) which would still lead to log(0) situation.

Additionally, I suggest we use [https://doc.rust-lang.org/std/primitive.f32.html#associatedconstant.EPSILON](f32's EPSILON) or [f16's EPSILON]https://docs.rs/tract-core/latest/tract_core/prelude/struct.f16.html#associatedconstant.EPSILON constants depending what on Backend's precision settings. @nathanielsimard or @louisfd can suggest on how we can extract this. -1e-5 seems a rather big number for f16 or f32. (probably it may not work for f16 because its epsilon is 4.88e-04. we need to double check it)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, it's a typo. DELTA should be positive.

1e-5 can ensure that the results of the loss are accurate to three decimal places, but 4.88e-4 is a bit large. Perhaps CTC Loss is indeed not suitable for the use of half-precision training.


/// The Connectionist Temporal Classification loss.
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -214,7 +216,7 @@ impl<B: Backend> CTCLoss<B> {
((la1 - lamax.clone()).exp()
+ (la2 - lamax.clone()).exp()
+ (la3 - lamax.clone()).exp().mul(mask_la3.clone())
+ 1e-15)
+ DELTA)
.log()
.clamp_min(NEG_INF)
+ lamax
Expand Down Expand Up @@ -253,7 +255,7 @@ impl<B: Backend> CTCLoss<B> {
// for the logsumexp calculation
let m = Tensor::cat([l1.clone(), l2.clone()].to_vec(), 0).max();
let m = m.clone().clamp_min(NEG_INF);
let log_likelihood = ((l1 - m.clone()).exp() + (l2 - m.clone()).exp() + 1e-15).log() + m;
let log_likelihood = ((l1 - m.clone()).exp() + (l2 - m.clone()).exp() + DELTA).log() + m;
neg_log_likelihood = neg_log_likelihood.slice_assign([0..batch_size], -log_likelihood);

match reduction {
Expand Down
Loading