Skip to content

Commit

Permalink
fix!: remove potentially incorrect Tracker gradients for SimpleChains
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 11, 2024
1 parent 0fff825 commit ccd0ead
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 37 deletions.
24 changes: 0 additions & 24 deletions ext/LuxTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,30 +70,6 @@ LuxCore.apply(m::AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st)
return ArrayInterface.aos_to_soa(reverse(x; dims))
end

# SimpleChains.jl: DON'T REPLACE THESE WITH @grad_from_chainrules
for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray)
T1 === :AbstractArray && T2 === :AbstractArray && continue

@eval function Lux.__apply_simple_chain(layer, x::$(T1), ps::$(T2), dev::LuxCPUDevice)
return Tracker.track(Lux.__apply_simple_chain, layer, x, ps, dev)
end
end

Tracker.@grad function Lux.__apply_simple_chain(layer, x, ps, ::LuxCPUDevice)
@warn "`Tracker.jl` often produces incorrect gradients for `SimpleChains.jl` models. \
As such please test your model with FiniteDifferences or Zygote before using \
`Tracker.jl` for your model." maxlog=1
y, pb_f = CRC.rrule(layer, Tracker.data(x), Tracker.data(ps))
__∇apply_simple_chain = let pb_f = pb_f
Δ -> begin
_, ∂x, ∂ps = pb_f(convert(Array, Tracker.data(Δ)))
return Tracker.nobacksies(:__apply_simple_chain, (nothing, ∂x, ∂ps, nothing))
end
end
# Tracker is not great at handling arbitrary types, so we convert to Array
return Array(y), __∇apply_simple_chain
end

# DynamicExpressions.jl
for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray)
T1 === :AbstractArray && T2 === :AbstractArray && continue
Expand Down
11 changes: 1 addition & 10 deletions src/layers/extension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,7 @@ regular `Array` or not. Default is `false`.
## Arguments
- `layer`: SimpleChains layer
!!! note
If using `Tracker.jl`, the output will always be a regular `Array`.
!!! danger
`Tracker.jl` sometimes produces incorrect gradients for `SimpleChains.jl` models. As
such please test your model with FiniteDifferences or Zygote before using `Tracker.jl`
for your model.
- `lux_layer`: Potentially equivalent Lux layer that is used for printing
"""
struct SimpleChainsLayer{ToArray, SL, LL <: Union{Nothing, AbstractExplicitLayer}} <:
AbstractExplicitLayer
Expand Down
6 changes: 3 additions & 3 deletions test/transform/simple_chains_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
@test size(gs[1]) == size(x)
@test length(gs[2].params) == length(ps.params)

@eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3
@eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true

x = randn(Float32, 28, 28, 1, 15)
@test size(first(simple_chains_model(x, ps, st))) == (10, 15)
Expand All @@ -39,7 +39,7 @@
@test size(gs[1]) == size(x)
@test length(gs[2].params) == length(ps.params)

@eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3
@eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true

@testset "Array Output" begin
adaptor = ToSimpleChainsAdaptor((static(28), static(28), static(1)), true)
Expand Down Expand Up @@ -98,7 +98,7 @@
@test size(gs[1]) == size(x)
@test length(gs[2].params) == length(ps.params)

@eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3
@eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
end
end

Expand Down

0 comments on commit ccd0ead

Please sign in to comment.