Skip to content

Commit

Permalink
fix: avoid closures in batched_jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 25, 2024
1 parent ff8e926 commit 1088431
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
16 changes: 9 additions & 7 deletions ext/LuxEnzymeExt/batched_autodiff.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
function Lux.AutoDiffInternalImpl.batched_jacobian_impl(
f::F, ad::AutoEnzyme, x::AbstractArray) where {F}
function Lux.AutoDiffInternalImpl.batched_jacobian_internal(
f::F, ad::AutoEnzyme, x::AbstractArray, args...) where {F}
backend = normalize_backend(True(), ad)
return batched_enzyme_jacobian_impl(f, backend, ADTypes.mode(backend), x)
return batched_enzyme_jacobian_impl(f, backend, ADTypes.mode(backend), x, args...)
end

function batched_enzyme_jacobian_impl(
f_orig::G, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {G}
f_orig::G, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray, args...) where {G}
# We need to run the function once to get the output type. Can we use ForwardWithPrimal?
y = f_orig(x)
f = annotate_function(ad, f_orig)
Expand All @@ -26,7 +26,8 @@ function batched_enzyme_jacobian_impl(
for i in 1:chunk_size:(length(x) ÷ B)
idxs = i:min(i + chunk_size - 1, length(x) ÷ B)
partials′ = make_onehot!(partials, idxs)
J_partials = only(Enzyme.autodiff(ad.mode, f, BatchDuplicated(x, partials′)))
J_partials = only(Enzyme.autodiff(
ad.mode, f, BatchDuplicated(x, partials′), Const.(args)...))
for (idx, J_partial) in zip(idxs, J_partials)
copyto!(view(J, :, idx, :), reshape(J_partial, :, B))
end
Expand All @@ -36,7 +37,7 @@ function batched_enzyme_jacobian_impl(
end

function batched_enzyme_jacobian_impl(
f_orig::G, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {G}
f_orig::G, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray, args...) where {G}
# We need to run the function once to get the output type. Can we use ReverseWithPrimal?
y = f_orig(x)

Expand All @@ -60,7 +61,8 @@ function batched_enzyme_jacobian_impl(
partials′ = make_onehot!(partials, idxs)
J_partials′ = make_zero!(J_partials, idxs)
Enzyme.autodiff(
ad.mode, fn, BatchDuplicated(y, partials′), BatchDuplicated(x, J_partials′)
ad.mode, fn, BatchDuplicated(y, partials′),
BatchDuplicated(x, J_partials′), Const.(args)...
)
for (idx, J_partial) in zip(idxs, J_partials)
copyto!(view(J, idx, :, :), reshape(J_partial, :, B))
Expand Down
10 changes: 6 additions & 4 deletions src/autodiff/nested_autodiff.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
## Written like this to avoid dynamic dispatch from Zygote
# Input Gradient / Jacobian
function rewrite_autodiff_call(f::ComposedFunction{F, <:StatefulLuxLayer}) where {F}
(f, f.inner.ps)
return f, f.inner.ps
end
function rewrite_autodiff_call(f::ComposedFunction{<:StatefulLuxLayer, F}) where {F}
(@closure((x, ps)->f.outer(f.inner(x), ps)), f.outer.ps)
return @closure((x, ps)->f.outer(f.inner(x), ps)), f.outer.ps
end
rewrite_autodiff_call(f::StatefulLuxLayer) = f, f.ps

Expand All @@ -22,10 +22,12 @@ function rewrite_autodiff_call(f::Base.Fix1{<:StatefulLuxLayer})
end

## Break ambiguity
for op in [ComposedFunction{<:StatefulLuxLayer, <:StatefulLuxLayer},
for op in [
ComposedFunction{<:StatefulLuxLayer, <:StatefulLuxLayer},
ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:StatefulLuxLayer},
ComposedFunction{<:StatefulLuxLayer, <:Base.Fix1{<:StatefulLuxLayer}},
ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:Base.Fix1{<:StatefulLuxLayer}}]
ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:Base.Fix1{<:StatefulLuxLayer}}
]
@eval function rewrite_autodiff_call(::$op)
error("Cannot rewrite ComposedFunction with StatefulLuxLayer as inner and outer \
layers")
Expand Down

0 comments on commit 1088431

Please sign in to comment.