Skip to content

Commit

Permalink
Handle thunk
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Aug 9, 2024
1 parent 0691de9 commit 656cb3a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ using Enzyme:
Const,
Duplicated,
DuplicatedNoNeed,
EnzymeCore,
Forward,
ForwardMode,
MixedDuplicated,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@ end

function DI.value_and_pullback(
f,
backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}},
backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},function_annotation},
x::Number,
dy,
::NoPullbackExtras,
)
) where {function_annotation}
f_and_df = force_annotation(get_f_and_df(f, backend))
forw, rev = autodiff_thunk(
ReverseSplitWithPrimal, typeof(f_and_df), Duplicated, typeof(Active(x))
)
mode = if function_annotation <: Annotation
ReverseSplitWithPrimal
else
EnzymeCore.set_err_if_func_written(ReverseSplitWithPrimal)
end
forw, rev = autodiff_thunk(mode, typeof(f_and_df), Duplicated, typeof(Active(x)))
tape, y, new_dy = forw(f_and_df, Active(x))
copyto!(new_dy, dy)
new_dx = only(only(rev(f_and_df, Active(x), tape)))
Expand Down Expand Up @@ -102,15 +105,23 @@ function DI.value_and_pullback!(
end

function DI.value_and_pullback!(
f, dx, backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing}}, x, dy, ::NoPullbackExtras
)
f,
dx,
backend::AnyAutoEnzyme{<:Union{ReverseMode,Nothing},function_annotation},
x,
dy,
::NoPullbackExtras,
) where {function_annotation}
f_and_df = force_annotation(get_f_and_df(f, backend))
mode = if function_annotation <: Annotation
ReverseSplitWithPrimal
else
EnzymeCore.set_err_if_func_written(ReverseSplitWithPrimal)
end
dx_sametype = convert(typeof(x), dx)
make_zero!(dx_sametype)
x_and_dx = Duplicated(x, dx_sametype)
forw, rev = autodiff_thunk(
ReverseSplitWithPrimal, typeof(f_and_df), Duplicated, typeof(x_and_dx)
)
forw, rev = autodiff_thunk(mode, typeof(f_and_df), Duplicated, typeof(x_and_dx))
tape, y, new_dy = forw(f_and_df, x_and_dx)
copyto!(new_dy, dy)
rev(f_and_df, x_and_dx, tape)
Expand Down

0 comments on commit 656cb3a

Please sign in to comment.