Skip to content

Commit

Permalink
cleanup Reactant and Enzyme tests (#2578)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Jan 6, 2025
1 parent 1ec93e9 commit 4eb4454
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 30 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down Expand Up @@ -61,7 +60,6 @@ OneHotArrays = "0.2.4"
Optimisers = "0.4.1"
Preferences = "1"
ProgressLogging = "0.1"
Reactant = "0.2.16"
Reexport = "1.0"
Setfield = "1.1"
SpecialFunctions = "2.1.2"
Expand Down
1 change: 0 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
22 changes: 11 additions & 11 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@ using Enzyme: Enzyme, Duplicated, Const, Active
(Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2), randn(Float32, 2), "Flux.Scale"),
(Conv((3, 3), 2 => 3), randn(Float32, 3, 3, 2, 1), "Conv"),
(Chain(Conv((3, 3), 2 => 3, ), Conv((3, 3), 3 => 1, tanh)), rand(Float32, 5, 5, 2, 1), "Chain(Conv, Conv)"),
# (Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"),
(Chain(Conv((4, 4), 2 => 2, pad=SamePad()), MeanPool((5, 5), pad=SamePad())), rand(Float32, 5, 5, 2, 2), "Chain(Conv, MeanPool)"),
(Maxout(() -> Dense(5 => 4, tanh), 3), randn(Float32, 5, 1), "Maxout"),
(SkipConnection(Dense(2 => 2), vcat), randn(Float32, 2, 3), "SkipConnection"),
# (Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"), # Passes on 1.10, fails on 1.11 with MethodError: no method matching function_attributes(::LLVM.UserOperandSet)
(Flux.Bilinear((2, 2) => 3), randn(Float32, 2, 1), "Bilinear"),
(ConvTranspose((3, 3), 3 => 2, stride=2), rand(Float32, 5, 5, 3, 1), "ConvTranspose"),
(first LayerNorm(2), randn(Float32, 2, 10), "LayerNorm"),
# (BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"), # AssertionError: Base.isconcretetype(typ)
# (first ∘ MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"), # AssertionError: Base.isconcretetype(typ)
(BatchNorm(2), randn(Float32, 2, 10), "BatchNorm"),
(first MultiHeadAttention(16), randn32(16, 20, 2), "MultiHeadAttention"),
]

for (model, x, name) in models_xs
@testset "Enzyme grad check $name" begin
println("testing $name with Enzyme")
test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true)
test_gradients(model, x; loss, compare_finite_diff=false, test_enzyme=true)
end
end
end
Expand All @@ -36,17 +36,17 @@ end
end

models_xs = [
# (RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
# (LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
# (GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
# (Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
# (Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
(RNN(3 => 2), randn(Float32, 3, 2), "RNN"),
(LSTM(3 => 5), randn(Float32, 3, 2), "LSTM"),
(GRU(3 => 5), randn(Float32, 3, 10), "GRU"),
(Chain(RNN(3 => 4), RNN(4 => 3)), randn(Float32, 3, 2), "Chain(RNN, RNN)"),
(Chain(LSTM(3 => 5), LSTM(5 => 3)), randn(Float32, 3, 2), "Chain(LSTM, LSTM)"),
]

for (model, x, name) in models_xs
@testset "check grad $name" begin
println("testing $name")
test_gradients(model, x; loss, compare_finite_diff=false, compare_enzyme=true)
test_gradients(model, x; loss, compare_finite_diff=false, test_enzyme=true)
end
end
end
Expand Down
13 changes: 13 additions & 0 deletions test/ext_reactant/test_utils_reactant.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# These are used only in test_utils.jl but cannot leave there
# because Reactant is only optionally loaded and the macros fail when it is not loaded.

function reactant_withgradient(f, x...)
y, g = Reactant.@jit enzyme_withgradient(f, x...)
return y, g
end

function reactant_loss(loss, x...)
l = Reactant.@jit loss(x...)
@test l isa Reactant.ConcreteRNumber
return l
end
22 changes: 14 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ using Pkg
using FiniteDifferences: FiniteDifferences
using Functors: fmapstructure_with_path

using Reactant

## Uncomment below to change the default test settings
# ENV["FLUX_TEST_AMDGPU"] = "true"
# ENV["FLUX_TEST_CUDA"] = "true"
Expand All @@ -23,20 +21,20 @@ using Reactant
# ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true"
# ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true"
# ENV["FLUX_TEST_ENZYME"] = "false"
# ENV["FLUX_TEST_REACTANT"] = "false"

const FLUX_TEST_ENZYME = get(ENV, "FLUX_TEST_ENZYME", VERSION < v"1.12-" ? "true" : "false") == "true"
const FLUX_TEST_REACTANT = get(ENV, "FLUX_TEST_REACTANT", VERSION < v"1.12-" && !Sys.iswindows() ? "true" : "false") == "true"

