From 10884314848a167818942b7330d624a471f20f6f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 25 Nov 2024 14:28:05 -0500 Subject: [PATCH] fix: avoid closures in batched_jacobian --- ext/LuxEnzymeExt/batched_autodiff.jl | 16 +++++++++------- src/autodiff/nested_autodiff.jl | 10 ++++++---- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/ext/LuxEnzymeExt/batched_autodiff.jl b/ext/LuxEnzymeExt/batched_autodiff.jl index 32cf8c2c3e..e73dcad2ab 100644 --- a/ext/LuxEnzymeExt/batched_autodiff.jl +++ b/ext/LuxEnzymeExt/batched_autodiff.jl @@ -1,11 +1,11 @@ -function Lux.AutoDiffInternalImpl.batched_jacobian_impl( - f::F, ad::AutoEnzyme, x::AbstractArray) where {F} +function Lux.AutoDiffInternalImpl.batched_jacobian_internal( + f::F, ad::AutoEnzyme, x::AbstractArray, args...) where {F} backend = normalize_backend(True(), ad) - return batched_enzyme_jacobian_impl(f, backend, ADTypes.mode(backend), x) + return batched_enzyme_jacobian_impl(f, backend, ADTypes.mode(backend), x, args...) end function batched_enzyme_jacobian_impl( - f_orig::G, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {G} + f_orig::G, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray, args...) where {G} # We need to run the function once to get the output type. Can we use ForwardWithPrimal? y = f_orig(x) f = annotate_function(ad, f_orig) @@ -26,7 +26,8 @@ function batched_enzyme_jacobian_impl( for i in 1:chunk_size:(length(x) ÷ B) idxs = i:min(i + chunk_size - 1, length(x) ÷ B) partials′ = make_onehot!(partials, idxs) - J_partials = only(Enzyme.autodiff(ad.mode, f, BatchDuplicated(x, partials′))) + J_partials = only(Enzyme.autodiff( + ad.mode, f, BatchDuplicated(x, partials′), Const.(args)...)) for (idx, J_partial) in zip(idxs, J_partials) copyto!(view(J, :, idx, :), reshape(J_partial, :, B)) end @@ -36,7 +37,7 @@ function batched_enzyme_jacobian_impl( end function batched_enzyme_jacobian_impl( - f_orig::G, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {G} + f_orig::G, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray, args...) where {G} # We need to run the function once to get the output type. Can we use ReverseWithPrimal? y = f_orig(x) @@ -60,7 +61,8 @@ function batched_enzyme_jacobian_impl( partials′ = make_onehot!(partials, idxs) J_partials′ = make_zero!(J_partials, idxs) Enzyme.autodiff( - ad.mode, fn, BatchDuplicated(y, partials′), BatchDuplicated(x, J_partials′) + ad.mode, fn, BatchDuplicated(y, partials′), + BatchDuplicated(x, J_partials′), Const.(args)... ) for (idx, J_partial) in zip(idxs, J_partials) copyto!(view(J, idx, :, :), reshape(J_partial, :, B)) diff --git a/src/autodiff/nested_autodiff.jl b/src/autodiff/nested_autodiff.jl index dfc94ad6f6..6467fd1f3a 100644 --- a/src/autodiff/nested_autodiff.jl +++ b/src/autodiff/nested_autodiff.jl @@ -1,10 +1,10 @@ ## Written like this to avoid dynamic dispatch from Zygote # Input Gradient / Jacobian function rewrite_autodiff_call(f::ComposedFunction{F, <:StatefulLuxLayer}) where {F} - (f, f.inner.ps) + return f, f.inner.ps end function rewrite_autodiff_call(f::ComposedFunction{<:StatefulLuxLayer, F}) where {F} - (@closure((x, ps)->f.outer(f.inner(x), ps)), f.outer.ps) + return @closure((x, ps)->f.outer(f.inner(x), ps)), f.outer.ps end rewrite_autodiff_call(f::StatefulLuxLayer) = f, f.ps @@ -22,10 +22,12 @@ function rewrite_autodiff_call(f::Base.Fix1{<:StatefulLuxLayer}) end ## Break ambiguity -for op in [ComposedFunction{<:StatefulLuxLayer, <:StatefulLuxLayer}, +for op in [ + ComposedFunction{<:StatefulLuxLayer, <:StatefulLuxLayer}, ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:StatefulLuxLayer}, ComposedFunction{<:StatefulLuxLayer, <:Base.Fix1{<:StatefulLuxLayer}}, - ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:Base.Fix1{<:StatefulLuxLayer}}] + ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:Base.Fix1{<:StatefulLuxLayer}} +] @eval function rewrite_autodiff_call(::$op) error("Cannot rewrite ComposedFunction with StatefulLuxLayer as inner and outer \ layers")