Skip to content

Commit

Permalink
Contexts for Zygote (#474)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Sep 19, 2024
1 parent 71ea7fd commit ee11b70
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ module DifferentiationInterfaceZygoteExt
using ADTypes: AutoForwardDiff, AutoZygote
import DifferentiationInterface as DI
using DifferentiationInterface:
Context,
HVPExtras,
NoGradientExtras,
NoHessianExtras,
NoJacobianExtras,
NoPullbackExtras,
PullbackExtras,
Tangents
Tangents,
unwrap
using ForwardDiff: ForwardDiff
using Zygote:
ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian
Expand All @@ -25,63 +27,83 @@ struct ZygotePullbackExtrasSamePoint{Y,PB} <: PullbackExtras
pb::PB
end

DI.prepare_pullback(f, ::AutoZygote, x, ty::Tangents) = NoPullbackExtras()
function DI.prepare_pullback(f, ::AutoZygote, x, ty::Tangents, contexts::Vararg{Context})
return NoPullbackExtras()
end

function DI.prepare_pullback_same_point(
f, ::NoPullbackExtras, ::AutoZygote, x, ty::Tangents
f, ::NoPullbackExtras, ::AutoZygote, x, ty::Tangents, contexts::Vararg{Context}
)
y, pb = pullback(f, x)
y, pb = pullback(f, x, map(unwrap, contexts)...)
return ZygotePullbackExtrasSamePoint(y, pb)
end

function DI.value_and_pullback(f, ::NoPullbackExtras, ::AutoZygote, x, ty::Tangents)
y, pb = pullback(f, x)
function DI.value_and_pullback(
f, ::NoPullbackExtras, ::AutoZygote, x, ty::Tangents, contexts::Vararg{Context}
)
y, pb = pullback(f, x, map(unwrap, contexts)...)
tx = map(ty) do dy
only(pb(dy))
first(pb(dy))
end
return y, tx
end

function DI.value_and_pullback(
f, extras::ZygotePullbackExtrasSamePoint, ::AutoZygote, x, ty::Tangents
f,
extras::ZygotePullbackExtrasSamePoint,
::AutoZygote,
x,
ty::Tangents,
contexts::Vararg{Context},
)
@compat (; y, pb) = extras
tx = map(ty) do dy
only(pb(dy))
first(pb(dy))
end
return copy(y), tx
end

function DI.pullback(
f, extras::ZygotePullbackExtrasSamePoint, ::AutoZygote, x, ty::Tangents
f,
extras::ZygotePullbackExtrasSamePoint,
::AutoZygote,
x,
ty::Tangents,
contexts::Vararg{Context},
)
@compat (; pb) = extras
tx = map(ty) do dy
only(pb(dy))
first(pb(dy))
end
return tx
end

## Gradient

DI.prepare_gradient(f, ::AutoZygote, x) = NoGradientExtras()
DI.prepare_gradient(f, ::AutoZygote, x, contexts::Vararg{Context}) = NoGradientExtras()

function DI.value_and_gradient(f, ::NoGradientExtras, ::AutoZygote, x)
@compat (; val, grad) = withgradient(f, x)
return val, only(grad)
function DI.value_and_gradient(
f, ::NoGradientExtras, ::AutoZygote, x, contexts::Vararg{Context}
)
@compat (; val, grad) = withgradient(f, x, map(unwrap, contexts)...)
return val, first(grad)
end

function DI.gradient(f, ::NoGradientExtras, ::AutoZygote, x)
return only(gradient(f, x))
function DI.gradient(f, ::NoGradientExtras, ::AutoZygote, x, contexts::Vararg{Context})
return first(gradient(f, x, map(unwrap, contexts)...))
end

function DI.value_and_gradient!(f, grad, extras::NoGradientExtras, backend::AutoZygote, x)
y, new_grad = DI.value_and_gradient(f, extras, backend, x)
function DI.value_and_gradient!(
f, grad, extras::NoGradientExtras, backend::AutoZygote, x, contexts::Vararg{Context}
)
y, new_grad = DI.value_and_gradient(f, extras, backend, x, contexts...)
return y, copyto!(grad, new_grad)
end

function DI.gradient!(f, grad, extras::NoGradientExtras, backend::AutoZygote, x)
return copyto!(grad, DI.gradient(f, extras, backend, x))
function DI.gradient!(
f, grad, extras::NoGradientExtras, backend::AutoZygote, x, contexts::Vararg{Context}
)
return copyto!(grad, DI.gradient(f, extras, backend, x, contexts...))
end

## Jacobian
Expand Down
7 changes: 7 additions & 0 deletions DifferentiationInterface/test/Back/Zygote/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ end

test_differentiation(AutoZygote(); excluded=[:second_derivative], logging=LOGGING);

test_differentiation(
AutoZygote(),
default_scenarios(; include_normal=false, include_constantified=true);
second_order=false,
logging=LOGGING,
);

if VERSION >= v"1.10"
test_differentiation(
AutoZygote(),
Expand Down

0 comments on commit ee11b70

Please sign in to comment.