From dabd4f53775424f7ead469ab5729a21fd32b0333 Mon Sep 17 00:00:00 2001 From: josephsdavid Date: Tue, 20 Dec 2022 19:05:59 +0000 Subject: [PATCH 1/4] implement logit_focal_loss --- src/losses/functions.jl | 50 +++++++++++++++++++++++++++++++++++++---- test/losses.jl | 17 ++++++++++++-- 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index ffda2ff99a..be52740e60 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -603,14 +603,56 @@ function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ)) agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims)) end +""" + logit_focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=eps(ŷ)) + +Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf) +which can be used in classification tasks with highly imbalanced classes. +It down-weights well-classified examples and focuses on hard examples. +The input, 'ŷ', is expected to be normalized (i.e. [softmax](@ref Softmax) output). + +The modulating factor, `γ`, controls the down-weighting strength. +For `γ == 0`, the loss is mathematically equivalent to [`Losses.crossentropy`](@ref). + +# Example +```jldoctest +julia> y = [1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0] +3×5 Matrix{Int64}: + 1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0 + +julia> ŷ = softmax(reshape(-7:7, 3, 5) .* 1f0) +3×5 Matrix{Float32}: + 0.0900306 0.0900306 0.0900306 0.0900306 0.0900306 + 0.244728 0.244728 0.244728 0.244728 0.244728 + 0.665241 0.665241 0.665241 0.665241 0.665241 + +julia> Flux.logit_focal_loss(ŷ, y) ≈ 1.1277571935622628 +true +``` + +See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels + +""" +function logit_focal_loss(ŷ, y; γ=2.0f0, agg=mean, dims=1, ϵ=Flux.epseltype(ŷ)) + Flux.losses._check_sizes(ŷ, y) + logpt = logsoftmax(ŷ; dims=dims) + logpt .+= ϵ + loss = agg(sum(@. -y * (1 - exp.(logpt))^γ * logpt; dims=dims)) + return loss +end + """ siamese_contrastive_loss(ŷ, y; margin = 1, agg = mean) - + Return the [contrastive loss](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf) which can be useful for training Siamese Networks. It is given by - - agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2) - + + agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2) + Specify `margin` to set the baseline for distance at which pairs are dissimilar. # Example diff --git a/test/losses.jl b/test/losses.jl index 2ca697a657..d1db2390eb 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -14,7 +14,7 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle, Flux.Losses.dice_coeff_loss, Flux.Losses.poisson_loss, Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss, - Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss, Flux.Losses.siamese_contrastive_loss] + Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss, Flux.Losses.siamese_contrastive_loss, Flux.Losses.logit_focal_loss] @testset "xlogx & xlogy" begin @@ -210,7 +210,20 @@ end @test Flux.focal_loss(ŷ1, y1) ≈ 0.45990566879720157 @test Flux.focal_loss(ŷ, y; γ=0.0) ≈ Flux.crossentropy(ŷ, y) end - + +@testset "logit_focal_loss" begin + rng = Random.seed!(Random.default_rng(), 5) + y = rand(rng, Float32, 6, 40, 2) + yhat = rand(rng, Float32, 6, 40, 2) + + @test logit_focal_loss(yhat, y; γ=0) ≈ + Flux.Losses.logitcrossentropy(yhat, y) + + + @test logit_focal_loss(yhat, y; γ=2) == + Flux.Losses.focal_loss(Flux.softmax(yhat; dims=1), y; γ=2) +end + @testset "siamese_contrastive_loss" begin y = [1 0 0 0 From 6a91e0dd99c0eb0ab0f7efd0c8c1763d46cabe2c Mon Sep 17 00:00:00 2001 From: josephsdavid Date: Tue, 20 Dec 2022 19:07:08 +0000 Subject: [PATCH 2/4] doctest maybe? --- src/losses/functions.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index be52740e60..2f287e90bf 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -624,7 +624,7 @@ julia> y = [1 0 0 0 1 0 1 0 1 0 0 0 1 0 0 -julia> ŷ = softmax(reshape(-7:7, 3, 5) .* 1f0) +julia> ŷ = reshape(-7:7, 3, 5) .* 1f0 3×5 Matrix{Float32}: 0.0900306 0.0900306 0.0900306 0.0900306 0.0900306 0.244728 0.244728 0.244728 0.244728 0.244728 @@ -634,7 +634,7 @@ julia> Flux.logit_focal_loss(ŷ, y) ≈ 1.1277571935622628 true ``` -See also: [`Losses.binary_focal_loss`](@ref) for binary (not one-hot) labels +See also: [`Losses.focal_loss`](@ref) """ function logit_focal_loss(ŷ, y; γ=2.0f0, agg=mean, dims=1, ϵ=Flux.epseltype(ŷ)) From a86eb7b483b9d85adc436d9e5cb1ff46297a6ed5 Mon Sep 17 00:00:00 2001 From: josephsdavid Date: Tue, 20 Dec 2022 19:20:12 +0000 Subject: [PATCH 3/4] fix typo --- src/losses/functions.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 2f287e90bf..579f84b6f1 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -637,8 +637,8 @@ true See also: [`Losses.focal_loss`](@ref) """ -function logit_focal_loss(ŷ, y; γ=2.0f0, agg=mean, dims=1, ϵ=Flux.epseltype(ŷ)) - Flux.losses._check_sizes(ŷ, y) +function logit_focal_loss(ŷ, y; γ=2.0f0, agg=mean, dims=1, ϵ=epseltype(ŷ)) + _check_sizes(ŷ, y) logpt = logsoftmax(ŷ; dims=dims) logpt .+= ϵ loss = agg(sum(@. -y * (1 - exp.(logpt))^γ * logpt; dims=dims)) From 7698243e0d0f8a00f82b6c6f447cdd9bef8df61b Mon Sep 17 00:00:00 2001 From: David Josephs <42522233+josephsdavid@users.noreply.github.com> Date: Wed, 21 Dec 2022 11:46:31 -0800 Subject: [PATCH 4/4] Update src/losses/functions.jl Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/losses/functions.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 579f84b6f1..457cfae9a4 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -639,10 +639,8 @@ See also: [`Losses.focal_loss`](@ref) """ function logit_focal_loss(ŷ, y; γ=2.0f0, agg=mean, dims=1, ϵ=epseltype(ŷ)) _check_sizes(ŷ, y) - logpt = logsoftmax(ŷ; dims=dims) - logpt .+= ϵ - loss = agg(sum(@. -y * (1 - exp.(logpt))^γ * logpt; dims=dims)) - return loss + logpt = logsoftmax(ŷ; dims) + agg(sum(@. -y * (1 - exp(logpt + ϵ))^γ * (logpt + ϵ); dims)) end """