Skip to content

Commit

Permalink
Preparation of pushforward, pullback and hvp for same point x (#255)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored May 14, 2024
1 parent 8843031 commit 65e449d
Show file tree
Hide file tree
Showing 28 changed files with 580 additions and 529 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.3.4"
version = "0.4.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
1 change: 0 additions & 1 deletion DifferentiationInterface/docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
DifferentiationInterface = "0.3"
Documenter = "1"
5 changes: 3 additions & 2 deletions DifferentiationInterface/docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ second_derivative!

```@docs
prepare_hvp
prepare_hvp_same_point
hvp
hvp!
```
Expand All @@ -67,6 +68,7 @@ hessian!

```@docs
prepare_pushforward
prepare_pushforward_same_point
pushforward
pushforward!
value_and_pushforward
Expand All @@ -75,12 +77,11 @@ value_and_pushforward!

```@docs
prepare_pullback
prepare_pullback_same_point
pullback
pullback!
value_and_pullback
value_and_pullback!
value_and_pullback_split
value_and_pullback!_split
```

## Backend queries
Expand Down
32 changes: 13 additions & 19 deletions DifferentiationInterface/docs/src/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,16 @@ However they have different signatures:
In many cases, AD can be accelerated if the function has been run at least once (e.g. to create a config or record a tape) and if some cache objects are provided.
This is a backend-specific procedure, but we expose a common syntax to achieve it.

| operator | preparation function |
| :------------------ | :---------------------------------- |
| `derivative` | [`prepare_derivative`](@ref) |
| `gradient` | [`prepare_gradient`](@ref) |
| `jacobian` | [`prepare_jacobian`](@ref) |
| `second_derivative` | [`prepare_second_derivative`](@ref) |
| `hessian` | [`prepare_hessian`](@ref) |
| `pushforward` | [`prepare_pushforward`](@ref) |
| `pullback` | [`prepare_pullback`](@ref) |
| `hvp` | [`prepare_hvp`](@ref) |
| operator | preparation function | preparation function (same point) |
| :------------------ | :---------------------------------- | ---------------------------------------- |
| `derivative` | [`prepare_derivative`](@ref) | - |
| `gradient` | [`prepare_gradient`](@ref) | - |
| `jacobian` | [`prepare_jacobian`](@ref) | - |
| `second_derivative` | [`prepare_second_derivative`](@ref) | - |
| `hessian` | [`prepare_hessian`](@ref) | - |
| `pushforward` | [`prepare_pushforward`](@ref) | [`prepare_pushforward_same_point`](@ref) |
| `pullback` | [`prepare_pullback`](@ref) | [`prepare_pullback_same_point`](@ref) |
| `hvp` | [`prepare_hvp`](@ref) | [`prepare_hvp_same_point`](@ref) |

Unsurprisingly, preparation syntax depends on the number of arguments:

Expand All @@ -89,6 +89,9 @@ This is especially worth it if you plan to call `operator` several times in simi
!!! warning
The `extras` object is nearly always mutated when given to an operator, even when said operator does not have a bang `!` in its name.

With `pushforward`, `pullback` and `hvp`, you can also choose to prepare for the same point `x`, assuming only the seed `v` will change.
Such is the purpose of `prepare_operator_same_point(f, backend, x, v)`, which is otherwise similar to standard preparation.

### Second order

We offer two ways to perform second-order differentiation (for [`second_derivative`](@ref), [`hvp`](@ref) and [`hessian`](@ref)):
Expand All @@ -115,15 +118,6 @@ We offer two ways to perform second-order differentiation (for [`second_derivati
Just wrap it around any backend, with an appropriate choice of sparsity detector and coloring algorithm, and call `jacobian` or `hessian`: the result will be sparse.
See the [tutorial section on sparsity](@ref sparsity-tutorial) for details.

### Split reverse mode

Some reverse mode AD backends expose a "split" option, which runs only the forward sweep, and encapsulates the reverse sweep in a closure.
We make this available for all backends with the following operators:

| out-of-place | in-place |
| :--------------------------------- | :---------------------------------- |
| [`value_and_pullback_split`](@ref) | [`value_and_pullback!_split`](@ref) |

### Translation

The wrapper [`DifferentiateWith`](@ref) allows you to translate between AD backends.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ using ChainRulesCore:
rrule_via_ad
using Compat
import DifferentiationInterface as DI
using DifferentiationInterface: DifferentiateWith, NoPullbackExtras, NoPushforwardExtras
using DifferentiationInterface:
DifferentiateWith, NoPullbackExtras, NoPushforwardExtras, PullbackExtras

ruleconfig(backend::AutoChainRules) = backend.ruleconfig

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ end

function ChainRulesCore.rrule(dw::DifferentiateWith, x)
@compat (; f, backend) = dw
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x)
pullbackfunc_adjusted(dy) = (NoTangent(), pullbackfunc(dy))
return y, pullbackfunc_adjusted
y = f(x)
extras_same = DI.prepare_pullback_same_point(f, backend, x, y)
pullbackfunc(dy) = (NoTangent(), DI.pullback(f, backend, x, dy, extras_same))
return y, pullbackfunc
end
Original file line number Diff line number Diff line change
@@ -1,27 +1,36 @@
## Pullback

