Skip to content

Commit

Permalink
Use reverse Jacobian and hvp from Enzyme (#445)
Browse files Browse the repository at this point in the history
* Use reverse Jacobian and hvp from Enzyme

* Fixes
  • Loading branch information
gdalle authored Sep 3, 2024
1 parent 4001119 commit b992180
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 2 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ ChainRulesCore = "1.23.0"
Compat = "3.46,4.2"
Diffractor = "=0.2.6"
DocStringExtensions = "0.8,0.9"
Enzyme = "0.12.28"
Enzyme = "0.12.35"
FastDifferentiation = "0.3.9, 0.4"
FillArrays = "1.7.0"
FiniteDiff = "2.23.1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using DifferentiationInterface:
PushforwardExtras,
NoDerivativeExtras,
NoGradientExtras,
NoHVPExtras,
NoJacobianExtras,
NoPullbackExtras,
NoPushforwardExtras,
Expand Down Expand Up @@ -41,6 +42,8 @@ using Enzyme:
gradient,
gradient!,
guess_activity,
hvp,
hvp!,
jacobian,
make_zero,
make_zero!,
Expand All @@ -54,4 +57,6 @@ include("forward_twoarg.jl")
include("reverse_onearg.jl")
include("reverse_twoarg.jl")

include("second_order.jl")

end # module
Original file line number Diff line number Diff line change
Expand Up @@ -216,4 +216,53 @@ end

## Jacobian

# see https://github.com/EnzymeAD/Enzyme.jl/issues/1391
struct EnzymeReverseOneArgJacobianExtras{M,B} <: JacobianExtras end

function DI.prepare_jacobian(f, backend::AutoEnzyme{<:ReverseMode,Nothing}, x)
y = f(x)
M = length(y)
B = pick_batchsize(backend, M)
return EnzymeReverseOneArgJacobianExtras{M,B}()
end

function DI.jacobian(
f,
backend::AutoEnzyme{<:ReverseMode,Nothing},
x,
::EnzymeReverseOneArgJacobianExtras{M,B},
) where {M,B}
jac_wrongshape = jacobian(reverse_mode(backend), f, x, Val(M), Val(B))
nx = length(x)
ny = length(jac_wrongshape) ÷ length(x)
return reshape(jac_wrongshape, ny, nx)
end

function DI.value_and_jacobian(
f,
backend::AutoEnzyme{<:ReverseMode,Nothing},
x,
extras::EnzymeReverseOneArgJacobianExtras,
)
return f(x), DI.jacobian(f, backend, x, extras)
end

function DI.jacobian!(
f,
jac,
backend::AutoEnzyme{<:ReverseMode,Nothing},
x,
extras::EnzymeReverseOneArgJacobianExtras,
)
return copyto!(jac, DI.jacobian(f, backend, x, extras))
end

function DI.value_and_jacobian!(
f,
jac,
backend::AutoEnzyme{<:ReverseMode,Nothing},
x,
extras::EnzymeReverseOneArgJacobianExtras,
)
y, new_jac = DI.value_and_jacobian(f, backend, x, extras)
return y, copyto!(jac, new_jac)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
function DI.prepare_hvp(f, ::AnyAutoEnzyme{Nothing,Nothing}, x, tx::Tangents{1})
return NoHVPExtras()
end

function DI.hvp(f, ::AnyAutoEnzyme{Nothing,Nothing}, x, tx::Tangents{1}, ::NoHVPExtras)
return SingleTangent(hvp(f, x, only(tx)))
end

function DI.hvp!(
f, tg::Tangents{1}, ::AnyAutoEnzyme{Nothing,Nothing}, x, tx::Tangents{1}, ::NoHVPExtras
)
hvp!(only(tg), f, x, only(tx))
return tg
end
1 change: 1 addition & 0 deletions DifferentiationInterface/test/Back/Enzyme/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ LOGGING = get(ENV, "CI", "false") == "false"

dense_backends = [
AutoEnzyme(; mode=nothing),
AutoEnzyme(; mode=nothing, function_annotation=Enzyme.Const),
AutoEnzyme(; mode=Enzyme.Forward),
AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const),
AutoEnzyme(; mode=Enzyme.Reverse),
Expand Down

0 comments on commit b992180

Please sign in to comment.