From b99218055e37b9a2316e0ce2017fdb0ebbe9e1c3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 3 Sep 2024 14:46:23 +0200 Subject: [PATCH] Use reverse Jacobian and hvp from Enzyme (#445) * Use reverse Jacobian and hvp from Enzyme * Fixes --- DifferentiationInterface/Project.toml | 2 +- .../DifferentiationInterfaceEnzymeExt.jl | 5 ++ .../reverse_onearg.jl | 51 ++++++++++++++++++- .../second_order.jl | 14 +++++ .../test/Back/Enzyme/test.jl | 1 + 5 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/second_order.jl diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index a9da4fe9e..87b2f46a6 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -49,7 +49,7 @@ ChainRulesCore = "1.23.0" Compat = "3.46,4.2" Diffractor = "=0.2.6" DocStringExtensions = "0.8,0.9" -Enzyme = "0.12.28" +Enzyme = "0.12.35" FastDifferentiation = "0.3.9, 0.4" FillArrays = "1.7.0" FiniteDiff = "2.23.1" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index c3600dfd0..46b0d566b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -11,6 +11,7 @@ using DifferentiationInterface: PushforwardExtras, NoDerivativeExtras, NoGradientExtras, + NoHVPExtras, NoJacobianExtras, NoPullbackExtras, NoPushforwardExtras, @@ -41,6 +42,8 @@ using Enzyme: gradient, gradient!, guess_activity, + hvp, + hvp!, jacobian, make_zero, make_zero!, @@ -54,4 +57,6 @@ include("forward_twoarg.jl") include("reverse_onearg.jl") include("reverse_twoarg.jl") +include("second_order.jl") + end # module diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 92163a12f..259579678 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -216,4 +216,53 @@ end ## Jacobian -# see https://github.com/EnzymeAD/Enzyme.jl/issues/1391 +struct EnzymeReverseOneArgJacobianExtras{M,B} <: JacobianExtras end + +function DI.prepare_jacobian(f, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) + y = f(x) + M = length(y) + B = pick_batchsize(backend, M) + return EnzymeReverseOneArgJacobianExtras{M,B}() +end + +function DI.jacobian( + f, + backend::AutoEnzyme{<:ReverseMode,Nothing}, + x, + ::EnzymeReverseOneArgJacobianExtras{M,B}, +) where {M,B} + jac_wrongshape = jacobian(reverse_mode(backend), f, x, Val(M), Val(B)) + nx = length(x) + ny = length(jac_wrongshape) รท length(x) + return reshape(jac_wrongshape, ny, nx) +end + +function DI.value_and_jacobian( + f, + backend::AutoEnzyme{<:ReverseMode,Nothing}, + x, + extras::EnzymeReverseOneArgJacobianExtras, +) + return f(x), DI.jacobian(f, backend, x, extras) +end + +function DI.jacobian!( + f, + jac, + backend::AutoEnzyme{<:ReverseMode,Nothing}, + x, + extras::EnzymeReverseOneArgJacobianExtras, +) + return copyto!(jac, DI.jacobian(f, backend, x, extras)) +end + +function DI.value_and_jacobian!( + f, + jac, + backend::AutoEnzyme{<:ReverseMode,Nothing}, + x, + extras::EnzymeReverseOneArgJacobianExtras, +) + y, new_jac = DI.value_and_jacobian(f, backend, x, extras) + return y, copyto!(jac, new_jac) +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/second_order.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/second_order.jl new file mode 100644 index 000000000..b585e171c --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/second_order.jl @@ -0,0 +1,14 @@ +function DI.prepare_hvp(f, ::AnyAutoEnzyme{Nothing,Nothing}, x, tx::Tangents{1}) + return NoHVPExtras() +end + +function DI.hvp(f, ::AnyAutoEnzyme{Nothing,Nothing}, x, tx::Tangents{1}, ::NoHVPExtras) + return SingleTangent(hvp(f, x, only(tx))) +end + +function DI.hvp!( + f, tg::Tangents{1}, ::AnyAutoEnzyme{Nothing,Nothing}, x, tx::Tangents{1}, ::NoHVPExtras +) + hvp!(only(tg), f, x, only(tx)) + return tg +end diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index e80cdb59d..1c6b82b9b 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -13,6 +13,7 @@ LOGGING = get(ENV, "CI", "false") == "false" dense_backends = [ AutoEnzyme(; mode=nothing), + AutoEnzyme(; mode=nothing, function_annotation=Enzyme.Const), AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const), AutoEnzyme(; mode=Enzyme.Reverse),