struct ChainRulesPullbackExtrasSamePoint{Y,PB} <: PullbackExtras
y::Y
pb::PB
end

DI.prepare_pullback(f, ::AutoReverseChainRules, x, dy) = NoPullbackExtras()

function DI.value_and_pullback_split(
f, backend::AutoReverseChainRules, x, ::NoPullbackExtras
function DI.prepare_pullback_same_point(
f, backend::AutoReverseChainRules, x, dy, ::PullbackExtras=NoPullbackExtras()
)
rc = ruleconfig(backend)
y, pullback = rrule_via_ad(rc, f, x)
pullbackfunc(dy) = last(pullback(dy))
return y, pullbackfunc
y, pb = rrule_via_ad(rc, f, x)
return ChainRulesPullbackExtrasSamePoint(y, pb)
end

function DI.value_and_pullback!_split(
f, backend::AutoReverseChainRules, x, extras::NoPullbackExtras
)
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
pullbackfunc!(dx, dy) = copyto!(dx, pullbackfunc(dy))
return y, pullbackfunc!
function DI.value_and_pullback(f, backend::AutoReverseChainRules, x, dy, ::NoPullbackExtras)
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x)
return y, last(pb(dy))
end

function DI.value_and_pullback(
f, backend::AutoReverseChainRules, x, dy, extras::NoPullbackExtras
f, ::AutoReverseChainRules, x, dy, extras::ChainRulesPullbackExtrasSamePoint
)
@compat (; y, pb) = extras
return copy(y), last(pb(dy))
end

function DI.pullback(
f, ::AutoReverseChainRules, x, dy, extras::ChainRulesPullbackExtrasSamePoint
)
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
return y, pullbackfunc(dy)
@compat (; pb) = extras
return last(pb(dy))
end
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ end

function DI.prepare_pushforward(f::F, backend::AutoForwardDiff, x, dx) where {F}
T = tag_type(f, backend, x)
xdual_tmp = make_dual(T, x, dx)
xdual_tmp = make_dual_similar(T, x)
return ForwardDiffOneArgPushforwardExtras{T,typeof(xdual_tmp)}(xdual_tmp)
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ end

