From ccd0ead9e18111cdc7810b7dfc628d08faae5d48 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 10 Jul 2024 22:36:31 -0700 Subject: [PATCH] fix!: remove potentially incorrect Tracker gradients for SimpleChains --- ext/LuxTrackerExt.jl | 24 ------------------------ src/layers/extension.jl | 11 +---------- test/transform/simple_chains_tests.jl | 6 +++--- 3 files changed, 4 insertions(+), 37 deletions(-) diff --git a/ext/LuxTrackerExt.jl b/ext/LuxTrackerExt.jl index c65ed3644b..a0880d9469 100644 --- a/ext/LuxTrackerExt.jl +++ b/ext/LuxTrackerExt.jl @@ -70,30 +70,6 @@ LuxCore.apply(m::AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) return ArrayInterface.aos_to_soa(reverse(x; dims)) end -# SimpleChains.jl: DON'T REPLACE THESE WITH @grad_from_chainrules -for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray) - T1 === :AbstractArray && T2 === :AbstractArray && continue - - @eval function Lux.__apply_simple_chain(layer, x::$(T1), ps::$(T2), dev::LuxCPUDevice) - return Tracker.track(Lux.__apply_simple_chain, layer, x, ps, dev) - end -end - -Tracker.@grad function Lux.__apply_simple_chain(layer, x, ps, ::LuxCPUDevice) - @warn "`Tracker.jl` often produces incorrect gradients for `SimpleChains.jl` models. \ - As such please test your model with FiniteDifferences or Zygote before using \ - `Tracker.jl` for your model." maxlog=1 - y, pb_f = CRC.rrule(layer, Tracker.data(x), Tracker.data(ps)) - __∇apply_simple_chain = let pb_f = pb_f - Δ -> begin - _, ∂x, ∂ps = pb_f(convert(Array, Tracker.data(Δ))) - return Tracker.nobacksies(:__apply_simple_chain, (nothing, ∂x, ∂ps, nothing)) - end - end - # Tracker is not great at handling arbitrary types, so we convert to Array - return Array(y), __∇apply_simple_chain -end - # DynamicExpressions.jl for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray) T1 === :AbstractArray && T2 === :AbstractArray && continue diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 3fac617904..972da75a38 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -210,16 +210,7 @@ regular `Array` or not. Default is `false`. ## Arguments - `layer`: SimpleChains layer - -!!! note - - If using `Tracker.jl`, the output will always be a regular `Array`. - -!!! danger - - `Tracker.jl` sometimes produces incorrect gradients for `SimpleChains.jl` models. As - such please test your model with FiniteDifferences or Zygote before using `Tracker.jl` - for your model. + - `lux_layer`: Potentially equivalent Lux layer that is used for printing """ struct SimpleChainsLayer{ToArray, SL, LL <: Union{Nothing, AbstractExplicitLayer}} <: AbstractExplicitLayer diff --git a/test/transform/simple_chains_tests.jl b/test/transform/simple_chains_tests.jl index b688a467c3..450fe88082 100644 --- a/test/transform/simple_chains_tests.jl +++ b/test/transform/simple_chains_tests.jl @@ -28,7 +28,7 @@ @test size(gs[1]) == size(x) @test length(gs[2].params) == length(ps.params) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true x = randn(Float32, 28, 28, 1, 15) @test size(first(simple_chains_model(x, ps, st))) == (10, 15) @@ -39,7 +39,7 @@ @test size(gs[1]) == size(x) @test length(gs[2].params) == length(ps.params) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true @testset "Array Output" begin adaptor = ToSimpleChainsAdaptor((static(28), static(28), static(1)), true) @@ -98,7 +98,7 @@ @test size(gs[1]) == size(x) @test length(gs[2].params) == length(ps.params) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true end end