From 19c3769f8cbfe0aae0a15e67909673fd9978d549 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 7 Dec 2024 01:45:46 +0100 Subject: [PATCH 1/5] adapt to Flux v0.15 --- DifferentiationInterfaceTest/Project.toml | 4 +- .../DifferentiationInterfaceTestFluxExt.jl | 73 +++++++++---------- 2 files changed, 38 insertions(+), 39 deletions(-) diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index bfa84d4b4..77f5a5116 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -15,6 +15,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] @@ -46,7 +47,7 @@ DifferentiationInterface = "0.6.0" DocStringExtensions = "0.8,0.9" ExplicitImports = "1.10.1" FiniteDifferences = "0.12" -Flux = "0.13,0.14" +Flux = "0.15" ForwardDiff = "0.10.36" Functors = "0.4, 0.5" JET = "0.4 - 0.8, 0.9" @@ -61,6 +62,7 @@ SparseArrays = "<0.0.1,1" SparseConnectivityTracer = "0.5.0,0.6" SparseMatrixColorings = "0.4.9" StaticArrays = "1.9" +Statistics = "1" Test = "<0.0.1,1" Zygote = "0.6" julia = "1.10" diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl index 9b7ef48a1..c6ba08f27 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl @@ -10,11 +10,11 @@ using Flux: Conv, ConvTranspose, Dense, - GRU, - LSTM, + GRU, GRUCell, + LSTM, LSTMCell, Maxout, MeanPool, - RNN, + RNN, RNNCell, SamePad, Scale, SkipConnection, @@ -24,6 +24,7 @@ using Flux: relu using Functors: @functor, fmapstructure_with_path, fleaves using LinearAlgebra +using Statistics: mean using Random: AbstractRNG, default_rng #= @@ -57,17 +58,18 @@ function DIT.flux_isapprox(a, b; atol, rtol) end function square_loss(model, x) - Flux.reset!(model) - return sum(abs2, model(x)) + y = model(x) + y = y isa Tuple ? y[1] : y # handle LSTM + return mean(abs2, y) end -function square_loss_iterated(model, x) - Flux.reset!(model) - y = copy(x) - for _ in 1:3 - y = model(y) +function square_loss_iterated(cell, x) + st = cell(x) # uses default initial state + for _ in 1:2 + st = cell(x, st) end - return sum(abs2, y) + y = st isa Tuple ? st[1] : st # handle LSTM + return mean(abs2, y) end struct SimpleDense{W,B,F} @@ -132,23 +134,7 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) Maxout(() -> Dense(5 => 4, tanh; init), 3), randn(rng, Float32, 5, 1) ), - ( - RNN(3 => 2; init), - randn(rng, Float32, 3, 2) - ), - ( - Chain(RNN(3 => 4; init), RNN(4 => 3; init)), - randn(rng, Float32, 3, 2) - ), - ( - LSTM(3 => 5; init), - randn(rng, Float32, 3, 2) - ), - ( - Chain(LSTM(3 => 5; init), LSTM(5 => 3; init)), - randn(rng, Float32, 3, 2) - ), - ( + ( SkipConnection(Dense(2 => 2; init), vcat), randn(rng, Float32, 2, 3) ), @@ -156,14 +142,22 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) Bilinear((2, 2) => 3; init), randn(rng, Float32, 2, 1) ), - ( - GRU(3 => 5; init), - randn(rng, Float32, 3, 10) - ), ( ConvTranspose((3, 3), 3 => 2; stride=2, init), rand(rng, Float32, 5, 5, 3, 1) ), + ( + RNN(3 => 4; init_kernel=init, init_recurrent_kernel=init), + randn(rng, Float32, 3, 2, 1) + ), + ( + LSTM(3 => 4; init_kernel=init, init_recurrent_kernel=init), + randn(rng, Float32, 3, 2, 1) + ), + ( + GRU(3 => 4; init_kernel=init, init_recurrent_kernel=init), + randn(rng, Float32, 3, 2, 1) + ), #! format: on ] @@ -176,16 +170,20 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) push!(scens, scen) end - # Recurrence + # Recurrent Cells recurrent_models_and_xs = [ #! format: off ( - RNN(3 => 3; init), + RNNCell(3 => 3; init_kernel=init, init_recurrent_kernel=init), + randn(rng, Float32, 3, 2) + ), + ( + LSTMCell(3 => 3; init_kernel=init, init_recurrent_kernel=init), randn(rng, Float32, 3, 2) ), ( - LSTM(3 => 3; init), + GRUCell(3 => 3; init_kernel=init, init_recurrent_kernel=init), randn(rng, Float32, 3, 2) ), #! format: on @@ -193,12 +191,11 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) for (model, x) in recurrent_models_and_xs Flux.trainmode!(model) - g = gradient_finite_differences(square_loss, model, x) + g = gradient_finite_differences(square_loss_iterated, model, x) scen = DIT.Scenario{:gradient,:out}( square_loss_iterated, model; contexts=(DI.Constant(x),), res1=g ) - # TODO: figure out why these tests are broken - # push!(scens, scen) + push!(scens, scen) end return scens From 70f723e7d6ece2f71e4b4924be074ea0faa06b18 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 7 Dec 2024 01:55:13 +0100 Subject: [PATCH 2/5] test also Enzyme --- DifferentiationInterface/test/Down/Flux/Project.toml | 7 +++++++ DifferentiationInterface/test/Down/Flux/test.jl | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) create mode 100644 DifferentiationInterface/test/Down/Flux/Project.toml diff --git a/DifferentiationInterface/test/Down/Flux/Project.toml b/DifferentiationInterface/test/Down/Flux/Project.toml new file mode 100644 index 000000000..3235152e8 --- /dev/null +++ b/DifferentiationInterface/test/Down/Flux/Project.toml @@ -0,0 +1,7 @@ +[deps] +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/DifferentiationInterface/test/Down/Flux/test.jl b/DifferentiationInterface/test/Down/Flux/test.jl index be1b7ad1a..4d917be4f 100644 --- a/DifferentiationInterface/test/Down/Flux/test.jl +++ b/DifferentiationInterface/test/Down/Flux/test.jl @@ -15,7 +15,7 @@ LOGGING = get(ENV, "CI", "false") == "false" test_differentiation( [ AutoZygote(), - # AutoEnzyme() # TODO: fix + AutoEnzyme(), ], DIT.flux_scenarios(Random.MersenneTwister(0)); isapprox=DIT.flux_isapprox, From 1e97549a4a6991827c27ee41764226e6537d09cc Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 7 Dec 2024 01:59:13 +0100 Subject: [PATCH 3/5] cleanup --- DifferentiationInterface/test/Down/Flux/Project.toml | 7 ------- DifferentiationInterface/test/Down/Flux/test.jl | 2 +- 2 files changed, 1 insertion(+), 8 deletions(-) delete mode 100644 DifferentiationInterface/test/Down/Flux/Project.toml diff --git a/DifferentiationInterface/test/Down/Flux/Project.toml b/DifferentiationInterface/test/Down/Flux/Project.toml deleted file mode 100644 index 3235152e8..000000000 --- a/DifferentiationInterface/test/Down/Flux/Project.toml +++ /dev/null @@ -1,7 +0,0 @@ -[deps] -DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/DifferentiationInterface/test/Down/Flux/test.jl b/DifferentiationInterface/test/Down/Flux/test.jl index 4d917be4f..804bad70e 100644 --- a/DifferentiationInterface/test/Down/Flux/test.jl +++ b/DifferentiationInterface/test/Down/Flux/test.jl @@ -15,7 +15,7 @@ LOGGING = get(ENV, "CI", "false") == "false" test_differentiation( [ AutoZygote(), - AutoEnzyme(), + # AutoEnzyme(), # TODO a few scenarios fail ], DIT.flux_scenarios(Random.MersenneTwister(0)); isapprox=DIT.flux_isapprox, From ff355284fadf7f1d3ccef41711b87e9305b30d0e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 31 Dec 2024 18:32:44 +0100 Subject: [PATCH 4/5] update flux v0.16 --- DifferentiationInterfaceTest/Project.toml | 2 +- .../DifferentiationInterfaceTestFluxExt.jl | 25 ++++++++----------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index 77f5a5116..c5f84e22e 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -47,7 +47,7 @@ DifferentiationInterface = "0.6.0" DocStringExtensions = "0.8,0.9" ExplicitImports = "1.10.1" FiniteDifferences = "0.12" -Flux = "0.15" +Flux = "0.16" ForwardDiff = "0.10.36" Functors = "0.4, 0.5" JET = "0.4 - 0.8, 0.9" diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl index c6ba08f27..ee0d5b168 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl @@ -44,31 +44,22 @@ end function DIT.flux_isapprox(a, b; atol, rtol) isapprox_results = fmapstructure_with_path(a, b) do kp, x, y - if :state in kp # ignore RNN and LSTM state + if x isa AbstractArray{<:Number} + return isapprox(x, y; atol, rtol) + else # ignore non-arrays return true - else - if x isa AbstractArray{<:Number} - return isapprox(x, y; atol, rtol) - else # ignore non-arrays - return true - end end end return all(fleaves(isapprox_results)) end -function square_loss(model, x) - y = model(x) - y = y isa Tuple ? y[1] : y # handle LSTM - return mean(abs2, y) -end +square_loss(model, x) = mean(abs2, model(x)) function square_loss_iterated(cell, x) - st = cell(x) # uses default initial state + y, st = cell(x) # uses default initial state for _ in 1:2 - st = cell(x, st) + y, st = cell(x, st) end - y = st isa Tuple ? st[1] : st # handle LSTM return mean(abs2, y) end @@ -158,6 +149,10 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng()) GRU(3 => 4; init_kernel=init, init_recurrent_kernel=init), randn(rng, Float32, 3, 2, 1) ), + ( + Chain(LSTM(3 => 4), RNN(4 => 5), Dense(5 => 2)), + randn(rng, Float32, 3, 2, 1) + ), #! format: on ] From 0bea8610758e93398d48d2e5772a1fd4c719cc8a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 2 Jan 2025 10:04:09 +0100 Subject: [PATCH 5/5] Format --- .../DifferentiationInterfaceTestFluxExt.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl index ee0d5b168..d0825cee3 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl @@ -10,11 +10,14 @@ using Flux: Conv, ConvTranspose, Dense, - GRU, GRUCell, - LSTM, LSTMCell, + GRU, + GRUCell, + LSTM, + LSTMCell, Maxout, MeanPool, - RNN, RNNCell, + RNN, + RNNCell, SamePad, Scale, SkipConnection,