function DI.prepare_pushforward(f!::F, y, backend::AutoForwardDiff, x, dx) where {F}
T = tag_type(f!, backend, x)
xdual_tmp = make_dual(T, x, dx)
ydual_tmp = make_dual(T, y, similar(y))
xdual_tmp = make_dual_similar(T, x)
ydual_tmp = make_dual_similar(T, y)
return ForwardDiffTwoArgPushforwardExtras{T,typeof(xdual_tmp),typeof(ydual_tmp)}(
xdual_tmp, ydual_tmp
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ tag_type(f, ::AutoForwardDiff{C,Nothing}, x) where {C} = Tag{typeof(f),eltype(x)
make_dual(::Type{T}, x::Number, dx) where {T} = Dual{T}(x, dx)
make_dual(::Type{T}, x, dx) where {T} = Dual{T}.(x, dx) # TODO: map causes Enzyme to fail

make_dual_similar(::Type{T}, x::Number) where {T} = Dual{T}(x, x)
make_dual_similar(::Type{T}, x) where {T} = similar(x, Dual{T,eltype(x),1})

make_dual!(::Type{T}, xdual, x, dx) where {T} = map!(Dual{T}, xdual, x, dx)

myvalue(::Type{T}, ydual::Number) where {T} = value(T, ydual)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module DifferentiationInterfaceTrackerExt

using ADTypes: AutoTracker
import DifferentiationInterface as DI
using DifferentiationInterface: NoGradientExtras, NoPullbackExtras
using DifferentiationInterface: NoGradientExtras, NoPullbackExtras, PullbackExtras
using Tracker: Tracker, back, data, forward, gradient, jacobian, param, withgradient
using Compat

Expand All @@ -11,23 +11,35 @@ DI.twoarg_support(::AutoTracker) = DI.TwoArgNotSupported()

## Pullback

struct TrackerPullbackExtrasSamePoint{Y,PB} <: PullbackExtras
y::Y
pb::PB
end

DI.prepare_pullback(f, ::AutoTracker, x, dy) = NoPullbackExtras()

function DI.value_and_pullback_split(f, ::AutoTracker, x, ::NoPullbackExtras)
y, back = forward(f, x)
pullbackfunc(dy) = data(only(back(dy)))
return y, pullbackfunc
function DI.prepare_pullback_same_point(
f, ::AutoTracker, x, dy, ::PullbackExtras=NoPullbackExtras()
)
y, pb = forward(f, x)
return TrackerPullbackExtrasSamePoint(y, pb)
end

function DI.value_and_pullback(f, ::AutoTracker, x, dy, ::NoPullbackExtras)
y, pb = forward(f, x)
return y, data(only(pb(dy)))
end

function DI.value_and_pullback!_split(f, backend::AutoTracker, x, extras::NoPullbackExtras)
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
pullbackfunc!(dx, dy) = copyto!(dx, pullbackfunc(dy))
return y, pullbackfunc!
function DI.value_and_pullback(
f, ::AutoTracker, x, dy, extras::TrackerPullbackExtrasSamePoint
)
@compat (; y, pb) = extras
return copy(y), data(only(pb(dy)))
end

function DI.value_and_pullback(f, backend::AutoTracker, x, dy, extras::NoPullbackExtras)
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
return y, pullbackfunc(dy)
function DI.pullback(f, ::AutoTracker, x, dy, extras::TrackerPullbackExtrasSamePoint)
@compat (; pb) = extras
return data(only(pb(dy)))
end

## Gradient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module DifferentiationInterfaceZygoteExt
using ADTypes: AutoZygote
import DifferentiationInterface as DI
using DifferentiationInterface:
NoGradientExtras, NoHessianExtras, NoJacobianExtras, NoPullbackExtras
NoGradientExtras, NoHessianExtras, NoJacobianExtras, NoPullbackExtras, PullbackExtras
using DocStringExtensions
using Zygote:
ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian
Expand All @@ -14,23 +14,35 @@ DI.twoarg_support(::AutoZygote) = DI.TwoArgNotSupported()

## Pullback

struct ZygotePullbackExtrasSamePoint{Y,PB} <: PullbackExtras
y::Y
pb::PB
end

DI.prepare_pullback(f, ::AutoZygote, x, dy) = NoPullbackExtras()

function DI.value_and_pullback_split(f, ::AutoZygote, x, ::NoPullbackExtras)
y, back = pullback(f, x)
pullbackfunc(dy) = only(back(dy))
return y, pullbackfunc
function DI.prepare_pullback_same_point(
f, ::AutoZygote, x, dy, ::PullbackExtras=NoPullbackExtras()
)
y, pb = pullback(f, x)
return ZygotePullbackExtrasSamePoint(y, pb)
end

function DI.value_and_pullback(f, ::AutoZygote, x, dy, ::NoPullbackExtras)
y, pb = pullback(f, x)
return y, only(pb(dy))
end

function DI.value_and_pullback!_split(f, backend::AutoZygote, x, extras::NoPullbackExtras)
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
pullbackfunc!(dx, dy) = copyto!(dx, pullbackfunc(dy))
return y, pullbackfunc!
function DI.value_and_pullback(
f, ::AutoZygote, x, dy, extras::ZygotePullbackExtrasSamePoint
)
@compat (; y, pb) = extras
return copy(y), only(pb(dy))
end

function DI.value_and_pullback(f, backend::AutoZygote, x, dy, extras::NoPullbackExtras)
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
return y, pullbackfunc(dy)
function DI.pullback(f, ::AutoZygote, x, dy, extras::ZygotePullbackExtrasSamePoint)
@compat (; pb) = extras
return only(pb(dy))
end

## Gradient
Expand Down
7 changes: 4 additions & 3 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ export SecondOrder

export value_and_pushforward!, value_and_pushforward
export value_and_pullback!, value_and_pullback
export value_and_pullback!_split, value_and_pullback_split

export value_and_derivative!, value_and_derivative
export value_and_gradient!, value_and_gradient
Expand All @@ -91,9 +90,11 @@ export second_derivative!, second_derivative
export hvp!, hvp
export hessian!, hessian

export prepare_pushforward, prepare_pullback
export prepare_pushforward, prepare_pushforward_same_point
export prepare_pullback, prepare_pullback_same_point
export prepare_hvp, prepare_hvp_same_point
export prepare_derivative, prepare_gradient, prepare_jacobian
export prepare_second_derivative, prepare_hvp, prepare_hessian
export prepare_second_derivative, prepare_hessian

export check_available, check_twoarg, check_hessian

Expand Down
Loading

0 comments on commit 65e449d

Please sign in to comment.