From bedf1f7eb3249883fb9f9286e8dcd1060ed12796 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Jun 2024 20:51:38 -0700 Subject: [PATCH 01/13] Add xlogx and xlogy functions --- src/Lux.jl | 5 ++++ src/losses/Losses.jl | 15 ++++++++++ src/losses/functions.jl | 0 src/losses/utils.jl | 62 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 82 insertions(+) create mode 100644 src/losses/Losses.jl create mode 100644 src/losses/functions.jl create mode 100644 src/losses/utils.jl diff --git a/src/Lux.jl b/src/Lux.jl index 87fccffc9..920a4e5cf 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -66,6 +66,9 @@ include("layers/display.jl") # AutoDiff include("chainrules.jl") +# Losses +include("losses/Losses.jl") + # Experimental include("contrib/contrib.jl") @@ -108,6 +111,8 @@ export jacobian_vector_product, vector_jacobian_product export batched_jacobian export AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote +export Losses + export f16, f32, f64 export transform diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl new file mode 100644 index 000000000..690d6f4d8 --- /dev/null +++ b/src/losses/Losses.jl @@ -0,0 +1,15 @@ +module Losses # A huge chunk of this code has been derived from Flux.jl + +using PrecompileTools: @recompile_invalidations + +@recompile_invalidations begin + using ChainRulesCore: ChainRulesCore, NoTangent, ZeroTangent, @thunk + using FastClosures: @closure +end + +const CRC = ChainRulesCore + +include("utils.jl") +include("functions.jl") + +end diff --git a/src/losses/functions.jl b/src/losses/functions.jl new file mode 100644 index 000000000..e69de29bb diff --git a/src/losses/utils.jl b/src/losses/utils.jl new file mode 100644 index 000000000..99937d659 --- /dev/null +++ b/src/losses/utils.jl @@ -0,0 +1,62 @@ +""" + xlogx(x::Number) + +Return `x * log(x)` for `x ≥ 0`, handling `x == 0` by taking the limit from above, to get +zero. +""" +@inline function xlogx(x::Number) + result = x * log(x) + return ifelse(iszero(x), zero(result), result) +end + +function CRC.rrule(::typeof(xlogx), x::Number) + iszero(x) && return x, Δ -> (NoTangent(), ZeroTangent()) + logx = log(x) + ∇xlogx = @closure Δ -> (NoTangent(), @thunk(Δ*(logx + true))) + return x * logx, ∇xlogx +end + +function CRC.rrule( + ::typeof(Broadcast.broadcasted), ::typeof(xlogx), x::AbstractArray{<:Number}) + logx = log.(x) + y = x .* logx + ∇xlogx = @closure Δ -> (NoTangent(), NoTangent(), @thunk(Δ.*(logx .+ true))) + return y, ∇xlogx +end + +""" + xlogy(x::Number, y::Number) + +Return `x * log(y)` for `y > 0`, and zero when `x == 0`. +""" +@inline function xlogy(x::Number, y::Number) + result = x * log(y) + return ifelse(iszero(x), zero(result), result) +end + +function CRC.rrule(::typeof(xlogy), x::Number, y::Number) + iszero(x) && return x, Δ -> (NoTangent(), ZeroTangent()) + logy = log(y) + ∇xlogy = @closure Δ -> (NoTangent(), @thunk(Δ*logy), @thunk(Δ * x/y)) + return x * logy, ∇xlogy +end + +function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(xlogy), + x::AbstractArray{<:Number}, y::AbstractArray{<:Number}) + logy = log.(y) + y = x .* logy + ∇xlogy = @closure Δ -> (NoTangent(), NoTangent(), @thunk(Δ.*logy), @thunk(Δ .* x./y)) + return y, ∇xlogy +end + +@inline function __check_sizes(ŷ::AbstractArray, y::AbstractArray) + for d in 1:max(ndims(ŷ), ndims(y)) + if size(ŷ, d) != size(y, d) + throw(DimensionMismatch("loss function expects size(ŷ) = $(size(ŷ)) to match \ + size(y) = $(size(y))")) + end + end +end +@inline __check_sizes(ŷ, y) = nothing + +CRC.@non_differentiable __check_sizes(ŷ::Any, y::Any) From f836a778e9097f98c475b2493d942f09c3e012f1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 15 Jun 2024 23:17:42 -0700 Subject: [PATCH 02/13] [skip ci] Add common loss functions --- Project.toml | 1 + src/Lux.jl | 2 +- src/losses/Losses.jl | 13 +++++- src/losses/functions.jl | 0 src/losses/loss_functions.jl | 78 ++++++++++++++++++++++++++++++++++++ src/losses/utils.jl | 11 +++++ 6 files changed, 103 insertions(+), 2 deletions(-) delete mode 100644 src/losses/functions.jl create mode 100644 src/losses/loss_functions.jl diff --git a/Project.toml b/Project.toml index fc063ccb0..bd874b077 100644 --- a/Project.toml +++ b/Project.toml @@ -27,6 +27,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [weakdeps] diff --git a/src/Lux.jl b/src/Lux.jl index 920a4e5cf..351eacbe3 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -111,7 +111,7 @@ export jacobian_vector_product, vector_jacobian_product export batched_jacobian export AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote -export Losses +export CrossEntropyLoss, L1Loss, MSELoss, MSLELoss export f16, f32, f64 diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 690d6f4d8..0e7e0cd30 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -3,13 +3,24 @@ module Losses # A huge chunk of this code has been derived from Flux.jl using PrecompileTools: @recompile_invalidations @recompile_invalidations begin + using ArgCheck: @argcheck using ChainRulesCore: ChainRulesCore, NoTangent, ZeroTangent, @thunk + using ConcreteStructs: @concrete using FastClosures: @closure + using ..Lux: __unwrap_val + using Markdown: @doc_str + using Statistics: mean end const CRC = ChainRulesCore +abstract type AbstractLossFunction <: Function end + include("utils.jl") -include("functions.jl") +include("loss_functions.jl") + +export CrossEntropyLoss, L1Loss, MSELoss, MSLELoss end + +using .Losses diff --git a/src/losses/functions.jl b/src/losses/functions.jl deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/losses/loss_functions.jl b/src/losses/loss_functions.jl new file mode 100644 index 000000000..f57382ae8 --- /dev/null +++ b/src/losses/loss_functions.jl @@ -0,0 +1,78 @@ +# In this file, doctests which differ in the printed Float32 values won't fail +```@meta +DocTestFilters = r"[0-9\.]+f0" +``` + +function (loss::AbstractLossFunction)(ŷ, y) + __check_sizes(ŷ, y) + return __unsafe_apply_loss(loss, ŷ, y) +end + +function __unsafe_apply_loss end + +@kwdef @concrete struct L1Loss <: AbstractLossFunction + agg = mean +end + +@inline __unsafe_apply_loss(loss::L1Loss, ŷ, y) = __fused_agg(loss.agg, abs, ŷ .- y) + +@kwdef @concrete struct MSELoss <: AbstractLossFunction + agg = mean +end + +@inline __unsafe_apply_loss(loss::MSELoss, ŷ, y) = __fused_agg(loss.agg, abs2, ŷ .- y) + +@kwdef @concrete struct MSLELoss <: AbstractLossFunction + agg = mean + epsilon = nothing +end + +@inline function __unsafe_apply_loss(loss::MSLELoss, ŷ, y) + ϵ = loss.epsilon === nothing ? eps(eltype(ŷ)) : loss.epsilon + return __fused_agg(loss.agg, abs2, log.((ŷ .+ ϵ) ./ (y .+ ϵ))) +end + +@concrete struct CrossEntropyLoss{logits, L <: Union{Nothing, Real}} <: AbstractLossFunction + label_smoothing::L + dims + agg + epsilon +end + +function CrossEntropyLoss(; + dims=1, agg=mean, epsilon=nothing, label_smoothing::Union{Nothing, Real}=nothing, + logits::Union{Bool, Val}=Val(false)) + label_smoothing !== nothing && @argcheck 0 ≤ label_smoothing ≤ 1 + return CrossEntropyLoss{__unwrap_val(logits)}(label_smoothing, dims, agg, epsilon) +end + +for logits in (true, false) + return_expr = logits ? + :(return __fused_agg( + loss.agg, -, sum(y_smooth .* logsoftmax(ŷ; loss.dims); loss.dims))) : + :(return __fused_agg( + loss.agg, -, sum(xlogy.(y_smooth, ŷ .+ ϵ); loss.dims))) + + @eval function __unsafe_apply_loss(loss::CrossEntropyLoss{$(logits)}, ŷ, y) + ϵ = loss.epsilon === nothing ? eps(eltype(ŷ)) : loss.epsilon + y_smooth = __label_smoothing( + loss.label_smoothing, y, promote_type(eltype(ŷ), eltype(y))) + $(return_expr) + end +end + +# TODO: HuberLoss +# TODO: BCELoss +# TODO: KLDivergenceLoss +# TODO: PoissonLoss +# TODO: HingeLoss +# TODO: SquaredHingeLoss +# TODO: DiceCoeffLoss +# TODO: TverskyLoss +# TODO: FocalLoss +# TODO: BinaryFocalLoss +# TODO: SimameseContrastiveLoss + +```@meta +DocTestFilters = nothing +``` diff --git a/src/losses/utils.jl b/src/losses/utils.jl index 99937d659..530575278 100644 --- a/src/losses/utils.jl +++ b/src/losses/utils.jl @@ -60,3 +60,14 @@ end @inline __check_sizes(ŷ, y) = nothing CRC.@non_differentiable __check_sizes(ŷ::Any, y::Any) + +@inline __fused_agg(::typeof(mean), op::OP, x) where {OP} = mean(op, x) +@inline __fused_agg(::typeof(sum), op::OP, x) where {OP} = sum(op, x) +@inline __fused_agg(::Nothing, op::OP, x) where {OP} = op.(x) +@inline __fused_agg(f::F, op::OP, x) where {F, OP} = f(op.(x)) + +@inline __label_smoothing(::Nothing, y, ::Type{T}) where {T} = y +@inline function __label_smoothing(label_smoothing::Real, y, ::Type{T}) where {T} + label_smoothing = T(label_smoothing) + return y .* (1 - label_smoothing) .+ label_smoothing ./ size(y, ndims(y) - 1) +end From bf9ff9b664fff023972e6f3f0432baf5fbff090f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 16 Jun 2024 12:32:23 -0700 Subject: [PATCH 03/13] [skip ci] Add more loss functions --- src/Lux.jl | 3 +- src/losses/Losses.jl | 4 +- src/losses/loss_functions.jl | 136 +++++++++++++++++++++++++++++++---- src/losses/utils.jl | 14 ++++ 4 files changed, 143 insertions(+), 14 deletions(-) diff --git a/src/Lux.jl b/src/Lux.jl index 351eacbe3..3a77b42de 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -111,7 +111,8 @@ export jacobian_vector_product, vector_jacobian_product export batched_jacobian export AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote -export CrossEntropyLoss, L1Loss, MSELoss, MSLELoss +export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, FocalLoss, L1Loss, L2Loss, + MAELoss, MSELoss, MSLELoss, SiameseContrastiveLoss, TverskyLoss export f16, f32, f64 diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 0e7e0cd30..33c121e1f 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -9,6 +9,7 @@ using PrecompileTools: @recompile_invalidations using FastClosures: @closure using ..Lux: __unwrap_val using Markdown: @doc_str + using LuxLib: logsoftmax, logsigmoid using Statistics: mean end @@ -19,7 +20,8 @@ abstract type AbstractLossFunction <: Function end include("utils.jl") include("loss_functions.jl") -export CrossEntropyLoss, L1Loss, MSELoss, MSLELoss +export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, FocalLoss, L1Loss, L2Loss, + MAELoss, MSELoss, MSLELoss, SiameseContrastiveLoss, TverskyLoss end diff --git a/src/losses/loss_functions.jl b/src/losses/loss_functions.jl index f57382ae8..06924d6a6 100644 --- a/src/losses/loss_functions.jl +++ b/src/losses/loss_functions.jl @@ -1,5 +1,6 @@ # In this file, doctests which differ in the printed Float32 values won't fail ```@meta +using Base: func_for_method_checked DocTestFilters = r"[0-9\.]+f0" ``` @@ -10,16 +11,20 @@ end function __unsafe_apply_loss end -@kwdef @concrete struct L1Loss <: AbstractLossFunction +@kwdef @concrete struct MAELoss <: AbstractLossFunction agg = mean end -@inline __unsafe_apply_loss(loss::L1Loss, ŷ, y) = __fused_agg(loss.agg, abs, ŷ .- y) +const L1Loss = MAELoss + +@inline __unsafe_apply_loss(loss::MAELoss, ŷ, y) = __fused_agg(loss.agg, abs, ŷ .- y) @kwdef @concrete struct MSELoss <: AbstractLossFunction agg = mean end +const L2Loss = MSELoss + @inline __unsafe_apply_loss(loss::MSELoss, ŷ, y) = __fused_agg(loss.agg, abs2, ŷ .- y) @kwdef @concrete struct MSLELoss <: AbstractLossFunction @@ -28,7 +33,8 @@ end end @inline function __unsafe_apply_loss(loss::MSLELoss, ŷ, y) - ϵ = loss.epsilon === nothing ? eps(eltype(ŷ)) : loss.epsilon + T = promote_type(eltype(ŷ), eltype(y)) + ϵ = __get_epsilon(T, loss.epsilon) return __fused_agg(loss.agg, abs2, log.((ŷ .+ ϵ) ./ (y .+ ϵ))) end @@ -54,24 +60,130 @@ for logits in (true, false) loss.agg, -, sum(xlogy.(y_smooth, ŷ .+ ϵ); loss.dims))) @eval function __unsafe_apply_loss(loss::CrossEntropyLoss{$(logits)}, ŷ, y) - ϵ = loss.epsilon === nothing ? eps(eltype(ŷ)) : loss.epsilon - y_smooth = __label_smoothing( - loss.label_smoothing, y, promote_type(eltype(ŷ), eltype(y))) + T = promote_type(eltype(ŷ), eltype(y)) + ϵ = __get_epsilon(T, loss.epsilon) + y_smooth = __label_smoothing(loss.label_smoothing, y, T) $(return_expr) end end +@concrete struct BinaryCrossEntropyLoss{logits, L <: Union{Nothing, Real}} <: + AbstractLossFunction + label_smoothing::L + agg + epsilon +end + +function BinaryCrossEntropyLoss(; + agg=mean, epsilon=nothing, label_smoothing::Union{Nothing, Real}=nothing, + logits::Union{Bool, Val}=Val(false)) + label_smoothing !== nothing && @argcheck 0 ≤ label_smoothing ≤ 1 + return BinaryCrossEntropyLoss{__unwrap_val(logits)}(label_smoothing, agg, epsilon) +end + +for logits in (true, false) + return_expr = logits ? :(return loss.agg((1 .- y_smooth) .* y̋ .- logsigmoid.(ŷ))) : + :(return loss.agg(-xlogy.(y_smooth, ŷ .+ ϵ) .- + xlogy.(1 .- y_smooth, 1 .- ŷ .+ ϵ))) + + @eval function __unsafe_apply_loss(loss::BinaryCrossEntropyLoss{$(logits)}, ŷ, y) + T = promote_type(eltype(ŷ), eltype(y)) + ϵ = __get_epsilon(T, loss.epsilon) + y_smooth = __label_smoothing_binary(loss.label_smoothing, y, T) + $(return_expr) + end +end + +@kwdef @concrete struct BinaryFocalLoss <: AbstractLossFunction + gamma = 2 + agg = mean + epsilon = nothing +end + +@inline function __unsafe_apply_loss(loss::BinaryFocalLoss, ŷ, y) + T = promote_type(eltype(ŷ), eltype(y)) + γ = loss.gamma isa Integer ? loss.gamma : T(loss.gamma) + ϵ = __get_epsilon(T, loss.epsilon) + ŷϵ = ŷ .+ ϵ + p_t = y .* ŷϵ + (1 .- y) .* (1 .- ŷϵ) + return __fused_agg(loss.agg, -, (1 .- p_t) .^ γ .* log.(p_t)) +end + +@kwdef @concrete struct FocalLoss <: AbstractLossFunction + gamma = 2 + dims = 1 + agg = mean + epsilon = nothing +end + +@inline function __unsafe_apply_loss(loss::FocalLoss, ŷ, y) + T = promote_type(eltype(ŷ), eltype(y)) + γ = loss.gamma isa Integer ? loss.gamma : T(loss.gamma) + ϵ = __get_epsilon(T, loss.epsilon) + ŷϵ = ŷ .+ ϵ + return loss.agg(sum(-y .* (1 .- ŷϵ) .^ γ .+ log.(ŷϵ); loss.dims)) +end + +@concrete struct SiameseContrastiveLoss <: AbstractLossFunction + margin + agg +end + +function SiameseContrastiveLoss(; margin::Real=true, agg=mean) + @argcheck margin ≥ 0 + return SiameseContrastiveLoss(margin, agg) +end + +@inline function __unsafe_apply_loss(loss::SiameseContrastiveLoss, ŷ, y) + z = @. (1 - y) * ŷ^2 + y * max(0, loss.margin - ŷ)^2 + return loss.agg(z) +end + +@kwdef @concrete struct TverskyLoss <: AbstractLossFunction + beta = 0.7 + smooth = true + agg = mean +end + +function __unsafe_apply_loss(loss::TverskyLoss, ŷ, y) + T = promote_type(eltype(ŷ), eltype(y)) + β = T(loss.beta) + α = T(loss.smooth) + + yŷ = y .* ŷ + dims = __get_dims(yŷ) + + TP = sum(yŷ; dims) + FP = sum((true .- y) .* ŷ; dims) + FN = sum(y .* (true .- ŷ); dims) + + return loss.agg(1 - (TP + α) / (TP + α * FP + β * FN + α)) +end + +@kwdef @concrete struct DiceCoeffLoss <: AbstractLossFunction + smooth = true + agg = mean +end + +function __unsafe_apply_loss(loss::DiceCoeffLoss, ŷ, y) + T = promote_type(eltype(ŷ), eltype(y)) + α = T(loss.smooth) + + yŷ = y .* ŷ + dims = __get_dims(yŷ) + + num = T(2) .* sum(yŷ; dims) .+ α + den = sum(abs2, ŷ; dims) .+ sum(abs2, y; dims) .+ α + + return loss.agg(true - num ./ den) +end + + # TODO: HuberLoss -# TODO: BCELoss # TODO: KLDivergenceLoss # TODO: PoissonLoss # TODO: HingeLoss # TODO: SquaredHingeLoss -# TODO: DiceCoeffLoss -# TODO: TverskyLoss -# TODO: FocalLoss -# TODO: BinaryFocalLoss -# TODO: SimameseContrastiveLoss ```@meta DocTestFilters = nothing diff --git a/src/losses/utils.jl b/src/losses/utils.jl index 530575278..00e150b68 100644 --- a/src/losses/utils.jl +++ b/src/losses/utils.jl @@ -1,5 +1,6 @@ """ xlogx(x::Number) +using Base: func_for_method_checked Return `x * log(x)` for `x ≥ 0`, handling `x == 0` by taking the limit from above, to get zero. @@ -71,3 +72,16 @@ CRC.@non_differentiable __check_sizes(ŷ::Any, y::Any) label_smoothing = T(label_smoothing) return y .* (1 - label_smoothing) .+ label_smoothing ./ size(y, ndims(y) - 1) end + +@inline __label_smoothing_binary(::Nothing, y, ::Type{T}) where {T} = y +@inline function __label_smoothing_binary(label_smoothing::Real, y, ::Type{T}) where {T} + label_smoothing = T(label_smoothing) + return y .* (1 - label_smoothing) .+ label_smoothing ./ 2 +end + +@inline __get_epsilon(::Type{T}, ϵ::Real) where {T} = T(ϵ) +@inline __get_epsilon(::Type{T}, ::Nothing) where {T} = eps(T) + +@inline __get_dims(_) = Colon() +@inline __get_dims(::AbstractVector) = Colon() +@inline __get_dims(::AbstractArray{T, N}) where {T, N} = 1:(N - 1) From bcb75029cd3bd89d2e4f59df43172ffe5e223f9d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 16 Jun 2024 17:57:15 -0700 Subject: [PATCH 04/13] More loss functions --- Project.toml | 1 + src/Lux.jl | 5 +-- src/losses/Losses.jl | 7 ++-- src/losses/loss_functions.jl | 65 +++++++++++++++++++++++++++++++++--- 4 files changed, 69 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index bd874b077..9e1f596b2 100644 --- a/Project.toml +++ b/Project.toml @@ -100,6 +100,7 @@ MacroTools = "0.5.13" Markdown = "1.10" NCCL = "0.1.1" OhMyThreads = "0.5.1" +OneHotArrays = "0.2.5" Optimisers = "0.3" Pkg = "1.10" PrecompileTools = "1.2" diff --git a/src/Lux.jl b/src/Lux.jl index 3a77b42de..9d3f07184 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -111,8 +111,9 @@ export jacobian_vector_product, vector_jacobian_product export batched_jacobian export AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote -export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, FocalLoss, L1Loss, L2Loss, - MAELoss, MSELoss, MSLELoss, SiameseContrastiveLoss, TverskyLoss +export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, FocalLoss, HingeLoss, + HuberLoss, KLDivergenceLoss, L1Loss, L2Loss, MAELoss, MSELoss, MSLELoss, PoissonLoss, + SiameseContrastiveLoss, SquaredHingeLoss, TverskyLoss export f16, f32, f64 diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 33c121e1f..c6ed58ccc 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -1,3 +1,5 @@ +# Eventually the idea is to create a package `DeepLearningLosses.jl` and move this +# functionality there and simply reexport it here. module Losses # A huge chunk of this code has been derived from Flux.jl using PrecompileTools: @recompile_invalidations @@ -20,8 +22,9 @@ abstract type AbstractLossFunction <: Function end include("utils.jl") include("loss_functions.jl") -export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, FocalLoss, L1Loss, L2Loss, - MAELoss, MSELoss, MSLELoss, SiameseContrastiveLoss, TverskyLoss +export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, FocalLoss, HingeLoss, + HuberLoss, KLDivergenceLoss, L1Loss, L2Loss, MAELoss, MSELoss, MSLELoss, + PoissonLoss, SiameseContrastiveLoss, SquaredHingeLoss, TverskyLoss end diff --git a/src/losses/loss_functions.jl b/src/losses/loss_functions.jl index 06924d6a6..ccdea30c6 100644 --- a/src/losses/loss_functions.jl +++ b/src/losses/loss_functions.jl @@ -178,12 +178,67 @@ function __unsafe_apply_loss(loss::DiceCoeffLoss, ŷ, y) return loss.agg(true - num ./ den) end +@kwdef @concrete struct HuberLoss <: AbstractLossFunction + delta = 1 + agg = mean +end + +function __unsafe_apply_loss(loss::HuberLoss, ŷ, y) + T = promote_type(eltype(ŷ), eltype(y)) + return __fused_agg(loss.agg, Base.Fix2(__huber_metric, T(loss.delta)), abs.(ŷ .- y)) +end + +function __huber_metric(err::T1, δ::T2) where {T1, T2} + T = promote_type(T1, T2) + x = T(1 // 2) + return ifelse(err < δ, err^2 * x, δ * (err - x * δ)) +end + +@kwdef @concrete struct HingeLoss <: AbstractLossFunction + agg = mean +end + +function __unsafe_apply_loss(loss::HingeLoss, ŷ, y) + T = promote_type(eltype(ŷ), eltype(y)) + return __fused_agg(loss.agg, Base.Fix1(max, T(0)), 1 .- y .* ŷ) +end + +@kwdef @concrete struct SquaredHingeLoss <: AbstractLossFunction + agg = mean +end + +function __unsafe_apply_loss(loss::SquaredHingeLoss, ŷ, y) + T = promote_type(eltype(ŷ), eltype(y)) + return __fused_agg(loss.agg, abs2 ∘ Base.Fix2(max, T(0)), 1 .- y .* ŷ) +end -# TODO: HuberLoss -# TODO: KLDivergenceLoss -# TODO: PoissonLoss -# TODO: HingeLoss -# TODO: SquaredHingeLoss +@concrete struct KLDivergenceLoss{C <: CrossEntropyLoss} <: AbstractLossFunction + agg + dims + celoss::C +end + +function KLDivergenceLoss(; dims=1, agg=mean, epsilon=nothing, label_smoothing=nothing) + celoss = CrossEntropyLoss(; dims, agg, epsilon, label_smoothing) + return KLDivergenceLoss(agg, dims, celoss) +end + +function __unsafe_apply_loss(loss::KLDivergenceLoss, ŷ, y) + cross_entropy = __unsafe_apply_loss(loss.celoss, ŷ, y) + entropy = loss.agg(sum(xlogx, y; loss.dims)) + return entropy + cross_entropy +end + +@kwdef @concrete struct PoissonLoss <: AbstractLossFunction + agg = mean + epsilon = nothing +end + +function __unsafe_apply_loss(loss::PoissonLoss, ŷ, y) + T = promote_type(eltype(ŷ), eltype(y)) + ϵ = __get_epsilon(T, loss.epsilon) + return loss.agg(ŷ .- xlogy.(y, ŷ .+ ϵ)) +end ```@meta DocTestFilters = nothing From 2009f7be0c288fa18bfa2cec4c0c00391c248376 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 16 Jun 2024 20:17:51 -0700 Subject: [PATCH 05/13] Add documentation with doctests --- Project.toml | 1 - src/Lux.jl | 6 +- src/losses/Losses.jl | 4 +- src/losses/loss_functions.jl | 348 +++++++++++++++++++++++++---------- src/losses/utils.jl | 1 - 5 files changed, 251 insertions(+), 109 deletions(-) diff --git a/Project.toml b/Project.toml index 9e1f596b2..bd874b077 100644 --- a/Project.toml +++ b/Project.toml @@ -100,7 +100,6 @@ MacroTools = "0.5.13" Markdown = "1.10" NCCL = "0.1.1" OhMyThreads = "0.5.1" -OneHotArrays = "0.2.5" Optimisers = "0.3" Pkg = "1.10" PrecompileTools = "1.2" diff --git a/src/Lux.jl b/src/Lux.jl index 9d3f07184..3b8ca68f2 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -111,9 +111,9 @@ export jacobian_vector_product, vector_jacobian_product export batched_jacobian export AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote -export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, FocalLoss, HingeLoss, - HuberLoss, KLDivergenceLoss, L1Loss, L2Loss, MAELoss, MSELoss, MSLELoss, PoissonLoss, - SiameseContrastiveLoss, SquaredHingeLoss, TverskyLoss +export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, DiceCoeffLoss, FocalLoss, + HingeLoss, HuberLoss, KLDivergenceLoss, L1Loss, L2Loss, MAELoss, MSELoss, MSLELoss, + PoissonLoss, SiameseContrastiveLoss, SquaredHingeLoss, TverskyLoss export f16, f32, f64 diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index c6ed58ccc..17187a14f 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -22,8 +22,8 @@ abstract type AbstractLossFunction <: Function end include("utils.jl") include("loss_functions.jl") -export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, FocalLoss, HingeLoss, - HuberLoss, KLDivergenceLoss, L1Loss, L2Loss, MAELoss, MSELoss, MSLELoss, +export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, DiceCoeffLoss, FocalLoss, + HingeLoss, HuberLoss, KLDivergenceLoss, L1Loss, L2Loss, MAELoss, MSELoss, MSLELoss, PoissonLoss, SiameseContrastiveLoss, SquaredHingeLoss, TverskyLoss end diff --git a/src/losses/loss_functions.jl b/src/losses/loss_functions.jl index ccdea30c6..3f7b1081d 100644 --- a/src/losses/loss_functions.jl +++ b/src/losses/loss_functions.jl @@ -1,6 +1,5 @@ # In this file, doctests which differ in the printed Float32 values won't fail ```@meta -using Base: func_for_method_checked DocTestFilters = r"[0-9\.]+f0" ``` @@ -11,62 +10,65 @@ end function __unsafe_apply_loss end -@kwdef @concrete struct MAELoss <: AbstractLossFunction - agg = mean -end +@doc doc""" + BinaryCrossEntropyLoss(; agg = mean, epsilon = nothing, + label_smoothing::Union{Nothing, Real}=nothing, + logits::Union{Bool, Val}=Val(false)) -const L1Loss = MAELoss +Binary Cross Entropy Loss with optional label smoothing and fused logit computation. -@inline __unsafe_apply_loss(loss::MAELoss, ŷ, y) = __fused_agg(loss.agg, abs, ŷ .- y) +Returns the binary cross entropy loss computed as: -@kwdef @concrete struct MSELoss <: AbstractLossFunction - agg = mean -end + - If `logits` is either `false` or `Val(false)`: -const L2Loss = MSELoss +$$agg\left(-y\tilde * \log\left(y\hat + \epsilon\right) - (1 - y\tilde) * \log\left(1 - y\hat + \epsilon\right)\right)$$ -@inline __unsafe_apply_loss(loss::MSELoss, ŷ, y) = __fused_agg(loss.agg, abs2, ŷ .- y) + - If `logits` is `true` or `Val(true)`: -@kwdef @concrete struct MSLELoss <: AbstractLossFunction - agg = mean - epsilon = nothing -end +$$agg\left((1 - y\tilde) * y\hat - log\sigma(y\hat)\right)$$ -@inline function __unsafe_apply_loss(loss::MSLELoss, ŷ, y) - T = promote_type(eltype(ŷ), eltype(y)) - ϵ = __get_epsilon(T, loss.epsilon) - return __fused_agg(loss.agg, abs2, log.((ŷ .+ ϵ) ./ (y .+ ϵ))) -end +The value of $y\tilde$ is computed using label smoothing. If `label_smoothing` is `nothing`, +then no label smoothing is applied. If `label_smoothing` is a real number $\in [0, 1]$, +then the value of $y\tilde$ is: -@concrete struct CrossEntropyLoss{logits, L <: Union{Nothing, Real}} <: AbstractLossFunction - label_smoothing::L - dims - agg - epsilon -end +$$y\tilde = (1 - \alpha) * y + \alpha * 0.5$$ -function CrossEntropyLoss(; - dims=1, agg=mean, epsilon=nothing, label_smoothing::Union{Nothing, Real}=nothing, - logits::Union{Bool, Val}=Val(false)) - label_smoothing !== nothing && @argcheck 0 ≤ label_smoothing ≤ 1 - return CrossEntropyLoss{__unwrap_val(logits)}(label_smoothing, dims, agg, epsilon) -end +where $\alpha$ is the value of `label_smoothing`. -for logits in (true, false) - return_expr = logits ? - :(return __fused_agg( - loss.agg, -, sum(y_smooth .* logsoftmax(ŷ; loss.dims); loss.dims))) : - :(return __fused_agg( - loss.agg, -, sum(xlogy.(y_smooth, ŷ .+ ϵ); loss.dims))) +## Example - @eval function __unsafe_apply_loss(loss::CrossEntropyLoss{$(logits)}, ŷ, y) - T = promote_type(eltype(ŷ), eltype(y)) - ϵ = __get_epsilon(T, loss.epsilon) - y_smooth = __label_smoothing(loss.label_smoothing, y, T) - $(return_expr) - end -end +```jldoctest +julia> bce = BinaryCrossEntropyLoss(); + +julia> y_bin = Bool[1, 0, 1]; + +julia> y_model = Float32[2, -1, pi] +3-element Vector{Float32}: + 2.0 + -1.0 + 3.1415927 + +julia> logitbce = BinaryCrossEntropyLoss(; logits=Val(true)); + +julia> logitbce(y_model, y_bin) +0.160832f0 +julia> bce(sigmoid.(y_model), y_bin) +0.16083185f0 + +julia> bce_ls = BinaryCrossEntropyLoss(label_smoothing=0.1); + +julia> bce_ls(sigmoid.(y_model), y_bin) > bce(sigmoid.(y_model), y_bin) +true + +julia> logitbce_ls = BinaryCrossEntropyLoss(label_smoothing=0.1, logits=Val(true)); + +julia> logitbce_ls(y_model, y_bin) > logitbce(y_model, y_bin) +true +``` + +See also [`CrossEntropyLoss`](@ref). +""" @concrete struct BinaryCrossEntropyLoss{logits, L <: Union{Nothing, Real}} <: AbstractLossFunction label_smoothing::L @@ -82,7 +84,7 @@ function BinaryCrossEntropyLoss(; end for logits in (true, false) - return_expr = logits ? :(return loss.agg((1 .- y_smooth) .* y̋ .- logsigmoid.(ŷ))) : + return_expr = logits ? :(return loss.agg((1 .- y_smooth) .* ŷ .- logsigmoid.(ŷ))) : :(return loss.agg(-xlogy.(y_smooth, ŷ .+ ϵ) .- xlogy.(1 .- y_smooth, 1 .- ŷ .+ ϵ))) @@ -94,6 +96,32 @@ for logits in (true, false) end end +@doc doc""" + BinaryFocalLoss(; gamma = 2, agg = mean, epsilon = nothing) + +Return the [binary focal loss](https://arxiv.org/pdf/1708.02002.pdf). The model input, +$y\hat$, is expected to be normalized (i.e. [softmax](@ref Softmax) output). + +For $\gamma = 0$ this is equivalent to [`BinaryCrossEntropyLoss`](@ref). + +## Example + +```jldoctest +julia> y = [0 1 0 + 1 0 1]; + +julia> ŷ = [0.268941 0.5 0.268941 + 0.731059 0.5 0.731059]; + +julia> BinaryFocalLoss()(ŷ, y) ≈ 0.0728675615927385 +true + +julia> BinaryFocalLoss(gamma=0)(ŷ, y) ≈ BinaryCrossEntropyLoss()(ŷ, y) +true +``` + +See also [`FocalLoss`](@ref) for multi-class focal loss. +""" @kwdef @concrete struct BinaryFocalLoss <: AbstractLossFunction gamma = 2 agg = mean @@ -109,55 +137,33 @@ end return __fused_agg(loss.agg, -, (1 .- p_t) .^ γ .* log.(p_t)) end -@kwdef @concrete struct FocalLoss <: AbstractLossFunction - gamma = 2 - dims = 1 - agg = mean - epsilon = nothing -end - -@inline function __unsafe_apply_loss(loss::FocalLoss, ŷ, y) - T = promote_type(eltype(ŷ), eltype(y)) - γ = loss.gamma isa Integer ? loss.gamma : T(loss.gamma) - ϵ = __get_epsilon(T, loss.epsilon) - ŷϵ = ŷ .+ ϵ - return loss.agg(sum(-y .* (1 .- ŷϵ) .^ γ .+ log.(ŷϵ); loss.dims)) -end - -@concrete struct SiameseContrastiveLoss <: AbstractLossFunction - margin +@concrete struct CrossEntropyLoss{logits, L <: Union{Nothing, Real}} <: AbstractLossFunction + label_smoothing::L + dims agg + epsilon end -function SiameseContrastiveLoss(; margin::Real=true, agg=mean) - @argcheck margin ≥ 0 - return SiameseContrastiveLoss(margin, agg) -end - -@inline function __unsafe_apply_loss(loss::SiameseContrastiveLoss, ŷ, y) - z = @. (1 - y) * ŷ^2 + y * max(0, loss.margin - ŷ)^2 - return loss.agg(z) -end - -@kwdef @concrete struct TverskyLoss <: AbstractLossFunction - beta = 0.7 - smooth = true - agg = mean +function CrossEntropyLoss(; + dims=1, agg=mean, epsilon=nothing, label_smoothing::Union{Nothing, Real}=nothing, + logits::Union{Bool, Val}=Val(false)) + label_smoothing !== nothing && @argcheck 0 ≤ label_smoothing ≤ 1 + return CrossEntropyLoss{__unwrap_val(logits)}(label_smoothing, dims, agg, epsilon) end -function __unsafe_apply_loss(loss::TverskyLoss, ŷ, y) - T = promote_type(eltype(ŷ), eltype(y)) - β = T(loss.beta) - α = T(loss.smooth) - - yŷ = y .* ŷ - dims = __get_dims(yŷ) - - TP = sum(yŷ; dims) - FP = sum((true .- y) .* ŷ; dims) - FN = sum(y .* (true .- ŷ); dims) +for logits in (true, false) + return_expr = logits ? + :(return __fused_agg( + loss.agg, -, sum(y_smooth .* logsoftmax(ŷ; loss.dims); loss.dims))) : + :(return __fused_agg( + loss.agg, -, sum(xlogy.(y_smooth, ŷ .+ ϵ); loss.dims))) - return loss.agg(1 - (TP + α) / (TP + α * FP + β * FN + α)) + @eval function __unsafe_apply_loss(loss::CrossEntropyLoss{$(logits)}, ŷ, y) + T = promote_type(eltype(ŷ), eltype(y)) + ϵ = __get_epsilon(T, loss.epsilon) + y_smooth = __label_smoothing(loss.label_smoothing, y, T) + $(return_expr) + end end @kwdef @concrete struct DiceCoeffLoss <: AbstractLossFunction @@ -178,20 +184,19 @@ function __unsafe_apply_loss(loss::DiceCoeffLoss, ŷ, y) return loss.agg(true - num ./ den) end -@kwdef @concrete struct HuberLoss <: AbstractLossFunction - delta = 1 +@kwdef @concrete struct FocalLoss <: AbstractLossFunction + gamma = 2 + dims = 1 agg = mean + epsilon = nothing end -function __unsafe_apply_loss(loss::HuberLoss, ŷ, y) +@inline function __unsafe_apply_loss(loss::FocalLoss, ŷ, y) T = promote_type(eltype(ŷ), eltype(y)) - return __fused_agg(loss.agg, Base.Fix2(__huber_metric, T(loss.delta)), abs.(ŷ .- y)) -end - -function __huber_metric(err::T1, δ::T2) where {T1, T2} - T = promote_type(T1, T2) - x = T(1 // 2) - return ifelse(err < δ, err^2 * x, δ * (err - x * δ)) + γ = loss.gamma isa Integer ? loss.gamma : T(loss.gamma) + ϵ = __get_epsilon(T, loss.epsilon) + ŷϵ = ŷ .+ ϵ + return loss.agg(sum(-y .* (1 .- ŷϵ) .^ γ .+ log.(ŷϵ); loss.dims)) end @kwdef @concrete struct HingeLoss <: AbstractLossFunction @@ -203,13 +208,20 @@ function __unsafe_apply_loss(loss::HingeLoss, ŷ, y) return __fused_agg(loss.agg, Base.Fix1(max, T(0)), 1 .- y .* ŷ) end -@kwdef @concrete struct SquaredHingeLoss <: AbstractLossFunction +@kwdef @concrete struct HuberLoss <: AbstractLossFunction + delta = 1 agg = mean end -function __unsafe_apply_loss(loss::SquaredHingeLoss, ŷ, y) +function __unsafe_apply_loss(loss::HuberLoss, ŷ, y) T = promote_type(eltype(ŷ), eltype(y)) - return __fused_agg(loss.agg, abs2 ∘ Base.Fix2(max, T(0)), 1 .- y .* ŷ) + return __fused_agg(loss.agg, Base.Fix2(__huber_metric, T(loss.delta)), abs.(ŷ .- y)) +end + +function __huber_metric(err::T1, δ::T2) where {T1, T2} + T = promote_type(T1, T2) + x = T(1 // 2) + return ifelse(err < δ, err^2 * x, δ * (err - x * δ)) end @concrete struct KLDivergenceLoss{C <: CrossEntropyLoss} <: AbstractLossFunction @@ -229,6 +241,93 @@ function __unsafe_apply_loss(loss::KLDivergenceLoss, ŷ, y) return entropy + cross_entropy end +@doc doc""" + MAELoss(; agg = mean) + +Returns the loss corresponding to mean absolute error: + +$$agg\left(\left| y\hat - y \right|\right)$$ + +## Example + +```jldoctest +julia> loss = MAELoss(); + +julia> y_model = [1.1, 1.9, 3.1]; + +julia> loss(y_model, 1:3) +0.10000000000000009 +``` +""" +@kwdef @concrete struct MAELoss <: AbstractLossFunction + agg = mean +end + +const L1Loss = MAELoss + +@inline __unsafe_apply_loss(loss::MAELoss, ŷ, y) = __fused_agg(loss.agg, abs, ŷ .- y) + +@doc doc""" + MSELoss(; agg = mean) + +Returns the loss corresponding to mean squared error: + +$$agg\left(\left( y\hat - y \right)^2\right)$$ + +## Example + +```jldoctest +julia> loss = MSELoss(); + +julia> y_model = [1.1, 1.9, 3.1]; + +julia> loss(y_model, 1:3) +0.010000000000000018 +``` + +See also [`MSELoss`](@ref). +""" +@kwdef @concrete struct MSELoss <: AbstractLossFunction + agg = mean +end + +const L2Loss = MSELoss + +@inline __unsafe_apply_loss(loss::MSELoss, ŷ, y) = __fused_agg(loss.agg, abs2, ŷ .- y) + +@doc doc""" + MSLELoss(; agg = mean, epsilon = nothing) + +Returns the loss corresponding to mean squared logarithmic error: + +$$agg\left(\left( \log\left( y\hat + \epsilon \right) - \log\left( y + \epsilon \right) \right)^2\right)$$ + +`epsilon` is added to both `y` and `ŷ` to prevent taking the logarithm of zero. If `epsilon` +is `nothing`, then we set it to `eps()`. + +## Example + +```jldoctest +julia> loss = MSLELoss(); + +julia> loss(Float32[1.1, 2.2, 3.3], 1:3) +0.009084041f0 + +julia> loss(Float32[0.9, 1.8, 2.7], 1:3) +0.011100831f0 +``` +""" +@kwdef @concrete struct MSLELoss <: AbstractLossFunction + agg = mean + epsilon = nothing +end + +@inline function __unsafe_apply_loss(loss::MSLELoss, ŷ, y) + T = promote_type(eltype(ŷ), eltype(y)) + ϵ = __get_epsilon(T, loss.epsilon) + return __fused_agg(loss.agg, abs2 ∘ log, (ŷ .+ ϵ) ./ (y .+ ϵ)) +end + @kwdef @concrete struct PoissonLoss <: AbstractLossFunction agg = mean epsilon = nothing @@ -240,6 +339,51 @@ function __unsafe_apply_loss(loss::PoissonLoss, ŷ, y) return loss.agg(ŷ .- xlogy.(y, ŷ .+ ϵ)) end +@concrete struct SiameseContrastiveLoss <: AbstractLossFunction + margin + agg +end + +function SiameseContrastiveLoss(; margin::Real=true, agg=mean) + @argcheck margin ≥ 0 + return SiameseContrastiveLoss(margin, agg) +end + +@inline function __unsafe_apply_loss(loss::SiameseContrastiveLoss, ŷ, y) + z = @. (1 - y) * ŷ^2 + y * max(0, loss.margin - ŷ)^2 + return loss.agg(z) +end + +@kwdef @concrete struct SquaredHingeLoss <: AbstractLossFunction + agg = mean +end + +function __unsafe_apply_loss(loss::SquaredHingeLoss, ŷ, y) + T = promote_type(eltype(ŷ), eltype(y)) + return __fused_agg(loss.agg, abs2 ∘ Base.Fix2(max, T(0)), 1 .- y .* ŷ) +end + +@kwdef @concrete struct TverskyLoss <: AbstractLossFunction + beta = 0.7 + smooth = true + agg = mean +end + +function __unsafe_apply_loss(loss::TverskyLoss, ŷ, y) + T = promote_type(eltype(ŷ), eltype(y)) + β = T(loss.beta) + α = T(loss.smooth) + + yŷ = y .* ŷ + dims = __get_dims(yŷ) + + TP = sum(yŷ; dims) + FP = sum((true .- y) .* ŷ; dims) + FN = sum(y .* (true .- ŷ); dims) + + return loss.agg(1 - (TP + α) / (TP + α * FP + β * FN + α)) +end + ```@meta DocTestFilters = nothing ``` diff --git a/src/losses/utils.jl b/src/losses/utils.jl index 00e150b68..296a8dd8f 100644 --- a/src/losses/utils.jl +++ b/src/losses/utils.jl @@ -1,6 +1,5 @@ """ xlogx(x::Number) -using Base: func_for_method_checked Return `x * log(x)` for `x ≥ 0`, handling `x == 0` by taking the limit from above, to get zero. From 4e360dbdb61ea58398475bdced9b612749154a68 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 18 Jun 2024 22:31:13 -0700 Subject: [PATCH 06/13] Add docstrings --- src/losses/loss_functions.jl | 330 +++++++++++++++++++++++++++++++++-- src/losses/utils.jl | 2 +- 2 files changed, 317 insertions(+), 15 deletions(-) diff --git a/src/losses/loss_functions.jl b/src/losses/loss_functions.jl index 3f7b1081d..b4cd786fc 100644 --- a/src/losses/loss_functions.jl +++ b/src/losses/loss_functions.jl @@ -21,17 +21,17 @@ Returns the binary cross entropy loss computed as: - If `logits` is either `false` or `Val(false)`: -$$agg\left(-y\tilde * \log\left(y\hat + \epsilon\right) - (1 - y\tilde) * \log\left(1 - y\hat + \epsilon\right)\right)$$ +$$\text{agg}\left(-\tilde{y} * \log\left(\hat{y} + \epsilon\right) - (1 - \tilde{y}) * \log\left(1 - \hat{y} + \epsilon\right)\right)$$ - If `logits` is `true` or `Val(true)`: -$$agg\left((1 - y\tilde) * y\hat - log\sigma(y\hat)\right)$$ +$$\text{agg}\left((1 - \tilde{y}) * \hat{y} - log\sigma(\hat{y})\right)$$ -The value of $y\tilde$ is computed using label smoothing. If `label_smoothing` is `nothing`, -then no label smoothing is applied. If `label_smoothing` is a real number $\in [0, 1]$, -then the value of $y\tilde$ is: +The value of $\tilde{y}$ is computed using label smoothing. If `label_smoothing` is +`nothing`, then no label smoothing is applied. If `label_smoothing` is a real number +$\in [0, 1]$, then the value of $\tilde{y}$ is: -$$y\tilde = (1 - \alpha) * y + \alpha * 0.5$$ +$$\tilde{y} = (1 - \alpha) * y + \alpha * 0.5$$ where $\alpha$ is the value of `label_smoothing`. @@ -99,8 +99,8 @@ end @doc doc""" BinaryFocalLoss(; gamma = 2, agg = mean, epsilon = nothing) -Return the [binary focal loss](https://arxiv.org/pdf/1708.02002.pdf). The model input, -$y\hat$, is expected to be normalized (i.e. [softmax](@ref Softmax) output). +Return the binary focal loss [1]. The model input, $\hat{y}$, is expected to be normalized +(i.e. [softmax](@ref Softmax) output). For $\gamma = 0$ this is equivalent to [`BinaryCrossEntropyLoss`](@ref). @@ -121,6 +121,11 @@ true ``` See also [`FocalLoss`](@ref) for multi-class focal loss. + +## References + +[1] Lin, Tsung-Yi, et al. "Focal loss for dense object detection." Proceedings of the IEEE +international conference on computer vision. 2017. """ @kwdef @concrete struct BinaryFocalLoss <: AbstractLossFunction gamma = 2 @@ -137,6 +142,49 @@ end return __fused_agg(loss.agg, -, (1 .- p_t) .^ γ .* log.(p_t)) end +@doc doc""" + CrossEntropyLoss(; agg=mean, epsilon=nothing, dims=1, + label_smoothing::Union{Nothing, Real}=nothing) + +Return the cross entropy loss which is used in multi-class classification tasks. The input, +$\hat{y}$, is expected to be normalized (i.e. `softmax` output) if `logits` is `false` or +`Val(false)`. + +The loss is calculated as: + +$$\text{agg}\left(-\sum \tilde{y} \log(\hat{y} + \epsilon)\right)$$ + +where $\epsilon$ is the smoothing factor. + +## Example + +```jldoctest +julia> y = [1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0] +3×5 Matrix{Int64}: + 1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0 + +julia> y_model = softmax(reshape(-7:7, 3, 5) .* 1f0) +3×5 Matrix{Float32}: + 0.0900306 0.0900306 0.0900306 0.0900306 0.0900306 + 0.244728 0.244728 0.244728 0.244728 0.244728 + 0.665241 0.665241 0.665241 0.665241 0.665241 + +julia> CrossEntropyLoss()(y_model, y) +1.6076053f0 + +julia> 5 * 1.6076053f0 ≈ CrossEntropyLoss(; agg=sum)(y_model, y) +true + +julia> CrossEntropyLoss(label_smoothing=0.15)(y_model, y) +1.5776052f0 +``` + +See also [`BinaryCrossEntropyLoss`](@ref) for binary classification tasks. +""" @concrete struct CrossEntropyLoss{logits, L <: Union{Nothing, Real}} <: AbstractLossFunction label_smoothing::L dims @@ -166,6 +214,34 @@ for logits in (true, false) end end +@doc doc""" + DiceCoeffLoss(; smooth = true, agg = mean) + +Return the Dice Coefficient loss [1] which is used in segmentation tasks. The dice +coefficient is similar to the F1_score. Loss calculated as: + +$$agg\left(1 - \frac{2 \sum \tilde{y} \hat{y} + \alpha}{\sum \tilde{y}^2 + \sum \hat{y}^2 + \alpha}\right)$$ + +where $\alpha$ is the smoothing factor (`smooth`). + +## Example + +```jldoctest +julia> y_pred = [1.1, 2.1, 3.1]; + +julia> DiceCoeffLoss()(y_pred, 1:3) +0.000992391663909964 + +julia> 1 - DiceCoeffLoss()(y_pred, 1:3) +0.99900760833609 +``` + +## References + +[1] Milletari, Fausto, Nassir Navab, and Seyed-Ahmad Ahmadi. "V-net: Fully convolutional +neural networks for volumetric medical image segmentation." 2016 fourth international +conference on 3D vision (3DV). Ieee, 2016. +""" @kwdef @concrete struct DiceCoeffLoss <: AbstractLossFunction smooth = true agg = mean @@ -184,6 +260,42 @@ function __unsafe_apply_loss(loss::DiceCoeffLoss, ŷ, y) return loss.agg(true - num ./ den) end +@doc doc""" + FocalLoss(; gamma = 2, dims = 1, agg = mean, epsilon = nothing) + +Return the focal loss [1] which can be used in classification tasks with highly imbalanced +classes. It down-weights well-classified examples and focuses on hard examples. +The input, $\hat{y}$, is expected to be normalized (i.e. `softmax` output). + +The modulating factor $\gamma$, controls the down-weighting strength. For $\gamma = 0$ this +is equivalent to [`CrossEntropyLoss`](@ref). + +## Example + +```jldoctest +julia> y = [1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0] +3×5 Matrix{Int64}: + 1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0 + +julia> ŷ = softmax(reshape(-7:7, 3, 5) .* 1f0) +3×5 Matrix{Float32}: + 0.0900306 0.0900306 0.0900306 0.0900306 0.0900306 + 0.244728 0.244728 0.244728 0.244728 0.244728 + 0.665241 0.665241 0.665241 0.665241 0.665241 + +julia> FocalLoss()(ŷ, y) ≈ 1.1277556f0 +true +``` + +## References + +[1] Lin, Tsung-Yi, et al. "Focal loss for dense object detection." Proceedings of the IEEE +international conference on computer vision. 2017. +""" @kwdef @concrete struct FocalLoss <: AbstractLossFunction gamma = 2 dims = 1 @@ -196,9 +308,43 @@ end γ = loss.gamma isa Integer ? loss.gamma : T(loss.gamma) ϵ = __get_epsilon(T, loss.epsilon) ŷϵ = ŷ .+ ϵ - return loss.agg(sum(-y .* (1 .- ŷϵ) .^ γ .+ log.(ŷϵ); loss.dims)) + return loss.agg(sum(-y .* (1 .- ŷϵ) .^ γ .* log.(ŷϵ); loss.dims)) end +@doc doc""" + HingeLoss(; agg = mean) + +Return the hinge loss loss given the prediction `ŷ` and true labels `y` (containing +1 or -1); calculated as: + +$$\text{agg}\left(\max(0, 1 - y \hat{y})\right)$$ + +Usually used with classifiers like Support Vector Machines. + +## Example + +```jldoctest +julia> loss = HingeLoss(); + +julia> y_true = [1, -1, 1, 1]; + +julia> y_pred = [0.1, 0.3, 1, 1.5]; + +julia> loss(y_pred, y_true) +0.55 + +julia> loss(y_pred[1], y_true[1]) +0.9 + +julia> loss(y_pred[2], y_true[2]) +1.3 + +julia> loss(y_pred[3], y_true[3]) +0.0 +``` + +See also [`SquaredHingeLoss`](@ref). +""" @kwdef @concrete struct HingeLoss <: AbstractLossFunction agg = mean end @@ -208,6 +354,30 @@ function __unsafe_apply_loss(loss::HingeLoss, ŷ, y) return __fused_agg(loss.agg, Base.Fix1(max, T(0)), 1 .- y .* ŷ) end +@doc doc""" + HuberLoss(; delta = 1, agg = mean) + +Returns the Huber loss, calculated as: + +$$L = \begin{cases} + 0.5 * |y - \hat{y}|^2 & \text{if } |y - \hat{y}| \leq \delta \\ + \delta * (|y - \hat{y}| - 0.5 * \delta) & \text{otherwise} +\end{cases}$$ + +where $\delta$ is the `delta` parameter. + +## Example + +```jldoctest +julia> y_model = [1.1, 2.1, 3.1]; + +julia> HuberLoss()(y_model, 1:3) +0.005000000000000009 + +julia> HuberLoss(delta=0.05)(y_model, 1:3) +0.003750000000000005 +``` +""" @kwdef @concrete struct HuberLoss <: AbstractLossFunction delta = 1 agg = mean @@ -219,11 +389,47 @@ function __unsafe_apply_loss(loss::HuberLoss, ŷ, y) end function __huber_metric(err::T1, δ::T2) where {T1, T2} - T = promote_type(T1, T2) - x = T(1 // 2) + x = promote_type(T1, T2)(0.5) return ifelse(err < δ, err^2 * x, δ * (err - x * δ)) end +@doc doc""" + KLDivergenceLoss(; dims = 1, agg = mean, epsilon = nothing, label_smoothing = nothing) + +Return the Kullback-Leibler Divergence loss between the predicted distribution $\hat{y}$ +and the true distribution $y$: + +The KL divergence is a measure of how much one probability distribution is different from +the other. It is always non-negative, and zero only when both the distributions are equal. + +For `epsilon` and `label_smoothing`, see [`CrossEntropyLoss`](@ref). + +## Example + +```jldoctest +julia> p1 = [1 0; 0 1] +2×2 Matrix{Int64}: + 1 0 + 0 1 + +julia> p2 = fill(0.5, 2, 2) +2×2 Matrix{Float64}: + 0.5 0.5 + 0.5 0.5 + +julia> KLDivergenceLoss()(p2, p1) ≈ log(2) +true + +julia> KLDivergenceLoss(; agg=sum)(p2, p1) ≈ 2 * log(2) +true + +julia> KLDivergenceLoss(; epsilon=0)(p2, p2) +0.0 + +julia> KLDivergenceLoss(; epsilon=0)(p1, p2) +Inf +``` +""" @concrete struct KLDivergenceLoss{C <: CrossEntropyLoss} <: AbstractLossFunction agg dims @@ -246,7 +452,7 @@ end Returns the loss corresponding to mean absolute error: -$$agg\left(\left| y\hat - y \right|\right)$$ +$$\text{agg}\left(\left| \hat{y} - y \right|\right)$$ ## Example @@ -272,7 +478,7 @@ const L1Loss = MAELoss Returns the loss corresponding to mean squared error: -$$agg\left(\left( y\hat - y \right)^2\right)$$ +$$\text{agg}\left(\left( \hat{y} - y \right)^2\right)$$ ## Example @@ -300,7 +506,7 @@ const L2Loss = MSELoss Returns the loss corresponding to mean squared logarithmic error: -$$agg\left(\left( \log\left( y\hat + \epsilon \right) - \log\left( y + \epsilon \right) \right)^2\right)$$ +$$\text{agg}\left(\left( \log\left( \hat{y} + \epsilon \right) - \log\left( y + \epsilon \right) \right)^2\right)$$ `epsilon` is added to both `y` and `ŷ` to prevent taking the logarithm of zero. If `epsilon` is `nothing`, then we set it to `eps()`. @@ -328,6 +534,23 @@ end return __fused_agg(loss.agg, abs2 ∘ log, (ŷ .+ ϵ) ./ (y .+ ϵ)) end +@doc doc""" + PoissonLoss(; agg = mean, epsilon = nothing) + +Return how much the predicted distribution $\hat{y}$ diverges from the expected Poisson +distribution $y$, calculated as: + +$$\text{agg}\left(\hat{y} - y * \log(\hat{y})\right)$$ + +## Example + +```jldoctest +julia> y_model = [1, 3, 3]; # data should only take integral values + +julia> PoissonLoss()(y_model, 1:3) +0.502312852219817 +``` +""" @kwdef @concrete struct PoissonLoss <: AbstractLossFunction agg = mean epsilon = nothing @@ -339,6 +562,34 @@ function __unsafe_apply_loss(loss::PoissonLoss, ŷ, y) return loss.agg(ŷ .- xlogy.(y, ŷ .+ ϵ)) end +@doc doc""" + SiameseContrastiveLoss(; margin = true, agg = mean) + +Return the contrastive loss [1] which can be useful for training Siamese Networks. It is +given by: + +$$\text{agg}\left((1 - y) \hat{y}^2 + y * \max(0, \text{margin} - \hat{y})^2\right)$$ + +Specify `margin` to set the baseline for distance at which pairs are dissimilar. + +## Example + +```jldoctest +julia> ŷ = [0.5, 1.5, 2.5]; + +julia> SiameseContrastiveLoss()(ŷ, 1:3) +-4.833333333333333 + +julia> SiameseContrastiveLoss(margin=2)(ŷ, 1:3) +-4.0 +``` + +## References + +[1] Hadsell, Raia, Sumit Chopra, and Yann LeCun. "Dimensionality reduction by learning an +invariant mapping." 2006 IEEE computer society conference on computer vision and pattern +recognition (CVPR'06). Vol. 2. IEEE, 2006. +""" @concrete struct SiameseContrastiveLoss <: AbstractLossFunction margin agg @@ -354,6 +605,40 @@ end return loss.agg(z) end +@doc doc""" + SquaredHingeLoss(; agg = mean) + +Return the squared hinge loss loss given the prediction `ŷ` and true labels `y` (containing +1 or -1); calculated as: + +$$\text{agg}\left(\max(0, 1 - y \hat{y})^2\right)$$ + +Usually used with classifiers like Support Vector Machines. + +## Example + +```jldoctest +julia> loss = SquaredHingeLoss(); + +julia> y_true = [1, -1, 1, 1]; + +julia> y_pred = [0.1, 0.3, 1, 1.5]; + +julia> loss(y_pred, y_true) +0.625 + +julia> loss(y_pred[1], y_true[1]) ≈ 0.81 +true + +julia> loss(y_pred[2], y_true[2]) ≈ 1.69 +true + +julia> loss(y_pred[3], y_true[3]) +0.0 +``` + +See also [`HingeLoss`](@ref). +""" @kwdef @concrete struct SquaredHingeLoss <: AbstractLossFunction agg = mean end @@ -363,6 +648,23 @@ function __unsafe_apply_loss(loss::SquaredHingeLoss, ŷ, y) return __fused_agg(loss.agg, abs2 ∘ Base.Fix2(max, T(0)), 1 .- y .* ŷ) end +@doc doc""" + TverskyLoss(; beta = 0.7, smooth = true, agg = mean) + +Return the Tversky loss [1]. Used with imbalanced data to give more weight to false +negatives. Larger `beta` weigh recall more than precision (by placing more emphasis on +false negatives). Calculated as: + +$$1 - \frac{\sum \left(y * \hat{y}\right) + \alpha}{\sum \left(y * \hat{y} + (1 - \beta) * (1 - y) * \hat{y} + \beta y * (1 - \hat{y})\right) + \alpha}$$ + +where $\alpha$ is the smoothing factor (`smooth`). + +## References + +[1] Salehi, Seyed Sadegh Mohseni, Deniz Erdogmus, and Ali Gholipour. "Tversky loss function +for image segmentation using 3D fully convolutional deep networks." International workshop +on machine learning in medical imaging. Cham: Springer International Publishing, 2017. +""" @kwdef @concrete struct TverskyLoss <: AbstractLossFunction beta = 0.7 smooth = true diff --git a/src/losses/utils.jl b/src/losses/utils.jl index 296a8dd8f..40b0a7b95 100644 --- a/src/losses/utils.jl +++ b/src/losses/utils.jl @@ -79,7 +79,7 @@ end end @inline __get_epsilon(::Type{T}, ϵ::Real) where {T} = T(ϵ) -@inline __get_epsilon(::Type{T}, ::Nothing) where {T} = eps(T) +@inline __get_epsilon(::Type{T}, ::Nothing) where {T} = eps(float(T)) @inline __get_dims(_) = Colon() @inline __get_dims(::AbstractVector) = Colon() From 99a36c272ee5f106c3a77ad48a74238a2ea0bcc0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Jun 2024 08:31:39 -0700 Subject: [PATCH 07/13] [skip tests] Add to docs --- docs/make.jl | 4 +-- docs/src/.vitepress/config.mts | 2 -- docs/src/api/Accelerator_Support/LuxCUDA.md | 16 --------- docs/src/api/Lux/autodiff.md | 2 +- docs/src/api/Lux/utilities.md | 40 ++++++++++++++++----- src/losses/Losses.jl | 3 ++ src/losses/loss_functions.jl | 2 +- 7 files changed, 37 insertions(+), 32 deletions(-) delete mode 100644 docs/src/api/Accelerator_Support/LuxCUDA.md diff --git a/docs/make.jl b/docs/make.jl index b4c44688b..dd9b7d827 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -52,7 +52,6 @@ pages = [ "api/Lux/distributed_utils.md", ], "Accelerator Support" => [ - "api/Accelerator_Support/LuxCUDA.md", "api/Accelerator_Support/LuxDeviceUtils.md" ], "Building Blocks" => [ @@ -82,8 +81,7 @@ makedocs(; sitename="Lux.jl Documentation", authors="Avik Pal et al.", clean=true, doctest=false, # We test it in the CI, no need to run it here - modules=[Lux, LuxCore, LuxLib, WeightInitializers, - Boltz, LuxTestUtils, LuxDeviceUtils, LuxCUDA], + modules=[Lux, LuxCore, LuxLib, WeightInitializers, Boltz, LuxTestUtils, LuxDeviceUtils], linkcheck=true, repo="https://github.com/LuxDL/Lux.jl/blob/{commit}{path}#{line}", format=DocumenterVitepress.MarkdownVitepress(; diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 5bbe72603..42049caea 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -80,7 +80,6 @@ export default defineConfig({ }, { text: 'Accelerator Support', items: [ - { text: 'LuxCUDA', link: '/api/Accelerator_Support/LuxCUDA' }, { text: 'LuxDeviceUtils', link: '/api/Accelerator_Support/LuxDeviceUtils' } ] }, @@ -196,7 +195,6 @@ export default defineConfig({ }, { text: 'Accelerator Support', collapsed: false, items: [ - { text: 'LuxCUDA', link: '/api/Accelerator_Support/LuxCUDA' }, { text: 'LuxDeviceUtils', link: '/api/Accelerator_Support/LuxDeviceUtils' }] }, { diff --git a/docs/src/api/Accelerator_Support/LuxCUDA.md b/docs/src/api/Accelerator_Support/LuxCUDA.md deleted file mode 100644 index a998cfdc5..000000000 --- a/docs/src/api/Accelerator_Support/LuxCUDA.md +++ /dev/null @@ -1,16 +0,0 @@ -# LuxCUDA - -`LuxCUDA` is meant to be used as a trigger package for all `CUDA` dependencies in `Lux`. -Users requiring CUDA support should install `LuxCUDA` and load it alongside `Lux`. - -## Index - -```@index -Pages = ["LuxCUDA.md"] -``` - -## API Reference - -```@autodocs -Modules = [LuxCUDA] -``` diff --git a/docs/src/api/Lux/autodiff.md b/docs/src/api/Lux/autodiff.md index 39e3c1e3e..f372b05aa 100644 --- a/docs/src/api/Lux/autodiff.md +++ b/docs/src/api/Lux/autodiff.md @@ -14,7 +14,7 @@ Lux. Additionally, we provide some convenience functions for working with AD. | [`ReverseDiff.jl`](https://github.com/JuliaDiff/ReverseDiff.jl) | ✔️ | ❌ | ❌ | Tier II | | [`Tracker.jl`](https://github.com/FluxML/Tracker.jl) | ✔️ | ✔️ | ❌ | Tier II | | [`Enzyme.jl`](https://github.com/EnzymeAD/Enzyme.jl) | ✔️ | ❓[^q] | ❓[^q] | Tier II | -| [`Tapir.jl`](https://github.com/withbayes/Tapir.jl) | ❓[^q] | ❓[^q] | ❌ | Tier IV | +| [`Tapir.jl`](https://github.com/withbayes/Tapir.jl) | ❓[^q] | ❌ | ❌ | Tier IV | | [`Diffractor.jl`](https://github.com/JuliaDiff/Diffractor.jl) | ❓[^q] | ❓[^q] | ❓[^q] | Tier IV | [^q]: This feature is supported downstream, but we don't extensively test it to ensure diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index 660404dcf..f50f92509 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -6,18 +6,26 @@ Pages = ["utilities.md"] ``` -## Device Management / Data Transfer +## Loss Functions ```@docs -Lux.cpu -Lux.gpu +BinaryCrossEntropyLoss +BinaryFocalLoss +CrossEntropyLoss +DiceCoeffLoss +FocalLoss +HingeLoss +HuberLoss +KLDivergenceLoss +MAELoss +MSELoss +MSLELoss +PoissonLoss +SiameseContrastiveLoss +SquaredHingeLoss +TverskyLoss ``` -!!! warning - - For detailed API documentation on Data Transfer check out the - [LuxDeviceUtils.jl](@ref LuxDeviceUtils-API) - ## Weight Initialization !!! warning @@ -31,6 +39,8 @@ Lux.gpu Lux.foldl_init Lux.istraining Lux.multigate +Lux.Losses.xlogy +Lux.Losses.xlogx ``` ## Updating Floating Point Precision @@ -56,8 +66,20 @@ StatefulLuxLayer @compact ``` -## Truncated Stacktraces +## Truncated Stacktraces (Deprecated) ```@docs Lux.disable_stacktrace_truncation! ``` + +## Device Management / Data Transfer (Deprecated) + +```@docs +Lux.cpu +Lux.gpu +``` + +!!! warning + + For detailed API documentation on Data Transfer check out the + [LuxDeviceUtils.jl](@ref LuxDeviceUtils-API) diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 17187a14f..29083838e 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -7,6 +7,7 @@ using PrecompileTools: @recompile_invalidations @recompile_invalidations begin using ArgCheck: @argcheck using ChainRulesCore: ChainRulesCore, NoTangent, ZeroTangent, @thunk + using Compat: @compat using ConcreteStructs: @concrete using FastClosures: @closure using ..Lux: __unwrap_val @@ -22,6 +23,8 @@ abstract type AbstractLossFunction <: Function end include("utils.jl") include("loss_functions.jl") +@compat public xlogx, xlogy + export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, DiceCoeffLoss, FocalLoss, HingeLoss, HuberLoss, KLDivergenceLoss, L1Loss, L2Loss, MAELoss, MSELoss, MSLELoss, PoissonLoss, SiameseContrastiveLoss, SquaredHingeLoss, TverskyLoss diff --git a/src/losses/loss_functions.jl b/src/losses/loss_functions.jl index b4cd786fc..70d11bdca 100644 --- a/src/losses/loss_functions.jl +++ b/src/losses/loss_functions.jl @@ -100,7 +100,7 @@ end BinaryFocalLoss(; gamma = 2, agg = mean, epsilon = nothing) Return the binary focal loss [1]. The model input, $\hat{y}$, is expected to be normalized -(i.e. [softmax](@ref Softmax) output). +(i.e. softmax output). For $\gamma = 0$ this is equivalent to [`BinaryCrossEntropyLoss`](@ref). From c10aa78e50d2963b5c9b8593b17105339588fdaf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Jun 2024 14:58:57 -0700 Subject: [PATCH 08/13] Reuse more code --- Project.toml | 4 + docs/src/api/Lux/utilities.md | 11 ++ src/Lux.jl | 1 + src/losses/Losses.jl | 10 +- src/losses/loss_functions.jl | 222 ++++++++++++++-------------------- src/losses/utils.jl | 46 ++++++- 6 files changed, 154 insertions(+), 140 deletions(-) diff --git a/Project.toml b/Project.toml index bd874b077..20beeedf2 100644 --- a/Project.toml +++ b/Project.toml @@ -16,12 +16,14 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5" +PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -90,6 +92,7 @@ Functors = "0.4.10" GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" Logging = "1.10" +LossFunctions = "0.11.1" LuxCore = "0.1.14" LuxDeviceUtils = "0.1.22" LuxLib = "0.3.23" @@ -101,6 +104,7 @@ Markdown = "1.10" NCCL = "0.1.1" OhMyThreads = "0.5.1" Optimisers = "0.3" +PartialFunctions = "1.2.0" Pkg = "1.10" PrecompileTools = "1.2" Preferences = "1.4.3" diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index f50f92509..2b560d573 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -8,6 +8,17 @@ Pages = ["utilities.md"] ## Loss Functions +!!! warning + + When using ChainRules.jl compatible AD (like Zygote), we only compute the gradients + wrt the inputs and drop any gradients wrt the targets. + +```@docs +GenericLossFunction +``` + +### Specialized Loss Functions + ```@docs BinaryCrossEntropyLoss BinaryFocalLoss diff --git a/src/Lux.jl b/src/Lux.jl index 3b8ca68f2..9709b35b8 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -114,6 +114,7 @@ export AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, DiceCoeffLoss, FocalLoss, HingeLoss, HuberLoss, KLDivergenceLoss, L1Loss, L2Loss, MAELoss, MSELoss, MSLELoss, PoissonLoss, SiameseContrastiveLoss, SquaredHingeLoss, TverskyLoss +export GenericLossFunction export f16, f32, f64 diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 29083838e..75f867602 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -1,18 +1,21 @@ -# Eventually the idea is to create a package `DeepLearningLosses.jl` and move this -# functionality there and simply reexport it here. +# Eventually the idea is to get `LossFunctions.jl` up to speed so that we don't need this +# sort of an implementation module Losses # A huge chunk of this code has been derived from Flux.jl using PrecompileTools: @recompile_invalidations @recompile_invalidations begin + using ArrayInterface: fast_scalar_indexing using ArgCheck: @argcheck using ChainRulesCore: ChainRulesCore, NoTangent, ZeroTangent, @thunk using Compat: @compat using ConcreteStructs: @concrete using FastClosures: @closure using ..Lux: __unwrap_val - using Markdown: @doc_str + using LossFunctions: LossFunctions using LuxLib: logsoftmax, logsigmoid + using Markdown: @doc_str + using PartialFunctions: @$ using Statistics: mean end @@ -28,6 +31,7 @@ include("loss_functions.jl") export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, DiceCoeffLoss, FocalLoss, HingeLoss, HuberLoss, KLDivergenceLoss, L1Loss, L2Loss, MAELoss, MSELoss, MSLELoss, PoissonLoss, SiameseContrastiveLoss, SquaredHingeLoss, TverskyLoss +export GenericLossFunction end diff --git a/src/losses/loss_functions.jl b/src/losses/loss_functions.jl index 70d11bdca..2236a53dd 100644 --- a/src/losses/loss_functions.jl +++ b/src/losses/loss_functions.jl @@ -50,11 +50,11 @@ julia> y_model = Float32[2, -1, pi] julia> logitbce = BinaryCrossEntropyLoss(; logits=Val(true)); -julia> logitbce(y_model, y_bin) -0.160832f0 +julia> logitbce(y_model, y_bin) ≈ 0.160832f0 +true -julia> bce(sigmoid.(y_model), y_bin) -0.16083185f0 +julia> bce(sigmoid.(y_model), y_bin) ≈ 0.16083185f0 +true julia> bce_ls = BinaryCrossEntropyLoss(label_smoothing=0.1); @@ -66,8 +66,6 @@ julia> logitbce_ls = BinaryCrossEntropyLoss(label_smoothing=0.1, logits=Val(true julia> logitbce_ls(y_model, y_bin) > logitbce(y_model, y_bin) true ``` - -See also [`CrossEntropyLoss`](@ref). """ @concrete struct BinaryCrossEntropyLoss{logits, L <: Union{Nothing, Real}} <: AbstractLossFunction @@ -120,8 +118,6 @@ julia> BinaryFocalLoss(gamma=0)(ŷ, y) ≈ BinaryCrossEntropyLoss()(ŷ, y) true ``` -See also [`FocalLoss`](@ref) for multi-class focal loss. - ## References [1] Lin, Tsung-Yi, et al. "Focal loss for dense object detection." Proceedings of the IEEE @@ -154,7 +150,14 @@ The loss is calculated as: $$\text{agg}\left(-\sum \tilde{y} \log(\hat{y} + \epsilon)\right)$$ -where $\epsilon$ is the smoothing factor. +where $\epsilon$ is added for numerical stability. The value of $\tilde{y}$ is computed +using label smoothing. If `label_smoothing` is `nothing`, then no label smoothing is +applied. If `label_smoothing` is a real number $\in [0, 1]$, then the value of +$\tilde{y}$ is calculated as: + +$$\tilde{y} = (1 - \alpha) * y + \alpha * \text{size along dim}$$ + +where $\alpha$ is the value of `label_smoothing`. ## Example @@ -173,17 +176,15 @@ julia> y_model = softmax(reshape(-7:7, 3, 5) .* 1f0) 0.244728 0.244728 0.244728 0.244728 0.244728 0.665241 0.665241 0.665241 0.665241 0.665241 -julia> CrossEntropyLoss()(y_model, y) -1.6076053f0 +julia> CrossEntropyLoss()(y_model, y) ≈ 1.6076053f0 +true julia> 5 * 1.6076053f0 ≈ CrossEntropyLoss(; agg=sum)(y_model, y) true -julia> CrossEntropyLoss(label_smoothing=0.15)(y_model, y) -1.5776052f0 +julia> CrossEntropyLoss(label_smoothing=0.15)(y_model, y) ≈ 1.5776052f0 +true ``` - -See also [`BinaryCrossEntropyLoss`](@ref) for binary classification tasks. """ @concrete struct CrossEntropyLoss{logits, L <: Union{Nothing, Real}} <: AbstractLossFunction label_smoothing::L @@ -220,7 +221,7 @@ end Return the Dice Coefficient loss [1] which is used in segmentation tasks. The dice coefficient is similar to the F1_score. Loss calculated as: -$$agg\left(1 - \frac{2 \sum \tilde{y} \hat{y} + \alpha}{\sum \tilde{y}^2 + \sum \hat{y}^2 + \alpha}\right)$$ +$$agg\left(1 - \frac{2 \sum y \hat{y} + \alpha}{\sum y^2 + \sum \hat{y}^2 + \alpha}\right)$$ where $\alpha$ is the smoothing factor (`smooth`). @@ -229,11 +230,11 @@ where $\alpha$ is the smoothing factor (`smooth`). ```jldoctest julia> y_pred = [1.1, 2.1, 3.1]; -julia> DiceCoeffLoss()(y_pred, 1:3) -0.000992391663909964 +julia> DiceCoeffLoss()(y_pred, 1:3) ≈ 0.000992391663909964 +true -julia> 1 - DiceCoeffLoss()(y_pred, 1:3) -0.99900760833609 +julia> 1 - DiceCoeffLoss()(y_pred, 1:3) ≈ 0.99900760833609 +true ``` ## References @@ -330,29 +331,11 @@ julia> y_true = [1, -1, 1, 1]; julia> y_pred = [0.1, 0.3, 1, 1.5]; -julia> loss(y_pred, y_true) -0.55 - -julia> loss(y_pred[1], y_true[1]) -0.9 - -julia> loss(y_pred[2], y_true[2]) -1.3 - -julia> loss(y_pred[3], y_true[3]) -0.0 +julia> loss(y_pred, y_true) ≈ 0.55 +true ``` - -See also [`SquaredHingeLoss`](@ref). """ -@kwdef @concrete struct HingeLoss <: AbstractLossFunction - agg = mean -end - -function __unsafe_apply_loss(loss::HingeLoss, ŷ, y) - T = promote_type(eltype(ŷ), eltype(y)) - return __fused_agg(loss.agg, Base.Fix1(max, T(0)), 1 .- y .* ŷ) -end +HingeLoss(; agg=mean) = GenericLossFunction(LossFunctions.L1HingeLoss(); agg) @doc doc""" HuberLoss(; delta = 1, agg = mean) @@ -371,26 +354,16 @@ where $\delta$ is the `delta` parameter. ```jldoctest julia> y_model = [1.1, 2.1, 3.1]; -julia> HuberLoss()(y_model, 1:3) -0.005000000000000009 +julia> HuberLoss()(y_model, 1:3) ≈ 0.005000000000000009 +true -julia> HuberLoss(delta=0.05)(y_model, 1:3) -0.003750000000000005 +julia> HuberLoss(delta=0.05)(y_model, 1:3) ≈ 0.003750000000000005 +true ``` """ -@kwdef @concrete struct HuberLoss <: AbstractLossFunction - delta = 1 - agg = mean -end - -function __unsafe_apply_loss(loss::HuberLoss, ŷ, y) - T = promote_type(eltype(ŷ), eltype(y)) - return __fused_agg(loss.agg, Base.Fix2(__huber_metric, T(loss.delta)), abs.(ŷ .- y)) -end - -function __huber_metric(err::T1, δ::T2) where {T1, T2} - x = promote_type(T1, T2)(0.5) - return ifelse(err < δ, err^2 * x, δ * (err - x * δ)) +function HuberLoss(; delta::Union{Nothing, AbstractFloat}=nothing, agg=mean) + delta = ifelse(delta === nothing, Float16(1), delta) + return GenericLossFunction(LossFunctions.HuberLoss(delta); agg) end @doc doc""" @@ -461,18 +434,14 @@ julia> loss = MAELoss(); julia> y_model = [1.1, 1.9, 3.1]; -julia> loss(y_model, 1:3) -0.10000000000000009 +julia> loss(y_model, 1:3) ≈ 0.1 +true ``` """ -@kwdef @concrete struct MAELoss <: AbstractLossFunction - agg = mean -end +MAELoss(; agg=mean) = GenericLossFunction(LossFunctions.L1DistLoss(); agg) const L1Loss = MAELoss -@inline __unsafe_apply_loss(loss::MAELoss, ŷ, y) = __fused_agg(loss.agg, abs, ŷ .- y) - @doc doc""" MSELoss(; agg = mean) @@ -487,20 +456,14 @@ julia> loss = MSELoss(); julia> y_model = [1.1, 1.9, 3.1]; -julia> loss(y_model, 1:3) -0.010000000000000018 +julia> loss(y_model, 1:3) ≈ 0.01 +true ``` - -See also [`MSELoss`](@ref). """ -@kwdef @concrete struct MSELoss <: AbstractLossFunction - agg = mean -end +MSELoss(; agg=mean) = GenericLossFunction(LossFunctions.L2DistLoss(); agg) const L2Loss = MSELoss -@inline __unsafe_apply_loss(loss::MSELoss, ŷ, y) = __fused_agg(loss.agg, abs2, ŷ .- y) - @doc doc""" MSLELoss(; agg = mean, epsilon = nothing) @@ -516,22 +479,15 @@ is `nothing`, then we set it to `eps()`. ```jldoctest julia> loss = MSLELoss(); -julia> loss(Float32[1.1, 2.2, 3.3], 1:3) -0.009084041f0 +julia> loss(Float32[1.1, 2.2, 3.3], 1:3) ≈ 0.009084041f0 +true -julia> loss(Float32[0.9, 1.8, 2.7], 1:3) -0.011100831f0 +julia> loss(Float32[0.9, 1.8, 2.7], 1:3) ≈ 0.011100831f0 +true ``` """ -@kwdef @concrete struct MSLELoss <: AbstractLossFunction - agg = mean - epsilon = nothing -end - -@inline function __unsafe_apply_loss(loss::MSLELoss, ŷ, y) - T = promote_type(eltype(ŷ), eltype(y)) - ϵ = __get_epsilon(T, loss.epsilon) - return __fused_agg(loss.agg, abs2 ∘ log, (ŷ .+ ϵ) ./ (y .+ ϵ)) +function MSLELoss(; agg=mean, epsilon=nothing) + return GenericLossFunction(@$(__msle_loss(_, _, epsilon)); agg) end @doc doc""" @@ -547,19 +503,12 @@ $$\text{agg}\left(\hat{y} - y * \log(\hat{y})\right)$$ ```jldoctest julia> y_model = [1, 3, 3]; # data should only take integral values -julia> PoissonLoss()(y_model, 1:3) -0.502312852219817 +julia> PoissonLoss()(y_model, 1:3) ≈ 0.502312852219817 +true ``` """ -@kwdef @concrete struct PoissonLoss <: AbstractLossFunction - agg = mean - epsilon = nothing -end - -function __unsafe_apply_loss(loss::PoissonLoss, ŷ, y) - T = promote_type(eltype(ŷ), eltype(y)) - ϵ = __get_epsilon(T, loss.epsilon) - return loss.agg(ŷ .- xlogy.(y, ŷ .+ ϵ)) +function PoissonLoss(; agg=mean, epsilon=nothing) + return GenericLossFunction(@$(__poisson_loss(_, _, epsilon)); agg) end @doc doc""" @@ -577,11 +526,11 @@ Specify `margin` to set the baseline for distance at which pairs are dissimilar. ```jldoctest julia> ŷ = [0.5, 1.5, 2.5]; -julia> SiameseContrastiveLoss()(ŷ, 1:3) --4.833333333333333 +julia> SiameseContrastiveLoss()(ŷ, 1:3) ≈ -4.833333333333333 +true -julia> SiameseContrastiveLoss(margin=2)(ŷ, 1:3) --4.0 +julia> SiameseContrastiveLoss(margin=2)(ŷ, 1:3) ≈ -4.0 +true ``` ## References @@ -590,19 +539,9 @@ julia> SiameseContrastiveLoss(margin=2)(ŷ, 1:3) invariant mapping." 2006 IEEE computer society conference on computer vision and pattern recognition (CVPR'06). Vol. 2. IEEE, 2006. """ -@concrete struct SiameseContrastiveLoss <: AbstractLossFunction - margin - agg -end - function SiameseContrastiveLoss(; margin::Real=true, agg=mean) @argcheck margin ≥ 0 - return SiameseContrastiveLoss(margin, agg) -end - -@inline function __unsafe_apply_loss(loss::SiameseContrastiveLoss, ŷ, y) - z = @. (1 - y) * ŷ^2 + y * max(0, loss.margin - ŷ)^2 - return loss.agg(z) + return GenericLossFunction(@$(__siamese_contrastive_loss(_, _, margin)); agg) end @doc doc""" @@ -624,29 +563,11 @@ julia> y_true = [1, -1, 1, 1]; julia> y_pred = [0.1, 0.3, 1, 1.5]; -julia> loss(y_pred, y_true) -0.625 - -julia> loss(y_pred[1], y_true[1]) ≈ 0.81 +julia> loss(y_pred, y_true) ≈ 0.625 true - -julia> loss(y_pred[2], y_true[2]) ≈ 1.69 -true - -julia> loss(y_pred[3], y_true[3]) -0.0 ``` - -See also [`HingeLoss`](@ref). """ -@kwdef @concrete struct SquaredHingeLoss <: AbstractLossFunction - agg = mean -end - -function __unsafe_apply_loss(loss::SquaredHingeLoss, ŷ, y) - T = promote_type(eltype(ŷ), eltype(y)) - return __fused_agg(loss.agg, abs2 ∘ Base.Fix2(max, T(0)), 1 .- y .* ŷ) -end +SquaredHingeLoss(; agg=mean) = GenericLossFunction(LossFunctions.L2HingeLoss(); agg) @doc doc""" TverskyLoss(; beta = 0.7, smooth = true, agg = mean) @@ -655,7 +576,7 @@ Return the Tversky loss [1]. Used with imbalanced data to give more weight to fa negatives. Larger `beta` weigh recall more than precision (by placing more emphasis on false negatives). Calculated as: -$$1 - \frac{\sum \left(y * \hat{y}\right) + \alpha}{\sum \left(y * \hat{y} + (1 - \beta) * (1 - y) * \hat{y} + \beta y * (1 - \hat{y})\right) + \alpha}$$ +$$1 - \frac{\sum \left(y \hat{y}\right) + \alpha}{\sum \left(y \hat{y} + (1 - \beta) (1 - y) \hat{y} + \beta y (1 - \hat{y})\right) + \alpha}$$ where $\alpha$ is the smoothing factor (`smooth`). @@ -686,6 +607,39 @@ function __unsafe_apply_loss(loss::TverskyLoss, ŷ, y) return loss.agg(1 - (TP + α) / (TP + α * FP + β * FN + α)) end +# Wrapper for LossFunctions.jl +@doc doc""" + GenericLossFunction(loss_fn; agg = mean) + +Takes any function `loss_fn` that maps 2 number inputs to a single number output. +Additionally, array inputs are efficiently broadcasted and aggregated using `agg`. + +```jldoctest +julia> mseloss = GenericLossFunction((ŷ, y) -> abs2(ŷ - y)); + +julia> y_model = [1.1, 1.9, 3.1]; + +julia> mseloss(y_model, 1:3) ≈ 0.01 +true +``` + +## Special Note + +This function takes any of the +[`LossFunctions.jl`](https://juliaml.github.io/LossFunctions.jl/stable). public functions +into the Lux Losses API with efficient aggregation. +""" +@concrete struct GenericLossFunction <: AbstractLossFunction + loss_fn + agg +end + +GenericLossFunction(loss_fn; agg=mean) = GenericLossFunction(loss_fn, agg) + +function __unsafe_apply_loss(loss::GenericLossFunction, ŷ, y) + return __fused_agg(loss.agg, loss.loss_fn, ŷ, y) +end + ```@meta DocTestFilters = nothing ``` diff --git a/src/losses/utils.jl b/src/losses/utils.jl index 40b0a7b95..b716b04c5 100644 --- a/src/losses/utils.jl +++ b/src/losses/utils.jl @@ -49,6 +49,23 @@ function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(xlogy), return y, ∇xlogy end +# Some functional forms of losses + +@inline function __siamese_contrastive_loss(x::T1, y::T2, margin=true) where {T1, T2} + return (1 - y) * x^2 + y * max(promote_type(T1, T2)(0), margin - x)^2 +end + +@inline function __poisson_loss(x::T1, y::T2, ϵ) where {T1, T2} + return x - xlogy(y, x + __get_epsilon(T1, ϵ)) +end + +@inline function __msle_loss(x::T1, y::T2, ϵ) where {T1, T2} + ϵ = __get_epsilon(promote_type(T1, T2), ϵ) + return log((x + ϵ) / (y + ϵ))^2 +end + +# Misc Utils + @inline function __check_sizes(ŷ::AbstractArray, y::AbstractArray) for d in 1:max(ndims(ŷ), ndims(y)) if size(ŷ, d) != size(y, d) @@ -61,10 +78,33 @@ end CRC.@non_differentiable __check_sizes(ŷ::Any, y::Any) -@inline __fused_agg(::typeof(mean), op::OP, x) where {OP} = mean(op, x) +@inline function __fused_agg(::typeof(mean), op::OP, x) where {OP} + return __fused_agg(sum, op, x) / length(x) +end +@inline function __fused_agg(::typeof(mean), lfn::LossFunctions.Traits.Loss, x, y) + return __fused_agg(sum, lfn, x, y) / length(x) +end + @inline __fused_agg(::typeof(sum), op::OP, x) where {OP} = sum(op, x) -@inline __fused_agg(::Nothing, op::OP, x) where {OP} = op.(x) -@inline __fused_agg(f::F, op::OP, x) where {F, OP} = f(op.(x)) +@inline function __fused_agg(::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y) + fast_scalar_indexing(x) && fast_scalar_indexing(y) && return sum(lfn, x, y) + return mapreduce(Broadcast.BroadcastFunction(lfn), +, x, y) +end + +@inline function CRC.rrule( + ::typeof(__fused_agg), ::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y) + z = lfn.(x, y) + ∇lfn = let z = z, y = y, lfn = lfn + Δ -> begin + ∂x = @thunk LossFunctions.deriv.((lfn,), z, y) .* Δ + return NoTangent(), NoTangent(), NoTangent(), ∂x, NoTangent() + end + end + return sum(z), ∇lfn +end + +@inline __fused_agg(::Nothing, op::OP, args...) where {OP} = op.(args...) +@inline __fused_agg(f::F, op::OP, args...) where {F, OP} = f(op.(args...)) @inline __label_smoothing(::Nothing, y, ::Type{T}) where {T} = y @inline function __label_smoothing(label_smoothing::Real, y, ::Type{T}) where {T} From 61bef4ddb87fafad0bb3cf8d0601fa38c9e727d2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Jun 2024 17:06:04 -0700 Subject: [PATCH 09/13] Remove the module --- docs/src/api/Lux/utilities.md | 9 +- src/Lux.jl | 18 +-- src/chainrules.jl | 44 +++++- .../loss_functions.jl => helpers/losses.jl} | 4 +- src/losses/Losses.jl | 38 ------ src/losses/utils.jl | 126 ------------------ src/utils.jl | 85 ++++++++++++ 7 files changed, 142 insertions(+), 182 deletions(-) rename src/{losses/loss_functions.jl => helpers/losses.jl} (99%) delete mode 100644 src/losses/Losses.jl delete mode 100644 src/losses/utils.jl diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index 2b560d573..81cd1daf7 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -15,11 +15,6 @@ Pages = ["utilities.md"] ```@docs GenericLossFunction -``` - -### Specialized Loss Functions - -```@docs BinaryCrossEntropyLoss BinaryFocalLoss CrossEntropyLoss @@ -50,8 +45,8 @@ TverskyLoss Lux.foldl_init Lux.istraining Lux.multigate -Lux.Losses.xlogy -Lux.Losses.xlogx +Lux.xlogy +Lux.xlogx ``` ## Updating Floating Point Precision diff --git a/src/Lux.jl b/src/Lux.jl index 9709b35b8..e118e2355 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -6,18 +6,21 @@ using PrecompileTools: @recompile_invalidations using ADTypes: AbstractADType, AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote using Adapt: Adapt, adapt using ArgCheck: @argcheck - using ArrayInterface: ArrayInterface + using ArrayInterface: ArrayInterface, fast_scalar_indexing using ChainRulesCore: ChainRulesCore, AbstractZero, HasReverseMode, NoTangent, - ProjectTo, RuleConfig, ZeroTangent + ProjectTo, RuleConfig, ZeroTangent, @thunk using ConcreteStructs: @concrete using FastClosures: @closure using Functors: Functors, fmap using GPUArraysCore: GPUArraysCore + using LossFunctions: LossFunctions using Markdown: @doc_str using OhMyThreads: tmapreduce + using PartialFunctions: @$ using Preferences: @load_preference using Random: Random, AbstractRNG using Reexport: @reexport + using Statistics: mean using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers using LuxLib: __apply_bias_activation @@ -63,12 +66,6 @@ include("layers/extension.jl") # Pretty Printing include("layers/display.jl") -# AutoDiff -include("chainrules.jl") - -# Losses -include("losses/Losses.jl") - # Experimental include("contrib/contrib.jl") @@ -77,6 +74,10 @@ include("helpers/stateful.jl") include("helpers/compact.jl") include("helpers/autodiff.jl") include("helpers/nested_ad.jl") +include("helpers/losses.jl") + +# AutoDiff +include("chainrules.jl") # Transform to and from other frameworks include("transform/types.jl") @@ -127,5 +128,6 @@ export MPIBackend, NCCLBackend, DistributedUtils # Unexported functions that are part of the public API @compat public Experimental +@compat public xlogx, xlogy end diff --git a/src/chainrules.jl b/src/chainrules.jl index 642302a8c..1e259eab7 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -9,7 +9,6 @@ CRC.@non_differentiable _conv_transpose_dims(::Any...) CRC.@non_differentiable _calc_padding(::Any...) CRC.@non_differentiable Base.printstyled(::Any...) ## Type Piracy: Needs upstreaming -## This is needed for fixing NamedTuple nested differentiation CRC.@non_differentiable fieldcount(::Any) # Utilities @@ -75,3 +74,46 @@ function CRC.rrule(::typeof(getproperty), m::AbstractExplicitLayer, name::Symbol ∇getproperty = Δ -> ntuple(Returns(NoTangent()), 3) return res, ∇getproperty end + +# For loss functions +@inline function CRC.rrule( + ::typeof(__fused_agg), ::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y) + z = lfn.(x, y) + ∇lfn = let z = z, y = y, lfn = lfn + Δ -> begin + ∂x = @thunk LossFunctions.deriv.((lfn,), z, y) .* Δ + return NoTangent(), NoTangent(), NoTangent(), ∂x, NoTangent() + end + end + return sum(z), ∇lfn +end + +function CRC.rrule(::typeof(xlogx), x::Number) + iszero(x) && return x, Δ -> (NoTangent(), ZeroTangent()) + logx = log(x) + ∇xlogx = @closure Δ -> (NoTangent(), @thunk(Δ*(logx + true))) + return x * logx, ∇xlogx +end + +function CRC.rrule( + ::typeof(Broadcast.broadcasted), ::typeof(xlogx), x::AbstractArray{<:Number}) + logx = log.(x) + y = x .* logx + ∇xlogx = @closure Δ -> (NoTangent(), NoTangent(), @thunk(Δ.*(logx .+ true))) + return y, ∇xlogx +end + +function CRC.rrule(::typeof(xlogy), x::Number, y::Number) + iszero(x) && return x, Δ -> (NoTangent(), ZeroTangent()) + logy = log(y) + ∇xlogy = @closure Δ -> (NoTangent(), @thunk(Δ*logy), @thunk(Δ * x/y)) + return x * logy, ∇xlogy +end + +function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(xlogy), + x::AbstractArray{<:Number}, y::AbstractArray{<:Number}) + logy = log.(y) + y = x .* logy + ∇xlogy = @closure Δ -> (NoTangent(), NoTangent(), @thunk(Δ.*logy), @thunk(Δ .* x./y)) + return y, ∇xlogy +end diff --git a/src/losses/loss_functions.jl b/src/helpers/losses.jl similarity index 99% rename from src/losses/loss_functions.jl rename to src/helpers/losses.jl index 2236a53dd..5ffc78e55 100644 --- a/src/losses/loss_functions.jl +++ b/src/helpers/losses.jl @@ -2,6 +2,7 @@ ```@meta DocTestFilters = r"[0-9\.]+f0" ``` +abstract type AbstractLossFunction <: Function end function (loss::AbstractLossFunction)(ŷ, y) __check_sizes(ŷ, y) @@ -607,7 +608,6 @@ function __unsafe_apply_loss(loss::TverskyLoss, ŷ, y) return loss.agg(1 - (TP + α) / (TP + α * FP + β * FN + α)) end -# Wrapper for LossFunctions.jl @doc doc""" GenericLossFunction(loss_fn; agg = mean) @@ -626,7 +626,7 @@ true ## Special Note This function takes any of the -[`LossFunctions.jl`](https://juliaml.github.io/LossFunctions.jl/stable). public functions +[`LossFunctions.jl`](https://juliaml.github.io/LossFunctions.jl/stable) public functions into the Lux Losses API with efficient aggregation. """ @concrete struct GenericLossFunction <: AbstractLossFunction diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl deleted file mode 100644 index 75f867602..000000000 --- a/src/losses/Losses.jl +++ /dev/null @@ -1,38 +0,0 @@ -# Eventually the idea is to get `LossFunctions.jl` up to speed so that we don't need this -# sort of an implementation -module Losses # A huge chunk of this code has been derived from Flux.jl - -using PrecompileTools: @recompile_invalidations - -@recompile_invalidations begin - using ArrayInterface: fast_scalar_indexing - using ArgCheck: @argcheck - using ChainRulesCore: ChainRulesCore, NoTangent, ZeroTangent, @thunk - using Compat: @compat - using ConcreteStructs: @concrete - using FastClosures: @closure - using ..Lux: __unwrap_val - using LossFunctions: LossFunctions - using LuxLib: logsoftmax, logsigmoid - using Markdown: @doc_str - using PartialFunctions: @$ - using Statistics: mean -end - -const CRC = ChainRulesCore - -abstract type AbstractLossFunction <: Function end - -include("utils.jl") -include("loss_functions.jl") - -@compat public xlogx, xlogy - -export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, DiceCoeffLoss, FocalLoss, - HingeLoss, HuberLoss, KLDivergenceLoss, L1Loss, L2Loss, MAELoss, MSELoss, MSLELoss, - PoissonLoss, SiameseContrastiveLoss, SquaredHingeLoss, TverskyLoss -export GenericLossFunction - -end - -using .Losses diff --git a/src/losses/utils.jl b/src/losses/utils.jl deleted file mode 100644 index b716b04c5..000000000 --- a/src/losses/utils.jl +++ /dev/null @@ -1,126 +0,0 @@ -""" - xlogx(x::Number) - -Return `x * log(x)` for `x ≥ 0`, handling `x == 0` by taking the limit from above, to get -zero. -""" -@inline function xlogx(x::Number) - result = x * log(x) - return ifelse(iszero(x), zero(result), result) -end - -function CRC.rrule(::typeof(xlogx), x::Number) - iszero(x) && return x, Δ -> (NoTangent(), ZeroTangent()) - logx = log(x) - ∇xlogx = @closure Δ -> (NoTangent(), @thunk(Δ*(logx + true))) - return x * logx, ∇xlogx -end - -function CRC.rrule( - ::typeof(Broadcast.broadcasted), ::typeof(xlogx), x::AbstractArray{<:Number}) - logx = log.(x) - y = x .* logx - ∇xlogx = @closure Δ -> (NoTangent(), NoTangent(), @thunk(Δ.*(logx .+ true))) - return y, ∇xlogx -end - -""" - xlogy(x::Number, y::Number) - -Return `x * log(y)` for `y > 0`, and zero when `x == 0`. -""" -@inline function xlogy(x::Number, y::Number) - result = x * log(y) - return ifelse(iszero(x), zero(result), result) -end - -function CRC.rrule(::typeof(xlogy), x::Number, y::Number) - iszero(x) && return x, Δ -> (NoTangent(), ZeroTangent()) - logy = log(y) - ∇xlogy = @closure Δ -> (NoTangent(), @thunk(Δ*logy), @thunk(Δ * x/y)) - return x * logy, ∇xlogy -end - -function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(xlogy), - x::AbstractArray{<:Number}, y::AbstractArray{<:Number}) - logy = log.(y) - y = x .* logy - ∇xlogy = @closure Δ -> (NoTangent(), NoTangent(), @thunk(Δ.*logy), @thunk(Δ .* x./y)) - return y, ∇xlogy -end - -# Some functional forms of losses - -@inline function __siamese_contrastive_loss(x::T1, y::T2, margin=true) where {T1, T2} - return (1 - y) * x^2 + y * max(promote_type(T1, T2)(0), margin - x)^2 -end - -@inline function __poisson_loss(x::T1, y::T2, ϵ) where {T1, T2} - return x - xlogy(y, x + __get_epsilon(T1, ϵ)) -end - -@inline function __msle_loss(x::T1, y::T2, ϵ) where {T1, T2} - ϵ = __get_epsilon(promote_type(T1, T2), ϵ) - return log((x + ϵ) / (y + ϵ))^2 -end - -# Misc Utils - -@inline function __check_sizes(ŷ::AbstractArray, y::AbstractArray) - for d in 1:max(ndims(ŷ), ndims(y)) - if size(ŷ, d) != size(y, d) - throw(DimensionMismatch("loss function expects size(ŷ) = $(size(ŷ)) to match \ - size(y) = $(size(y))")) - end - end -end -@inline __check_sizes(ŷ, y) = nothing - -CRC.@non_differentiable __check_sizes(ŷ::Any, y::Any) - -@inline function __fused_agg(::typeof(mean), op::OP, x) where {OP} - return __fused_agg(sum, op, x) / length(x) -end -@inline function __fused_agg(::typeof(mean), lfn::LossFunctions.Traits.Loss, x, y) - return __fused_agg(sum, lfn, x, y) / length(x) -end - -@inline __fused_agg(::typeof(sum), op::OP, x) where {OP} = sum(op, x) -@inline function __fused_agg(::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y) - fast_scalar_indexing(x) && fast_scalar_indexing(y) && return sum(lfn, x, y) - return mapreduce(Broadcast.BroadcastFunction(lfn), +, x, y) -end - -@inline function CRC.rrule( - ::typeof(__fused_agg), ::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y) - z = lfn.(x, y) - ∇lfn = let z = z, y = y, lfn = lfn - Δ -> begin - ∂x = @thunk LossFunctions.deriv.((lfn,), z, y) .* Δ - return NoTangent(), NoTangent(), NoTangent(), ∂x, NoTangent() - end - end - return sum(z), ∇lfn -end - -@inline __fused_agg(::Nothing, op::OP, args...) where {OP} = op.(args...) -@inline __fused_agg(f::F, op::OP, args...) where {F, OP} = f(op.(args...)) - -@inline __label_smoothing(::Nothing, y, ::Type{T}) where {T} = y -@inline function __label_smoothing(label_smoothing::Real, y, ::Type{T}) where {T} - label_smoothing = T(label_smoothing) - return y .* (1 - label_smoothing) .+ label_smoothing ./ size(y, ndims(y) - 1) -end - -@inline __label_smoothing_binary(::Nothing, y, ::Type{T}) where {T} = y -@inline function __label_smoothing_binary(label_smoothing::Real, y, ::Type{T}) where {T} - label_smoothing = T(label_smoothing) - return y .* (1 - label_smoothing) .+ label_smoothing ./ 2 -end - -@inline __get_epsilon(::Type{T}, ϵ::Real) where {T} = T(ϵ) -@inline __get_epsilon(::Type{T}, ::Nothing) where {T} = eps(float(T)) - -@inline __get_dims(_) = Colon() -@inline __get_dims(::AbstractVector) = Colon() -@inline __get_dims(::AbstractArray{T, N}) where {T, N} = 1:(N - 1) diff --git a/src/utils.jl b/src/utils.jl index 9772d724c..ca89047c8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -306,3 +306,88 @@ end __recursive_make_zero!!, values(x))) @inline __recursive_make_zero!!(::Nothing) = nothing @inline __recursive_make_zero!!(x) = fmap(__recursive_make_zero!!, x) + +# helpers for the loss functions +""" + xlogx(x::Number) + +Return `x * log(x)` for `x ≥ 0`, handling `x == 0` by taking the limit from above, to get +zero. +""" +@inline function xlogx(x::Number) + result = x * log(x) + return ifelse(iszero(x), zero(result), result) +end + +""" + xlogy(x::Number, y::Number) + +Return `x * log(y)` for `y > 0`, and zero when `x == 0`. +""" +@inline function xlogy(x::Number, y::Number) + result = x * log(y) + return ifelse(iszero(x), zero(result), result) +end + +# Some functional forms of losses + +@inline function __siamese_contrastive_loss(x::T1, y::T2, margin=true) where {T1, T2} + return (1 - y) * x^2 + y * max(promote_type(T1, T2)(0), margin - x)^2 +end + +@inline function __poisson_loss(x::T1, y::T2, ϵ) where {T1, T2} + return x - xlogy(y, x + __get_epsilon(T1, ϵ)) +end + +@inline function __msle_loss(x::T1, y::T2, ϵ) where {T1, T2} + ϵ = __get_epsilon(promote_type(T1, T2), ϵ) + return log((x + ϵ) / (y + ϵ))^2 +end + +# Misc Utils +@inline function __check_sizes(ŷ::AbstractArray, y::AbstractArray) + for d in 1:max(ndims(ŷ), ndims(y)) + if size(ŷ, d) != size(y, d) + throw(DimensionMismatch("loss function expects size(ŷ) = $(size(ŷ)) to match \ + size(y) = $(size(y))")) + end + end +end +@inline __check_sizes(ŷ, y) = nothing + +CRC.@non_differentiable __check_sizes(ŷ::Any, y::Any) + +@inline function __fused_agg(::typeof(mean), op::OP, x) where {OP} + return __fused_agg(sum, op, x) / length(x) +end +@inline function __fused_agg(::typeof(mean), lfn::LossFunctions.Traits.Loss, x, y) + return __fused_agg(sum, lfn, x, y) / length(x) +end + +@inline __fused_agg(::typeof(sum), op::OP, x) where {OP} = sum(op, x) +@inline function __fused_agg(::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y) + fast_scalar_indexing(x) && fast_scalar_indexing(y) && return sum(lfn, x, y) + return mapreduce(Broadcast.BroadcastFunction(lfn), +, x, y) +end + +@inline __fused_agg(::Nothing, op::OP, args...) where {OP} = op.(args...) +@inline __fused_agg(f::F, op::OP, args...) where {F, OP} = f(op.(args...)) + +@inline __label_smoothing(::Nothing, y, ::Type{T}) where {T} = y +@inline function __label_smoothing(label_smoothing::Real, y, ::Type{T}) where {T} + label_smoothing = T(label_smoothing) + return y .* (1 - label_smoothing) .+ label_smoothing ./ size(y, ndims(y) - 1) +end + +@inline __label_smoothing_binary(::Nothing, y, ::Type{T}) where {T} = y +@inline function __label_smoothing_binary(label_smoothing::Real, y, ::Type{T}) where {T} + label_smoothing = T(label_smoothing) + return y .* (1 - label_smoothing) .+ label_smoothing ./ 2 +end + +@inline __get_epsilon(::Type{T}, ϵ::Real) where {T} = T(ϵ) +@inline __get_epsilon(::Type{T}, ::Nothing) where {T} = eps(float(T)) + +@inline __get_dims(_) = Colon() +@inline __get_dims(::AbstractVector) = Colon() +@inline __get_dims(::AbstractArray{T, N}) where {T, N} = 1:(N - 1) From 638a1763dfb9287278af4fe1764a07dc0679a612 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Jun 2024 18:13:25 -0700 Subject: [PATCH 10/13] Add tests --- Project.toml | 4 +- docs/src/api/Lux/utilities.md | 1 - src/Lux.jl | 2 +- src/chainrules.jl | 15 +- src/helpers/losses.jl | 40 +--- src/utils.jl | 5 + test/helpers/loss_tests.jl | 386 ++++++++++++++++++++++++++++++++++ 7 files changed, 402 insertions(+), 51 deletions(-) create mode 100644 test/helpers/loss_tests.jl diff --git a/Project.toml b/Project.toml index 20beeedf2..8c800a24d 100644 --- a/Project.toml +++ b/Project.toml @@ -103,6 +103,7 @@ MacroTools = "0.5.13" Markdown = "1.10" NCCL = "0.1.1" OhMyThreads = "0.5.1" +OneHotArrays = "0.2.5" Optimisers = "0.3" PartialFunctions = "1.2.0" Pkg = "1.10" @@ -135,6 +136,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" @@ -147,4 +149,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "Documenter", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "ForwardDiff", "Logging", "LuxTestUtils", "MLUtils", "Optimisers", "Pkg", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"] +test = ["Aqua", "ComponentArrays", "Documenter", "DynamicExpressions", "Enzyme", "ExplicitImports", "FiniteDifferences", "ForwardDiff", "Logging", "LuxTestUtils", "MLUtils", "OneHotArrays", "Optimisers", "Pkg", "ReTestItems", "ReverseDiff", "SimpleChains", "StableRNGs", "Statistics", "Test", "Tracker", "Zygote"] diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index 81cd1daf7..e20b1956e 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -29,7 +29,6 @@ MSLELoss PoissonLoss SiameseContrastiveLoss SquaredHingeLoss -TverskyLoss ``` ## Weight Initialization diff --git a/src/Lux.jl b/src/Lux.jl index e118e2355..0eac19186 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -114,7 +114,7 @@ export AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote export BinaryCrossEntropyLoss, BinaryFocalLoss, CrossEntropyLoss, DiceCoeffLoss, FocalLoss, HingeLoss, HuberLoss, KLDivergenceLoss, L1Loss, L2Loss, MAELoss, MSELoss, MSLELoss, - PoissonLoss, SiameseContrastiveLoss, SquaredHingeLoss, TverskyLoss + PoissonLoss, SiameseContrastiveLoss, SquaredHingeLoss export GenericLossFunction export f16, f32, f64 diff --git a/src/chainrules.jl b/src/chainrules.jl index 1e259eab7..ad951980b 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -78,14 +78,11 @@ end # For loss functions @inline function CRC.rrule( ::typeof(__fused_agg), ::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y) - z = lfn.(x, y) - ∇lfn = let z = z, y = y, lfn = lfn - Δ -> begin - ∂x = @thunk LossFunctions.deriv.((lfn,), z, y) .* Δ - return NoTangent(), NoTangent(), NoTangent(), ∂x, NoTangent() - end + ∇lfn = @closure Δ -> begin + ∂x = @thunk LossFunctions.deriv.(Ref(lfn), x, y) .* Δ + return NoTangent(), NoTangent(), NoTangent(), ∂x, NoTangent() end - return sum(z), ∇lfn + return __fused_agg(sum, lfn, x, y), ∇lfn end function CRC.rrule(::typeof(xlogx), x::Number) @@ -113,7 +110,7 @@ end function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(xlogy), x::AbstractArray{<:Number}, y::AbstractArray{<:Number}) logy = log.(y) - y = x .* logy + z = x .* logy ∇xlogy = @closure Δ -> (NoTangent(), NoTangent(), @thunk(Δ.*logy), @thunk(Δ .* x./y)) - return y, ∇xlogy + return z, ∇xlogy end diff --git a/src/helpers/losses.jl b/src/helpers/losses.jl index 5ffc78e55..90b269e73 100644 --- a/src/helpers/losses.jl +++ b/src/helpers/losses.jl @@ -417,7 +417,7 @@ end function __unsafe_apply_loss(loss::KLDivergenceLoss, ŷ, y) cross_entropy = __unsafe_apply_loss(loss.celoss, ŷ, y) - entropy = loss.agg(sum(xlogx, y; loss.dims)) + entropy = loss.agg(sum(xlogx.(y); loss.dims)) # Intentional broadcasting for Zygote type stability return entropy + cross_entropy end @@ -570,44 +570,6 @@ true """ SquaredHingeLoss(; agg=mean) = GenericLossFunction(LossFunctions.L2HingeLoss(); agg) -@doc doc""" - TverskyLoss(; beta = 0.7, smooth = true, agg = mean) - -Return the Tversky loss [1]. Used with imbalanced data to give more weight to false -negatives. Larger `beta` weigh recall more than precision (by placing more emphasis on -false negatives). Calculated as: - -$$1 - \frac{\sum \left(y \hat{y}\right) + \alpha}{\sum \left(y \hat{y} + (1 - \beta) (1 - y) \hat{y} + \beta y (1 - \hat{y})\right) + \alpha}$$ - -where $\alpha$ is the smoothing factor (`smooth`). - -## References - -[1] Salehi, Seyed Sadegh Mohseni, Deniz Erdogmus, and Ali Gholipour. "Tversky loss function -for image segmentation using 3D fully convolutional deep networks." International workshop -on machine learning in medical imaging. Cham: Springer International Publishing, 2017. -""" -@kwdef @concrete struct TverskyLoss <: AbstractLossFunction - beta = 0.7 - smooth = true - agg = mean -end - -function __unsafe_apply_loss(loss::TverskyLoss, ŷ, y) - T = promote_type(eltype(ŷ), eltype(y)) - β = T(loss.beta) - α = T(loss.smooth) - - yŷ = y .* ŷ - dims = __get_dims(yŷ) - - TP = sum(yŷ; dims) - FP = sum((true .- y) .* ŷ; dims) - FN = sum(y .* (true .- ŷ); dims) - - return loss.agg(1 - (TP + α) / (TP + α * FP + β * FN + α)) -end - @doc doc""" GenericLossFunction(loss_fn; agg = mean) diff --git a/src/utils.jl b/src/utils.jl index ca89047c8..40d4b8871 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -364,7 +364,12 @@ end return __fused_agg(sum, lfn, x, y) / length(x) end +@inline __fused_agg(::typeof(sum), op::OP, x::Number) where {OP} = op(x) @inline __fused_agg(::typeof(sum), op::OP, x) where {OP} = sum(op, x) +@inline function __fused_agg( + ::typeof(sum), lfn::LossFunctions.Traits.Loss, x::Number, y::Number) + return lfn(x, y) +end @inline function __fused_agg(::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y) fast_scalar_indexing(x) && fast_scalar_indexing(y) && return sum(lfn, x, y) return mapreduce(Broadcast.BroadcastFunction(lfn), +, x, y) diff --git a/test/helpers/loss_tests.jl b/test/helpers/loss_tests.jl new file mode 100644 index 000000000..3da1aa50e --- /dev/null +++ b/test/helpers/loss_tests.jl @@ -0,0 +1,386 @@ +@testitem "xlogx & xlogy" setup=[SharedTestSetup] tags=[:helpers] begin + using Lux: xlogx, xlogy + using ForwardDiff, Zygote + + @test iszero(xlogx(0)) + @test isnan(xlogx(NaN)) + @test xlogx(2) ≈ 2.0 * log(2.0) + + ∂x1 = ForwardDiff.derivative(xlogx, 2.0) + ∂x2 = Zygote.gradient(xlogx, 2.0)[1] + @test ∂x1 ≈ ∂x2 + + @inferred xlogx(2) + @inferred xlogx(0) + @jet xlogx(2) + + @test iszero(xlogy(0, 1)) + @test isnan(xlogy(NaN, 1)) + @test isnan(xlogy(1, NaN)) + @test isnan(xlogy(NaN, NaN)) + @test xlogy(2, 3) ≈ 2.0 * log(3.0) + + ∂x1 = ForwardDiff.derivative(Base.Fix2(xlogy, 3.0), 2.0) + ∂y1 = ForwardDiff.derivative(Base.Fix1(xlogy, 2.0), 3.0) + ∂x2, ∂y2 = Zygote.gradient(xlogy, 2.0, 3.0) + @test ∂x1 ≈ ∂x2 + @test ∂y1 ≈ ∂y2 + + @inferred xlogy(2, 3) + @inferred xlogy(0, 1) + @jet xlogy(2, 3) + + @testset "$mode" for (mode, aType, dev, ongpu) in MODES + x = rand(10) |> aType + __f = sum ∘ Broadcast.BroadcastFunction(xlogx) + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 + + y = rand(10) |> aType + __f = sum ∘ Broadcast.BroadcastFunction(xlogy) + @eval @test_gradients $__f $x $y gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 + end +end + +@testitem "Regression Loss" setup=[SharedTestSetup] tags=[:helpers] begin + using Zygote + + @testset "$mode" for (mode, aType, dev, ongpu) in MODES + y = [1.0, 1.0, 0.0, 0.0] |> aType + ŷ = [0.9, 0.1, 0.1, 0.9] |> aType + + loss_res_map = Dict( + "MSE" => (0.1^2 + 0.9^2) / 2, "MAE" => (0.1 + 0.9) / 2, "Huber" => 0.205) + + @testset "$(loss)" for (loss, loss_res) in loss_res_map + loss_mean = eval(Symbol(loss * "Loss"))() + loss_sum = eval(Symbol(loss * "Loss"))(; agg=sum) + loss_sum2 = eval(Symbol(loss * "Loss"))(; agg=(args...) -> sum(args...)) + + @test loss_mean(ŷ, y) ≈ loss_res + @test loss_sum(ŷ, y) ≈ loss_res * 4 + @test loss_sum2(ŷ, y) ≈ loss_res * 4 + + @inferred Zygote.gradient(loss_mean, ŷ, y) + + @jet loss_mean(ŷ, y) + @jet loss_sum(ŷ, y) + + __f = Base.Fix2(loss_mean, y) + @eval @test_gradients $__f $ŷ gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 skip_tracker=$ongpu + end + + @testset "MSLE" begin + y = [123.0, 456.0, 789.0] |> aType + ŷ = [345.0, 332.0, 789.0] |> aType + + @test MSLELoss()(ŷ, y) ≈ 0.38813985859136585 + + @jet MSLELoss()(ŷ, y) + + @test_broken @inferred Zygote.gradient(MSLELoss(), ŷ, y) + + __f = Base.Fix2(MSLELoss(), y) + @eval @test_gradients $__f $ŷ gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 skip_tracker=$ongpu + end + end +end + +@testitem "Classification Loss" setup=[SharedTestSetup] tags=[:helpers] begin + using OneHotArrays, Zygote + + @testset "$mode" for (mode, aType, dev, ongpu) in MODES + y = onehotbatch([1, 1, 0, 0], 0:1) |> dev + y_smoothed = Lux.__label_smoothing(0.1, y, Float32) + + ŷ = [0.1 0.9; 0.9 0.1; 0.9 0.1; 0.1 0.9]' |> dev + v = log(0.1 / 0.9) + logŷ = [v 0.0; 0.0 v; 0.0 v; v 0.0]' |> dev + lossvalue = 1.203972804325936 + lossvalue_smoothed = 1.2039728043259348 + + yl = onehotbatch([1], 0:1) |> dev + sf = 0.1 + yls = [sf (1 - sf)]' |> dev + ylp = [0.9 0.1]' |> dev + logylp = [0.0 v]' |> dev + + ya = onehotbatch([1, 1, 1, 0, 0], 0:1) |> dev + ya_smoothed = Lux.__label_smoothing(2sf, ya, Float32) + y_same = Float32.(ya) + y_sim = y_same .* (1 - 2 * sf) .+ sf + y_dis = copy(y_sim) + y_dis[1, :], y_dis[2, :] = y_dis[2, :], y_dis[1, :] + + @testset "CrossEntropyLoss" begin + celoss = CrossEntropyLoss() + + @test celoss([0.1, 0.0, 0.9] |> aType, [0.1, 0.0, 0.9] |> aType) ≈ + celoss([0.1, 0.9] |> aType, [0.1, 0.9] |> aType) + + @test celoss(ŷ, y) ≈ lossvalue + @test celoss(ŷ, y_smoothed) ≈ lossvalue_smoothed + + celoss_smooth = CrossEntropyLoss(; label_smoothing=0.1) + @test celoss_smooth(ŷ, y) ≈ lossvalue_smoothed + + celoss_smooth2 = CrossEntropyLoss(; label_smoothing=2sf) + @test celoss_smooth2(ylp, yl) ≈ sum(-yls .* log.(ylp)) + + @test celoss(ylp, yl) ≈ sum(-yl .* log.(ylp)) + + @test iszero(CrossEntropyLoss(; epsilon=0)(y_same, ya)) + + @test celoss(y_sim, ya) < celoss_smooth(y_sim, ya) + @test celoss(y_dis, ya) > celoss_smooth(y_dis, ya) + + @jet celoss(ŷ, y) + @jet celoss_smooth(ŷ, y) + + @inferred Zygote.gradient(celoss, ŷ, y) + + __f = Base.Fix2(celoss, y) + @eval @test_gradients $__f $ŷ gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 skip_tracker=$ongpu + end + + @testset "Logit CrossEntropyLoss" begin + logitceloss = CrossEntropyLoss(; logits=Val(true)) + + @test logitceloss(logŷ, y) ≈ lossvalue + @test logitceloss(logylp, yl) ≈ sum(-yl .* log.(softmax(logylp))) + + logitceloss_smooth = CrossEntropyLoss(; logits=Val(true), label_smoothing=0.1) + + @test logitceloss(logŷ, y_smoothed) ≈ lossvalue_smoothed + @test logitceloss_smooth(logŷ, y) ≈ lossvalue_smoothed + + logitceloss_smooth2 = CrossEntropyLoss(; logits=Val(true), label_smoothing=2sf) + @test logitceloss_smooth2(logylp, yl) ≈ sum(-yls .* log.(softmax(logylp))) + + @jet logitceloss(logŷ, y) + @jet logitceloss_smooth(logŷ, y) + + @inferred Zygote.gradient(logitceloss, logŷ, y) + + __f = Base.Fix2(logitceloss, y) + @eval @test_gradients $__f $logŷ gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 skip_tracker=$ongpu + end + + logŷ, y = randn(3) |> aType, rand(3) |> aType + yls = y .* (1 - 2sf) .+ sf + + @testset "BinaryCrossEntropyLoss" begin + bceloss = BinaryCrossEntropyLoss() + bceloss_smooth = BinaryCrossEntropyLoss(; label_smoothing=2sf, epsilon=0) + + @test bceloss_smooth(σ.(logŷ), y) ≈ + -mean(yls .* log.(σ.(logŷ)) .+ (1 .- yls) .* log.(1 .- σ.(logŷ))) + + @test bceloss(σ.(logŷ), y) ≈ + mean(-y .* log.(σ.(logŷ)) .- (1 .- y) .* log.(1 .- σ.(logŷ))) + + @test bceloss(σ.(logŷ), y) ≈ mean(-y .* log.(σ.(logŷ) .+ eps.(σ.(logŷ))) - + (1 .- y) .* log.(1 .- σ.(logŷ) .+ eps.(σ.(logŷ)))) + + @test bceloss([0.1, 0.2, 0.9] |> aType, 1) ≈ + -mean(log, [0.1, 0.2, 0.9] |> aType) # constant label + + @jet bceloss(σ.(logŷ), y) + @jet bceloss_smooth(σ.(logŷ), y) + + @inferred Zygote.gradient(bceloss, σ.(logŷ), y) + + __f = Base.Fix2(bceloss, y) + σlogŷ = σ.(logŷ) + @eval @test_gradients $__f $σlogŷ gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 skip_tracker=$ongpu + end + + @testset "Logit BinaryCrossEntropyLoss" begin + logitbceloss = BinaryCrossEntropyLoss(; logits=Val(true)) + logitbceloss_smooth = BinaryCrossEntropyLoss(; + logits=Val(true), label_smoothing=2sf, epsilon=0) + + @test logitbceloss_smooth(logŷ, y) ≈ + -mean(yls .* log.(sigmoid(logŷ)) .+ + (1 .- yls) .* log.(1 .- sigmoid(logŷ))) + + @test logitbceloss(logŷ, y) ≈ + mean(-y .* log.(sigmoid(logŷ)) .- (1 .- y) .* log.(1 .- sigmoid(logŷ))) + + @jet logitbceloss(logŷ, y) + @jet logitbceloss_smooth(logŷ, y) + + @inferred Zygote.gradient(logitbceloss, logŷ, y) + + __f = Base.Fix2(logitbceloss, y) + @eval @test_gradients $__f $logŷ gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 skip_tracker=$ongpu + end + + @testset "BinaryFocalLoss" begin + y = [0 1 0 + 1 0 1] |> aType + ŷ = [0.268941 0.5 0.268941 + 0.731059 0.5 0.731059] |> aType + + y1 = [1 0 + 0 1] |> aType + ŷ1 = [0.6 0.3 + 0.4 0.7] |> aType + + @test BinaryFocalLoss()(ŷ, y) ≈ 0.0728675615927385 + @test BinaryFocalLoss()(ŷ1, y1) ≈ 0.05691642237852222 + @test BinaryFocalLoss(; gamma=0)(ŷ, y) ≈ Lux.CrossEntropyLoss()(ŷ, y) + + @jet BinaryFocalLoss()(ŷ, y) + + @inferred Zygote.gradient(BinaryFocalLoss(), ŷ, y) + + __f = Base.Fix2(BinaryFocalLoss(), y) + @eval @test_gradients $__f $ŷ gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 skip_tracker=$ongpu + end + + @testset "FocalLoss" begin + y = [1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0] |> aType + ŷ = softmax(reshape(-7:7, 3, 5) .* 1.0f0) + y1 = [1 0 + 0 0 + 0 1] |> aType + ŷ1 = [0.4 0.2 + 0.5 0.5 + 0.1 0.3] |> aType + + @test FocalLoss()(ŷ, y) ≈ 1.1277571935622628 + @test FocalLoss()(ŷ1, y1) ≈ 0.45990566879720157 + @test FocalLoss(; gamma=0)(ŷ, y) ≈ CrossEntropyLoss()(ŷ, y) + + @jet FocalLoss()(ŷ, y) + + @inferred Zygote.gradient(FocalLoss(), ŷ, y) + + __f = Base.Fix2(FocalLoss(), y) + # FD will lead to out of domain errors + @eval @test_gradients $__f $ŷ gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 skip_tracker=$ongpu skip_finite_differences=true + end + end +end + +@testitem "Other Losses" setup=[SharedTestSetup] tags=[:helpers] begin + using Zygote + + @testset "$mode" for (mode, aType, dev, ongpu) in MODES + @testset "KLDivergenceLoss" begin + y = [1 2 3] |> aType + ŷ = [4.0 5.0 6.0] |> aType + + @test KLDivergenceLoss()([0.1, 0.0, 0.9] |> aType, [0.1, 0.0, 0.9] |> aType) ≈ + KLDivergenceLoss()([0.1, 0.9] |> aType, [0.1, 0.9] |> aType) + @test KLDivergenceLoss()(ŷ, y) ≈ -1.7661057888493457 + @test KLDivergenceLoss()(y, y) ≈ 0 + + @jet KLDivergenceLoss()(ŷ, y) + @inferred Zygote.gradient(KLDivergenceLoss(), ŷ, y) + + __f = Base.Fix2(KLDivergenceLoss(), y) + @eval @test_gradients $__f $ŷ gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 skip_tracker=$ongpu + end + + @testset "HingeLoss" begin + y = [1, 2, 3, 4] |> aType + ŷ = [5.0, 6.0, 7.0, 8.0] |> aType + + @test Lux.HingeLoss()(ŷ, y) ≈ 0 + @test Lux.HingeLoss()(y, 0.5 .* y) ≈ 0.125 + + @jet Lux.HingeLoss()(ŷ, y) + @inferred Zygote.gradient(Lux.HingeLoss(), ŷ, y) + + __f = Base.Fix2(Lux.HingeLoss(), y) + @eval @test_gradients $__f $ŷ gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 skip_tracker=$ongpu + end + + @testset "SquaredHingeLoss" begin + y = [1, 2, 3, 4] |> aType + ŷ = [5.0, 6.0, 7.0, 8.0] |> aType + + @test SquaredHingeLoss()(ŷ, y) ≈ 0 + @test SquaredHingeLoss()(y, 0.5 .* y) ≈ 0.0625 + + @jet SquaredHingeLoss()(ŷ, y) + @inferred Zygote.gradient(SquaredHingeLoss(), ŷ, y) + + __f = Base.Fix2(SquaredHingeLoss(), y) + @eval @test_gradients $__f $ŷ gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 skip_tracker=$ongpu + end + + @testset "PoissonLoss" begin + y = [0.1, 0.2, 0.3] |> aType + ŷ = [0.4, 0.5, 0.6] |> aType + + @test Lux.PoissonLoss()(ŷ, y) ≈ 0.6278353988097339 + @test Lux.PoissonLoss()(y, y) ≈ 0.5044459776946685 + + @jet Lux.PoissonLoss()(ŷ, y) + @test_broken @inferred Zygote.gradient(Lux.PoissonLoss(), ŷ, y) + + __f = Base.Fix2(Lux.PoissonLoss(), y) + @eval @test_gradients $__f $ŷ gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 skip_tracker=$ongpu + end + + @testset "DiceCoeffLoss" begin + y = [1.0, 0.5, 0.3, 2.4] |> aType + ŷ = [0.0, 1.4, 0.5, 1.2] |> aType + + @test DiceCoeffLoss()(ŷ, y) ≈ 0.2799999999999999 + @test DiceCoeffLoss()(y, y) ≈ 0.0 + + @jet DiceCoeffLoss()(ŷ, y) + @test_broken @inferred Zygote.gradient(DiceCoeffLoss(), ŷ, y) + + __f = Base.Fix2(DiceCoeffLoss(), y) + @eval @test_gradients $__f $ŷ gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 skip_tracker=$ongpu + end + + @testset "Siamese Contrastive Loss" begin + y = [1 0 + 0 0 + 0 1] |> aType + ŷ = [0.4 0.2 + 0.5 0.5 + 0.1 0.3] |> aType + y1 = [1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0] |> aType + ŷ1 = softmax(reshape(-7:7, 3, 5) .* 1.0f0) + y2 = [1 + 0 + 0 + 1 + 1] |> aType + ŷ2 = [0.6 + 0.4 + 0.1 + 0.2 + 0.7] |> aType + + @test SiameseContrastiveLoss()(ŷ, y) ≈ 0.2333333333333333 + @test SiameseContrastiveLoss(; margin=0.5f0)(ŷ, y) ≈ 0.10000000000000002 + @test SiameseContrastiveLoss(; margin=1.5f0)(ŷ, y) ≈ 0.5333333333333333 + @test SiameseContrastiveLoss()(ŷ1, y1) ≈ 0.32554644f0 + @test SiameseContrastiveLoss(; margin=0.5f0)(ŷ1, y1) ≈ 0.16271012f0 + @test SiameseContrastiveLoss(; margin=1.5f0)(ŷ1, y1) ≈ 0.6532292f0 + @test SiameseContrastiveLoss(; margin=1)(ŷ, y) ≈ SiameseContrastiveLoss()(ŷ, y) + @test SiameseContrastiveLoss()(y, y) ≈ 0.0 + @test SiameseContrastiveLoss()(y1, y1) ≈ 0.0 + @test SiameseContrastiveLoss(; margin=0)(ŷ, y) ≈ 0.09166666666666667 + @test SiameseContrastiveLoss(; margin=0)(ŷ1, y1) ≈ 0.13161165f0 + @test SiameseContrastiveLoss()(ŷ2, y2) ≈ 0.21200000000000005 + @test SiameseContrastiveLoss()(ŷ2, ŷ2) ≈ 0.18800000000000003 + + @jet SiameseContrastiveLoss()(ŷ, y) + + @test_throws ArgumentError SiameseContrastiveLoss(; margin=-0.5) + @test_throws ArgumentError SiameseContrastiveLoss(; margin=-1) + end + end +end From d1c84485ec804498d5b1b7dd9b2c560101e56e36 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Jun 2024 20:07:43 -0700 Subject: [PATCH 11/13] Remove PartialFunctions --- Project.toml | 2 -- src/Lux.jl | 1 - src/helpers/losses.jl | 6 +++--- src/utils.jl | 8 ++++++++ test/helpers/loss_tests.jl | 16 ++++++++++++---- 5 files changed, 23 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 8c800a24d..101e945a0 100644 --- a/Project.toml +++ b/Project.toml @@ -23,7 +23,6 @@ LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5" -PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -105,7 +104,6 @@ NCCL = "0.1.1" OhMyThreads = "0.5.1" OneHotArrays = "0.2.5" Optimisers = "0.3" -PartialFunctions = "1.2.0" Pkg = "1.10" PrecompileTools = "1.2" Preferences = "1.4.3" diff --git a/src/Lux.jl b/src/Lux.jl index 0eac19186..7e74e8d67 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -16,7 +16,6 @@ using PrecompileTools: @recompile_invalidations using LossFunctions: LossFunctions using Markdown: @doc_str using OhMyThreads: tmapreduce - using PartialFunctions: @$ using Preferences: @load_preference using Random: Random, AbstractRNG using Reexport: @reexport diff --git a/src/helpers/losses.jl b/src/helpers/losses.jl index 90b269e73..cf02c07d2 100644 --- a/src/helpers/losses.jl +++ b/src/helpers/losses.jl @@ -488,7 +488,7 @@ true ``` """ function MSLELoss(; agg=mean, epsilon=nothing) - return GenericLossFunction(@$(__msle_loss(_, _, epsilon)); agg) + return GenericLossFunction(__Fix3(__msle_loss, epsilon); agg) end @doc doc""" @@ -509,7 +509,7 @@ true ``` """ function PoissonLoss(; agg=mean, epsilon=nothing) - return GenericLossFunction(@$(__poisson_loss(_, _, epsilon)); agg) + return GenericLossFunction(__Fix3(__poisson_loss, epsilon); agg) end @doc doc""" @@ -542,7 +542,7 @@ recognition (CVPR'06). Vol. 2. IEEE, 2006. """ function SiameseContrastiveLoss(; margin::Real=true, agg=mean) @argcheck margin ≥ 0 - return GenericLossFunction(@$(__siamese_contrastive_loss(_, _, margin)); agg) + return GenericLossFunction(__Fix3(__siamese_contrastive_loss, margin); agg) end @doc doc""" diff --git a/src/utils.jl b/src/utils.jl index 40d4b8871..479a36438 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -330,6 +330,14 @@ Return `x * log(y)` for `y > 0`, and zero when `x == 0`. end # Some functional forms of losses +@concrete struct __Fix3 + f + x +end + +Broadcast.broadcastable(f::__Fix3) = Ref(f) + +@inline (f::__Fix3)(a, b) = f.f(a, b, f.x) @inline function __siamese_contrastive_loss(x::T1, y::T2, margin=true) where {T1, T2} return (1 - y) * x^2 + y * max(promote_type(T1, T2)(0), margin - x)^2 diff --git a/test/helpers/loss_tests.jl b/test/helpers/loss_tests.jl index 3da1aa50e..92303312a 100644 --- a/test/helpers/loss_tests.jl +++ b/test/helpers/loss_tests.jl @@ -77,7 +77,11 @@ end @jet MSLELoss()(ŷ, y) - @test_broken @inferred Zygote.gradient(MSLELoss(), ŷ, y) + if ongpu + @test_broken @inferred Zygote.gradient(MSLELoss(), ŷ, y) + else + @inferred Zygote.gradient(MSLELoss(), ŷ, y) + end __f = Base.Fix2(MSLELoss(), y) @eval @test_gradients $__f $ŷ gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 skip_tracker=$ongpu @@ -187,7 +191,11 @@ end @jet bceloss(σ.(logŷ), y) @jet bceloss_smooth(σ.(logŷ), y) - @inferred Zygote.gradient(bceloss, σ.(logŷ), y) + if ongpu + @test_broken @inferred Zygote.gradient(bceloss, σ.(logŷ), y) + else + @inferred Zygote.gradient(bceloss, σ.(logŷ), y) + end __f = Base.Fix2(bceloss, y) σlogŷ = σ.(logŷ) @@ -242,7 +250,7 @@ end y = [1 0 0 0 1 0 1 0 1 0 0 0 1 0 0] |> aType - ŷ = softmax(reshape(-7:7, 3, 5) .* 1.0f0) + ŷ = softmax(reshape(-7:7, 3, 5) .* 1.0f0) |> aType y1 = [1 0 0 0 0 1] |> aType @@ -351,7 +359,7 @@ end y1 = [1 0 0 0 1 0 1 0 1 0 0 0 1 0 0] |> aType - ŷ1 = softmax(reshape(-7:7, 3, 5) .* 1.0f0) + ŷ1 = softmax(reshape(-7:7, 3, 5) .* 1.0f0) |> aType y2 = [1 0 0 From ce2ec21b82f35d6fc4d5fe55dd434d4471853627 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Jun 2024 20:59:22 -0700 Subject: [PATCH 12/13] Start replacing loss functions --- Project.toml | 2 +- examples/Basics/Project.toml | 5 ++-- examples/Basics/main.jl | 48 +++++++++++++++--------------------- src/helpers/losses.jl | 5 ++++ test/helpers/loss_tests.jl | 24 +++++++++--------- 5 files changed, 41 insertions(+), 43 deletions(-) diff --git a/Project.toml b/Project.toml index 101e945a0..08a2820b1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.55" +version = "0.5.56" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/examples/Basics/Project.toml b/examples/Basics/Project.toml index 11e49de55..9e0c4c294 100644 --- a/examples/Basics/Project.toml +++ b/examples/Basics/Project.toml @@ -6,14 +6,15 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ComponentArrays = "0.13, 0.14, 0.15" +ComponentArrays = "0.15" ForwardDiff = "0.10" Literate = "2" -Lux = "0.5" +Lux = "0.5.56" LuxCUDA = "0.2, 0.3" Optimisers = "0.2, 0.3" Zygote = "0.6" diff --git a/examples/Basics/main.jl b/examples/Basics/main.jl index 860aaf7e3..15f03f741 100644 --- a/examples/Basics/main.jl +++ b/examples/Basics/main.jl @@ -294,34 +294,26 @@ println("x shape: ", size(x_samples), "; y shape: ", size(y_samples)) # [Optimisers.jl](https://github.com/FluxML/Optimisers.jl). We will use Stochastic Gradient # Descent (SGD) with a learning rate of `0.01`. -using Optimisers - -opt = Optimisers.Descent(0.01f0) - -# Initialize the initial state of the optimiser -opt_state = Optimisers.setup(opt, ps) +using Optimisers, Printf # Define the loss function -function sse(model, ps, st, X, y) - y_pred, st_new = model(X, ps, st) - return sum(abs2, y_pred .- y), st_new -end -sse(weight, bias, X, y) = sum(abs2, weight * X .+ bias .- y) -loss_function(ps, X, y) = sse(model, ps, st, X, y) - -println("Loss Value with ground true parameters: ", sse(W, b, x_samples, y_samples)) - -for i in 1:100 - ## In actual code, don't use globals. But here I will simply for the sake of - ## demonstration - global ps, st, opt_state - ## Compute the gradient using the pullback API to update the states - (loss, st), pb_f = Zygote.pullback(loss_function, ps, x_samples, y_samples) - ## We pass nothing as the seed for `st`, since we don't want to propagate any gradient - ## for st - gs = pb_f((one(loss), nothing))[1] - ## Update model parameters - ## `Optimisers.update` can be used if mutation is not desired - opt_state, ps = Optimisers.update!(opt_state, ps, gs) - (i % 10 == 1 || i == 100) && println(lazy"Loss Value after $i iterations: $loss") +lossfn = MSELoss() + +println("Loss Value with ground true parameters: ", lossfn(W * x_samples .+ b, y_samples)) + +# We will train the model using our training API. +function train_model!(model, ps, st, opt, nepochs::Int) + tstate = Lux.Experimental.TrainState(model, ps, st, opt) + for i in 1:nepochs + grads, loss, _, tstate = Lux.Experimental.single_train_step!( + AutoZygote(), lossfn, (x_samples, y_samples), tstate) + if i % 1000 == 1 || i == nepochs + @printf "Loss Value after %6d iterations: %.8f\n" i loss + end + end + return tstate.model, tstate.parameters, tstate.states end + +model, ps, st = train_model!(model, ps, st, Descent(0.01f0), 10000) + +println("Loss Value after training: ", lossfn(first(model(x_samples, ps, st)), y_samples)) diff --git a/src/helpers/losses.jl b/src/helpers/losses.jl index cf02c07d2..91db1dfb3 100644 --- a/src/helpers/losses.jl +++ b/src/helpers/losses.jl @@ -4,6 +4,11 @@ DocTestFilters = r"[0-9\.]+f0" ``` abstract type AbstractLossFunction <: Function end +function (loss::AbstractLossFunction)(model::AbstractExplicitLayer, ps, st, (x, y)) + ŷ, st_ = model(x, ps, st) + return loss(ŷ, y), st_, (;) +end + function (loss::AbstractLossFunction)(ŷ, y) __check_sizes(ŷ, y) return __unsafe_apply_loss(loss, ŷ, y) diff --git a/test/helpers/loss_tests.jl b/test/helpers/loss_tests.jl index 92303312a..6faa8cd03 100644 --- a/test/helpers/loss_tests.jl +++ b/test/helpers/loss_tests.jl @@ -77,11 +77,7 @@ end @jet MSLELoss()(ŷ, y) - if ongpu - @test_broken @inferred Zygote.gradient(MSLELoss(), ŷ, y) - else - @inferred Zygote.gradient(MSLELoss(), ŷ, y) - end + @test_broken @inferred Zygote.gradient(MSLELoss(), ŷ, y) __f = Base.Fix2(MSLELoss(), y) @eval @test_gradients $__f $ŷ gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 skip_tracker=$ongpu @@ -191,11 +187,7 @@ end @jet bceloss(σ.(logŷ), y) @jet bceloss_smooth(σ.(logŷ), y) - if ongpu - @test_broken @inferred Zygote.gradient(bceloss, σ.(logŷ), y) - else - @inferred Zygote.gradient(bceloss, σ.(logŷ), y) - end + @inferred Zygote.gradient(bceloss, σ.(logŷ), y) __f = Base.Fix2(bceloss, y) σlogŷ = σ.(logŷ) @@ -240,7 +232,11 @@ end @jet BinaryFocalLoss()(ŷ, y) - @inferred Zygote.gradient(BinaryFocalLoss(), ŷ, y) + if ongpu + @test_broken @inferred Zygote.gradient(BinaryFocalLoss(), ŷ, y) + else + @inferred Zygote.gradient(BinaryFocalLoss(), ŷ, y) + end __f = Base.Fix2(BinaryFocalLoss(), y) @eval @test_gradients $__f $ŷ gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 skip_tracker=$ongpu @@ -264,7 +260,11 @@ end @jet FocalLoss()(ŷ, y) - @inferred Zygote.gradient(FocalLoss(), ŷ, y) + if ongpu + @test_broken @inferred Zygote.gradient(FocalLoss(), ŷ, y) + else + @inferred Zygote.gradient(FocalLoss(), ŷ, y) + end __f = Base.Fix2(FocalLoss(), y) # FD will lead to out of domain errors From 514414e79ac0caad80bc5817f191ebbcc77b1db6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 19 Jun 2024 21:14:42 -0700 Subject: [PATCH 13/13] Update examples using the new losses --- docs/src/api/Lux/contrib.md | 2 +- docs/src/api/Lux/utilities.md | 8 ++++++++ examples/ConvMixer/main.jl | 9 ++------- examples/DDIM/main.jl | 6 ++++-- examples/GravitationalWaveForm/main.jl | 7 ++++--- examples/HyperNet/main.jl | 9 ++------- examples/NeuralODE/main.jl | 9 ++------- examples/PolynomialFitting/main.jl | 9 +++------ examples/SimpleChains/main.jl | 7 +------ examples/SimpleRNN/main.jl | 23 ++++++++--------------- 10 files changed, 35 insertions(+), 54 deletions(-) diff --git a/docs/src/api/Lux/contrib.md b/docs/src/api/Lux/contrib.md index c258145f1..a5a4c3f15 100644 --- a/docs/src/api/Lux/contrib.md +++ b/docs/src/api/Lux/contrib.md @@ -20,7 +20,7 @@ All features listed on this page are **experimental** which means: Pages = ["contrib.md"] ``` -## Training +## [Training](@id Training-API) Helper Functions making it easier to train `Lux.jl` models. diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index e20b1956e..76aeb585f 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -8,6 +8,14 @@ Pages = ["utilities.md"] ## Loss Functions +Loss Functions Objects take 2 forms of inputs: + +1. $y\hat$ and $y$ where $y\hat$ is the predicted output and $y$ is the target output. +2. `model`, `ps`, `st`, `(x, y)` where `model` is the model, `ps` are the parameters, + `st` are the states and `(x, y)` are the input and target pair. Then it returns the + loss, updated states, and an empty named tuple. This makes them compatible with the + [Experimental Training API](@ref Training-API). + !!! warning When using ChainRules.jl compatible AD (like Zygote), we only compute the gradients diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index 5b0218073..a5d5a532c 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -55,13 +55,6 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) #! format: on end -logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1)) - -function loss(model, ps, st, (x, y)) - y_pred, st = model(x, ps, st) - return logitcrossentropy(y_pred, y), st, (;) -end - function accuracy(model, ps, st, dataloader; dev=gpu_device()) total_correct, total = 0, 0 st = Lux.testmode(st) @@ -94,6 +87,8 @@ end lr_schedule = linear_interpolation( [0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]) + loss = CrossEntropyLoss(; logits=Val(true)) + for epoch in 1:epochs stime = time() lr = 0 diff --git a/examples/DDIM/main.jl b/examples/DDIM/main.jl index c957cf9b2..448bceaac 100644 --- a/examples/DDIM/main.jl +++ b/examples/DDIM/main.jl @@ -291,10 +291,12 @@ function preprocess_image(image::Matrix{<:RGB}, image_size::Int) return apply(CenterResizeCrop((image_size, image_size)), Image(image)) |> itemdata end +const maeloss = MAELoss() + function loss_function(model, ps, st, data) (noises, images, pred_noises, pred_images), st = Lux.apply(model, data, ps, st) - noise_loss = mean(abs, noises .- pred_noises) - image_loss = mean(abs, images .- pred_images) + noise_loss = maeloss(noises, pred_noises) + image_loss = maeloss(images, pred_images) return noise_loss, st, (; image_loss, noise_loss) end diff --git a/examples/GravitationalWaveForm/main.jl b/examples/GravitationalWaveForm/main.jl index fc8bd4eee..718eefd46 100644 --- a/examples/GravitationalWaveForm/main.jl +++ b/examples/GravitationalWaveForm/main.jl @@ -22,7 +22,7 @@ CUDA.allowscalar(false) # We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position # vector $r = r_1 - r_2$ and use Newtonian formulas to get $r_1$, $r_2$ (e.g. Theoretical -# Mechanics of Particles and Continua 4.3) +# Mechanics of Particles and Continua 4.3) function one2two(path, m₁, m₂) M = m₁ + m₂ @@ -290,11 +290,12 @@ end # Next, we define the objective (loss) function to be minimized when training the neural # differential equations. +const mseloss = MSELoss() + function loss(θ) pred = Array(solve(prob_nn, RK4(); u0, p=θ, saveat=tsteps, dt, adaptive=false)) pred_waveform = first(compute_waveform(dt_data, pred, mass_ratio, ode_model_params)) - loss = sum(abs2, waveform .- pred_waveform) - return loss, pred_waveform + return mseloss(waveform, pred_waveform), pred_waveform end # Warmup the loss function diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index 7e1f2065d..3c6cdbd81 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -57,12 +57,7 @@ function create_model() end # ## Define Utility Functions -logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1)) - -function loss(model, ps, st, (data_idx, x, y)) - y_pred, st = model((data_idx, x), ps, st) - return logitcrossentropy(y_pred, y), st, (;) -end +const loss = CrossEntropyLoss(; logits=Val(true)) function accuracy(model, ps, st, dataloader, data_idx, gdev=gpu_device()) total_correct, total = 0, 0 @@ -101,7 +96,7 @@ function train() x = x |> dev y = y |> dev (_, _, _, train_state) = Lux.Experimental.single_train_step!( - AutoZygote(), loss, (data_idx, x, y), train_state) + AutoZygote(), loss, ((data_idx, x), y), train_state) end ttime = time() - stime diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index c7cda2421..039654491 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -105,12 +105,7 @@ function create_model(model_fn=NeuralODE; dev=gpu_device(), use_named_tuple::Boo end # ## Define Utility Functions -logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1)) - -function loss(model, ps, st, (x, y)) - y_pred, st = model(x, ps, st) - return logitcrossentropy(y_pred, y), st, (;) -end +const logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) function accuracy(model, ps, st, dataloader; dev=gpu_device()) total_correct, total = 0, 0 @@ -143,7 +138,7 @@ function train(model_function; cpu::Bool=false, kwargs...) x = dev(x) y = dev(y) _, _, _, tstate = Lux.Experimental.single_train_step!( - AutoZygote(), loss, (x, y), tstate) + AutoZygote(), logitcrossentropy, (x, y), tstate) end ttime = time() - stime diff --git a/examples/PolynomialFitting/main.jl b/examples/PolynomialFitting/main.jl index 703b9234d..4c0c1f1d2 100644 --- a/examples/PolynomialFitting/main.jl +++ b/examples/PolynomialFitting/main.jl @@ -51,12 +51,9 @@ opt = Adam(0.03f0) # We will use the `Lux.Training` API so we need to ensure that our loss function takes 4 # inputs -- model, parameters, states and data. The function must return 3 values -- loss, -# updated_state, and any computed statistics. -function loss_function(model, ps, st, data) - y_pred, st = Lux.apply(model, data[1], ps, st) - mse_loss = mean(abs2, y_pred .- data[2]) - return mse_loss, st, () -end +# updated_state, and any computed statistics. This is already satisfied by the loss +# functions provided by Lux. +const loss_function = MSELoss() # ## Training diff --git a/examples/SimpleChains/main.jl b/examples/SimpleChains/main.jl index da408300f..8bde951cc 100644 --- a/examples/SimpleChains/main.jl +++ b/examples/SimpleChains/main.jl @@ -44,12 +44,7 @@ adaptor = ToSimpleChainsAdaptor((static(28), static(28), static(1))) simple_chains_model = adaptor(lux_model) # ## Helper Functions -logitcrossentropy(y_pred, y) = mean(-sum(y .* logsoftmax(y_pred); dims=1)) - -function loss(model, ps, st, (x, y)) - y_pred, st = model(x, ps, st) - return logitcrossentropy(y_pred, y), st, (;) -end +const loss = CrossEntropyLoss(; logits=Val(true)) function accuracy(model, ps, st, dataloader) total_correct, total = 0, 0 diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index 0cdb3f05a..ffa3982ee 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -116,20 +116,12 @@ end # Now let's define the binarycrossentropy loss. Typically it is recommended to use # `logitbinarycrossentropy` since it is more numerically stable, but for the sake of # simplicity we will use `binarycrossentropy`. - -function xlogy(x, y) - result = x * log(y) - return ifelse(iszero(x), zero(result), result) -end - -function binarycrossentropy(y_pred, y_true) - y_pred = y_pred .+ eps(eltype(y_pred)) - return mean(@. -xlogy(y_true, y_pred) - xlogy(1 - y_true, 1 - y_pred)) -end +const lossfn = BinaryCrossEntropyLoss() function compute_loss(model, ps, st, (x, y)) - y_pred, st = model(x, ps, st) - return binarycrossentropy(y_pred, y), st, (; y_pred=y_pred) + ŷ, st_ = model(x, ps, st) + loss = lossfn(ŷ, y) + return loss, st_, (; y_pred=ŷ) end matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true) @@ -156,7 +148,7 @@ function main(model_type) y = y |> dev (_, loss, _, train_state) = Lux.Experimental.single_train_step!( - AutoZygote(), compute_loss, (x, y), train_state) + AutoZygote(), lossfn, (x, y), train_state) @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss end @@ -166,8 +158,9 @@ function main(model_type) for (x, y) in val_loader x = x |> dev y = y |> dev - loss, st_, ret = compute_loss(model, train_state.parameters, st_, (x, y)) - acc = accuracy(ret.y_pred, y) + ŷ, st_ = model(x, train_state.parameters, st_) + loss = lossfn(ŷ, y) + acc = accuracy(ŷ, y) @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc end end