# Reactant will automatically select a GPU backend, if available, and TPU backend, if available.
# Otherwise it will fall back to CPU.
const FLUX_TEST_REACTANT = get(ENV, "FLUX_TEST_REACTANT",
VERSION < v"1.12-" && !Sys.iswindows() ? "true" : "false") == "true"

if FLUX_TEST_ENZYME || FLUX_TEST_REACTANT
Pkg.add("Enzyme")
using Enzyme: Enzyme
end

if FLUX_TEST_REACTANT
Pkg.add("Reactant")
using Reactant: Reactant
end

include("test_utils.jl") # for test_gradients

Random.seed!(0)
Expand Down Expand Up @@ -182,7 +180,15 @@ end
end

if FLUX_TEST_REACTANT
## This Pg.add has to be done after Pkg.add("CUDA") otherwise CUDA.jl
## will not be functional and complain with:
# ┌ Error: CUDA.jl could not find an appropriate CUDA runtime to use.
#
# │ CUDA.jl's JLLs were precompiled without an NVIDIA driver present.
Pkg.add("Reactant")
using Reactant: Reactant
@testset "Reactant" begin
include("ext_reactant/test_utils_reactant.jl")
include("ext_reactant/reactant.jl")
end
else
Expand Down
22 changes: 14 additions & 8 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ end

function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
fmapstructure_with_path(a, b) do kp, x, y
# @show kp
if x isa AbstractArray
@test x y rtol=rtol atol=atol
elseif x isa Number
Expand All @@ -45,23 +46,29 @@ function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4)
end
end

# By default, this computes the gradients on cpu using the default AD (Zygote)
# and compares them with finite differences.
# Changing the arguments, you can assume the cpu Zygote gradients as the ground truth
# and test other scenarios.
function test_gradients(
f,
xs...;
rtol=1e-4, atol=1e-4,
test_gpu = false,
test_reactant = false,
test_enzyme = false,
test_grad_f = true,
test_grad_x = true,
compare_finite_diff = true,
compare_enzyme = false,
loss = (f, xs...) -> mean(f(xs...)),
)

if !test_gpu && !compare_finite_diff && !compare_enzyme && !test_reactant
if !test_gpu && !compare_finite_diff && !test_enzyme && !test_reactant
error("You should either compare numerical gradients methods or CPU vs GPU.")
end

Flux.trainmode!(f) # for layers like BatchNorm

## Let's make sure first that the forward pass works.
l = loss(f, xs...)
@test l isa Number
Expand All @@ -79,8 +86,7 @@ function test_gradients(
cpu_dev = cpu_device()
xs_re = xs |> reactant_dev
f_re = f |> reactant_dev
l_re = Reactant.@jit loss(f_re, xs_re...)
@test l_re isa Reactant.ConcreteRNumber
l_re = reactant_loss(loss, f_re, xs_re...)
@test l l_re rtol=rtol atol=atol
end

Expand All @@ -97,7 +103,7 @@ function test_gradients(
check_equal_leaves(g, g_fd; rtol, atol)
end

if compare_enzyme
if test_enzyme
y_ez, g_ez = enzyme_withgradient((xs...) -> loss(f, xs...), xs...)
@test y y_ez rtol=rtol atol=atol
check_equal_leaves(g, g_ez; rtol, atol)
Expand All @@ -113,7 +119,7 @@ function test_gradients(

if test_reactant
# Enzyme gradient with respect to input on Reactant.
y_re, g_re = Reactant.@jit enzyme_withgradient((xs...) -> loss(f_re, xs...), xs_re...)
y_re, g_re = reactant_withgradient((xs...) -> loss(f_re, xs...), xs_re...)
@test y y_re rtol=rtol atol=atol
check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
end
Expand All @@ -133,7 +139,7 @@ function test_gradients(
check_equal_leaves(g, g_fd; rtol, atol)
end

if compare_enzyme
if test_enzyme
y_ez, g_ez = enzyme_withgradient(f -> loss(f, xs...), f)
@test y y_ez rtol=rtol atol=atol
check_equal_leaves(g, g_ez; rtol, atol)
Expand All @@ -149,7 +155,7 @@ function test_gradients(

if test_reactant
# Enzyme gradient with respect to input on Reactant.
y_re, g_re = Reactant.@jit enzyme_withgradient(f -> loss(f, xs_re...), f_re)
y_re, g_re = reactant_withgradient(f -> loss(f, xs_re...), f_re)
@test y y_re rtol=rtol atol=atol
check_equal_leaves(g_re |> cpu_dev, g; rtol, atol)
end
Expand Down

0 comments on commit 4eb4454

Please sign in to comment.