Skip to content

Commit

Permalink
test: logsoftmax and softmax forwarddiff rules
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 31, 2024
1 parent c8ec0e0 commit 50bc991
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
4 changes: 2 additions & 2 deletions lib/LuxLib/src/impl/forward_diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ function softmax_dual(
y = NNlib.softmax(x_data; dims)
dysᵢ = ntuple(P) do i
v = partial_fn.(x, i)
return y .* (v .- LinearAlgebra.dot(y, v))
return y .* (v .- sum(y .* v; dims))
end

partials = ForwardDiff.Partials.(tuple.(dysᵢ...))
Expand All @@ -84,7 +84,7 @@ function logsoftmax_dual(
y = NNlib.softmax(x_data; dims)
dysᵢ = ntuple(P) do i
v = partial_fn.(x, i)
return v .- LinearAlgebra.dot(y, v)
return v .- sum(y .* v; dims)
end

partials = ForwardDiff.Partials.(tuple.(dysᵢ...))
Expand Down
35 changes: 31 additions & 4 deletions lib/LuxLib/test/others/forwarddiff_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,20 @@

function test_jvp_computation(f::F, x, u, ongpu, nested=false) where {F}
jvp₁ = jvp_forwarddiff(f, x, u)

if !(x isa ComponentArray && ongpu)
# ComponentArray + ForwardDiff on GPU don't play nice
jvp₂ = jvp_forwarddiff_concrete(f, x, u)
@test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5)
@testset "JVP ForwardDiff Concrete" begin
jvp₂ = jvp_forwarddiff_concrete(f, x, u)
@test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5)
end
end

if !nested
jvp₃ = jvp_zygote(f, x, u)
@test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5)
@testset "JVP Zygote" begin
jvp₃ = jvp_zygote(f, x, u)
@test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5)
end
end
end

Expand Down Expand Up @@ -89,6 +94,28 @@
true)
end
end

@testset for op in (logsoftmax, softmax)
@testset for (input_dim, dim) in zip(
(
(2, 3), (2, 3), (2, 3, 4, 5),
(2, 3, 4, 5), (2, 3, 4, 5), (2, 3, 4, 5)
),
(1, 2, 1, 2, 3, 4)
)
x = randn(Float32, input_dim) |> aType
u = randn(Float32, input_dim) |> aType

test_jvp_computation(x -> op(x; dims=dim), x, u, ongpu)
test_jvp_computation(
x -> op(x; dims=dim), ComponentArray(; x), u, ongpu)

test_jvp_computation(
x -> only(Zygote.gradient(x -> sum(op(x; dims=dim)), x)),
x, u, ongpu, true
)
end
end
end
end

Expand Down

0 comments on commit 50bc991

Please sign in to comment.