From 65e449dd0bc323ab0b08fb3e3ab5a35e80d3a302 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 14 May 2024 09:58:22 +0200 Subject: [PATCH] Preparation of pushforward, pullback and hvp for same point x (#255) --- DifferentiationInterface/Project.toml | 2 +- DifferentiationInterface/docs/Project.toml | 1 - DifferentiationInterface/docs/src/api.md | 5 +- DifferentiationInterface/docs/src/overview.md | 32 +- ...fferentiationInterfaceChainRulesCoreExt.jl | 3 +- .../differentiate_with.jl | 7 +- .../reverse_onearg.jl | 37 +- .../onearg.jl | 2 +- .../twoarg.jl | 4 +- .../utils.jl | 3 + .../DifferentiationInterfaceTrackerExt.jl | 36 +- .../DifferentiationInterfaceZygoteExt.jl | 36 +- .../src/DifferentiationInterface.jl | 7 +- .../src/first_order/jacobian.jl | 86 +++- .../src/first_order/pullback.jl | 128 ++--- .../src/first_order/pushforward.jl | 49 +- .../src/second_order/hessian.jl | 10 +- .../src/second_order/hvp.jl | 23 + .../src/sparse/fallbacks.jl | 27 +- .../src/sparse/hessian.jl | 5 +- .../src/sparse/jacobian.jl | 26 +- DifferentiationInterfaceTest/Project.toml | 4 +- .../docs/Project.toml | 2 - .../src/scenarios/scenario.jl | 4 +- .../src/tests/benchmark.jl | 96 +--- .../src/tests/correctness.jl | 457 ++++++++++-------- .../src/tests/type_stability.jl | 12 - .../src/utils/misc.jl | 5 +- 28 files changed, 580 insertions(+), 529 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 4035cad95..bfee09e6b 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.3.4" +version = "0.4.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/docs/Project.toml b/DifferentiationInterface/docs/Project.toml index 72feb54b5..16d0930c3 100644 --- a/DifferentiationInterface/docs/Project.toml +++ b/DifferentiationInterface/docs/Project.toml @@ -22,5 +22,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -DifferentiationInterface = "0.3" Documenter = "1" diff --git a/DifferentiationInterface/docs/src/api.md b/DifferentiationInterface/docs/src/api.md index 07fcddcd9..a0500ac4b 100644 --- a/DifferentiationInterface/docs/src/api.md +++ b/DifferentiationInterface/docs/src/api.md @@ -53,6 +53,7 @@ second_derivative! ```@docs prepare_hvp +prepare_hvp_same_point hvp hvp! ``` @@ -67,6 +68,7 @@ hessian! ```@docs prepare_pushforward +prepare_pushforward_same_point pushforward pushforward! value_and_pushforward @@ -75,12 +77,11 @@ value_and_pushforward! ```@docs prepare_pullback +prepare_pullback_same_point pullback pullback! value_and_pullback value_and_pullback! -value_and_pullback_split -value_and_pullback!_split ``` ## Backend queries diff --git a/DifferentiationInterface/docs/src/overview.md b/DifferentiationInterface/docs/src/overview.md index 0eb66918f..b5d69de65 100644 --- a/DifferentiationInterface/docs/src/overview.md +++ b/DifferentiationInterface/docs/src/overview.md @@ -62,16 +62,16 @@ However they have different signatures: In many cases, AD can be accelerated if the function has been run at least once (e.g. to create a config or record a tape) and if some cache objects are provided. This is a backend-specific procedure, but we expose a common syntax to achieve it. -| operator | preparation function | -| :------------------ | :---------------------------------- | -| `derivative` | [`prepare_derivative`](@ref) | -| `gradient` | [`prepare_gradient`](@ref) | -| `jacobian` | [`prepare_jacobian`](@ref) | -| `second_derivative` | [`prepare_second_derivative`](@ref) | -| `hessian` | [`prepare_hessian`](@ref) | -| `pushforward` | [`prepare_pushforward`](@ref) | -| `pullback` | [`prepare_pullback`](@ref) | -| `hvp` | [`prepare_hvp`](@ref) | +| operator | preparation function | preparation function (same point) | +| :------------------ | :---------------------------------- | ---------------------------------------- | +| `derivative` | [`prepare_derivative`](@ref) | - | +| `gradient` | [`prepare_gradient`](@ref) | - | +| `jacobian` | [`prepare_jacobian`](@ref) | - | +| `second_derivative` | [`prepare_second_derivative`](@ref) | - | +| `hessian` | [`prepare_hessian`](@ref) | - | +| `pushforward` | [`prepare_pushforward`](@ref) | [`prepare_pushforward_same_point`](@ref) | +| `pullback` | [`prepare_pullback`](@ref) | [`prepare_pullback_same_point`](@ref) | +| `hvp` | [`prepare_hvp`](@ref) | [`prepare_hvp_same_point`](@ref) | Unsurprisingly, preparation syntax depends on the number of arguments: @@ -89,6 +89,9 @@ This is especially worth it if you plan to call `operator` several times in simi !!! warning The `extras` object is nearly always mutated when given to an operator, even when said operator does not have a bang `!` in its name. +With `pushforward`, `pullback` and `hvp`, you can also choose to prepare for the same point `x`, assuming only the seed `v` will change. +Such is the purpose of `prepare_operator_same_point(f, backend, x, v)`, which is otherwise similar to standard preparation. + ### Second order We offer two ways to perform second-order differentiation (for [`second_derivative`](@ref), [`hvp`](@ref) and [`hessian`](@ref)): @@ -115,15 +118,6 @@ We offer two ways to perform second-order differentiation (for [`second_derivati Just wrap it around any backend, with an appropriate choice of sparsity detector and coloring algorithm, and call `jacobian` or `hessian`: the result will be sparse. See the [tutorial section on sparsity](@ref sparsity-tutorial) for details. -### Split reverse mode - -Some reverse mode AD backends expose a "split" option, which runs only the forward sweep, and encapsulates the reverse sweep in a closure. -We make this available for all backends with the following operators: - -| out-of-place | in-place | -| :--------------------------------- | :---------------------------------- | -| [`value_and_pullback_split`](@ref) | [`value_and_pullback!_split`](@ref) | - ### Translation The wrapper [`DifferentiateWith`](@ref) allows you to translate between AD backends. diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl index f0bd09d57..0fdac68ba 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl @@ -11,7 +11,8 @@ using ChainRulesCore: rrule_via_ad using Compat import DifferentiationInterface as DI -using DifferentiationInterface: DifferentiateWith, NoPullbackExtras, NoPushforwardExtras +using DifferentiationInterface: + DifferentiateWith, NoPullbackExtras, NoPushforwardExtras, PullbackExtras ruleconfig(backend::AutoChainRules) = backend.ruleconfig diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl index f0ce941e8..6dafb6a2a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl @@ -6,7 +6,8 @@ end function ChainRulesCore.rrule(dw::DifferentiateWith, x) @compat (; f, backend) = dw - y, pullbackfunc = DI.value_and_pullback_split(f, backend, x) - pullbackfunc_adjusted(dy) = (NoTangent(), pullbackfunc(dy)) - return y, pullbackfunc_adjusted + y = f(x) + extras_same = DI.prepare_pullback_same_point(f, backend, x, y) + pullbackfunc(dy) = (NoTangent(), DI.pullback(f, backend, x, dy, extras_same)) + return y, pullbackfunc end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl index 706af6583..4390a5a3d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -1,27 +1,36 @@ ## Pullback +struct ChainRulesPullbackExtrasSamePoint{Y,PB} <: PullbackExtras + y::Y + pb::PB +end + DI.prepare_pullback(f, ::AutoReverseChainRules, x, dy) = NoPullbackExtras() -function DI.value_and_pullback_split( - f, backend::AutoReverseChainRules, x, ::NoPullbackExtras +function DI.prepare_pullback_same_point( + f, backend::AutoReverseChainRules, x, dy, ::PullbackExtras=NoPullbackExtras() ) rc = ruleconfig(backend) - y, pullback = rrule_via_ad(rc, f, x) - pullbackfunc(dy) = last(pullback(dy)) - return y, pullbackfunc + y, pb = rrule_via_ad(rc, f, x) + return ChainRulesPullbackExtrasSamePoint(y, pb) end -function DI.value_and_pullback!_split( - f, backend::AutoReverseChainRules, x, extras::NoPullbackExtras -) - y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras) - pullbackfunc!(dx, dy) = copyto!(dx, pullbackfunc(dy)) - return y, pullbackfunc! +function DI.value_and_pullback(f, backend::AutoReverseChainRules, x, dy, ::NoPullbackExtras) + rc = ruleconfig(backend) + y, pb = rrule_via_ad(rc, f, x) + return y, last(pb(dy)) end function DI.value_and_pullback( - f, backend::AutoReverseChainRules, x, dy, extras::NoPullbackExtras + f, ::AutoReverseChainRules, x, dy, extras::ChainRulesPullbackExtrasSamePoint +) + @compat (; y, pb) = extras + return copy(y), last(pb(dy)) +end + +function DI.pullback( + f, ::AutoReverseChainRules, x, dy, extras::ChainRulesPullbackExtrasSamePoint ) - y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras) - return y, pullbackfunc(dy) + @compat (; pb) = extras + return last(pb(dy)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 26354d357..6326cbc4f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -6,7 +6,7 @@ end function DI.prepare_pushforward(f::F, backend::AutoForwardDiff, x, dx) where {F} T = tag_type(f, backend, x) - xdual_tmp = make_dual(T, x, dx) + xdual_tmp = make_dual_similar(T, x) return ForwardDiffOneArgPushforwardExtras{T,typeof(xdual_tmp)}(xdual_tmp) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index 7f587a074..0feccd13b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -7,8 +7,8 @@ end function DI.prepare_pushforward(f!::F, y, backend::AutoForwardDiff, x, dx) where {F} T = tag_type(f!, backend, x) - xdual_tmp = make_dual(T, x, dx) - ydual_tmp = make_dual(T, y, similar(y)) + xdual_tmp = make_dual_similar(T, x) + ydual_tmp = make_dual_similar(T, y) return ForwardDiffTwoArgPushforwardExtras{T,typeof(xdual_tmp),typeof(ydual_tmp)}( xdual_tmp, ydual_tmp ) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index d39638a15..9d7c08928 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -7,6 +7,9 @@ tag_type(f, ::AutoForwardDiff{C,Nothing}, x) where {C} = Tag{typeof(f),eltype(x) make_dual(::Type{T}, x::Number, dx) where {T} = Dual{T}(x, dx) make_dual(::Type{T}, x, dx) where {T} = Dual{T}.(x, dx) # TODO: map causes Enzyme to fail +make_dual_similar(::Type{T}, x::Number) where {T} = Dual{T}(x, x) +make_dual_similar(::Type{T}, x) where {T} = similar(x, Dual{T,eltype(x),1}) + make_dual!(::Type{T}, xdual, x, dx) where {T} = map!(Dual{T}, xdual, x, dx) myvalue(::Type{T}, ydual::Number) where {T} = value(T, ydual) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl index 2ffd2c525..961aa269e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl @@ -2,7 +2,7 @@ module DifferentiationInterfaceTrackerExt using ADTypes: AutoTracker import DifferentiationInterface as DI -using DifferentiationInterface: NoGradientExtras, NoPullbackExtras +using DifferentiationInterface: NoGradientExtras, NoPullbackExtras, PullbackExtras using Tracker: Tracker, back, data, forward, gradient, jacobian, param, withgradient using Compat @@ -11,23 +11,35 @@ DI.twoarg_support(::AutoTracker) = DI.TwoArgNotSupported() ## Pullback +struct TrackerPullbackExtrasSamePoint{Y,PB} <: PullbackExtras + y::Y + pb::PB +end + DI.prepare_pullback(f, ::AutoTracker, x, dy) = NoPullbackExtras() -function DI.value_and_pullback_split(f, ::AutoTracker, x, ::NoPullbackExtras) - y, back = forward(f, x) - pullbackfunc(dy) = data(only(back(dy))) - return y, pullbackfunc +function DI.prepare_pullback_same_point( + f, ::AutoTracker, x, dy, ::PullbackExtras=NoPullbackExtras() +) + y, pb = forward(f, x) + return TrackerPullbackExtrasSamePoint(y, pb) +end + +function DI.value_and_pullback(f, ::AutoTracker, x, dy, ::NoPullbackExtras) + y, pb = forward(f, x) + return y, data(only(pb(dy))) end -function DI.value_and_pullback!_split(f, backend::AutoTracker, x, extras::NoPullbackExtras) - y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras) - pullbackfunc!(dx, dy) = copyto!(dx, pullbackfunc(dy)) - return y, pullbackfunc! +function DI.value_and_pullback( + f, ::AutoTracker, x, dy, extras::TrackerPullbackExtrasSamePoint +) + @compat (; y, pb) = extras + return copy(y), data(only(pb(dy))) end -function DI.value_and_pullback(f, backend::AutoTracker, x, dy, extras::NoPullbackExtras) - y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras) - return y, pullbackfunc(dy) +function DI.pullback(f, ::AutoTracker, x, dy, extras::TrackerPullbackExtrasSamePoint) + @compat (; pb) = extras + return data(only(pb(dy))) end ## Gradient diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index bd8ad1e52..1a14341ff 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -3,7 +3,7 @@ module DifferentiationInterfaceZygoteExt using ADTypes: AutoZygote import DifferentiationInterface as DI using DifferentiationInterface: - NoGradientExtras, NoHessianExtras, NoJacobianExtras, NoPullbackExtras + NoGradientExtras, NoHessianExtras, NoJacobianExtras, NoPullbackExtras, PullbackExtras using DocStringExtensions using Zygote: ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian @@ -14,23 +14,35 @@ DI.twoarg_support(::AutoZygote) = DI.TwoArgNotSupported() ## Pullback +struct ZygotePullbackExtrasSamePoint{Y,PB} <: PullbackExtras + y::Y + pb::PB +end + DI.prepare_pullback(f, ::AutoZygote, x, dy) = NoPullbackExtras() -function DI.value_and_pullback_split(f, ::AutoZygote, x, ::NoPullbackExtras) - y, back = pullback(f, x) - pullbackfunc(dy) = only(back(dy)) - return y, pullbackfunc +function DI.prepare_pullback_same_point( + f, ::AutoZygote, x, dy, ::PullbackExtras=NoPullbackExtras() +) + y, pb = pullback(f, x) + return ZygotePullbackExtrasSamePoint(y, pb) +end + +function DI.value_and_pullback(f, ::AutoZygote, x, dy, ::NoPullbackExtras) + y, pb = pullback(f, x) + return y, only(pb(dy)) end -function DI.value_and_pullback!_split(f, backend::AutoZygote, x, extras::NoPullbackExtras) - y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras) - pullbackfunc!(dx, dy) = copyto!(dx, pullbackfunc(dy)) - return y, pullbackfunc! +function DI.value_and_pullback( + f, ::AutoZygote, x, dy, extras::ZygotePullbackExtrasSamePoint +) + @compat (; y, pb) = extras + return copy(y), only(pb(dy)) end -function DI.value_and_pullback(f, backend::AutoZygote, x, dy, extras::NoPullbackExtras) - y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras) - return y, pullbackfunc(dy) +function DI.pullback(f, ::AutoZygote, x, dy, extras::ZygotePullbackExtrasSamePoint) + @compat (; pb) = extras + return only(pb(dy)) end ## Gradient diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index d30deff6e..751d789af 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -74,7 +74,6 @@ export SecondOrder export value_and_pushforward!, value_and_pushforward export value_and_pullback!, value_and_pullback -export value_and_pullback!_split, value_and_pullback_split export value_and_derivative!, value_and_derivative export value_and_gradient!, value_and_gradient @@ -91,9 +90,11 @@ export second_derivative!, second_derivative export hvp!, hvp export hessian!, hessian -export prepare_pushforward, prepare_pullback +export prepare_pushforward, prepare_pushforward_same_point +export prepare_pullback, prepare_pullback_same_point +export prepare_hvp, prepare_hvp_same_point export prepare_derivative, prepare_gradient, prepare_jacobian -export prepare_second_derivative, prepare_hvp, prepare_hessian +export prepare_second_derivative, prepare_hessian export check_available, check_twoarg, check_hessian diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 77c87f28b..b68bcab7f 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -99,10 +99,17 @@ end function value_and_jacobian_onearg_aux( f::F, backend, x::AbstractArray, extras::PushforwardJacobianExtras ) where {F} - y = f(x) + y = f(x) # TODO: remove + pushforward_extras_same = prepare_pushforward_same_point( + f, + backend, + x, + basis(backend, x, first(CartesianIndices(x))), + extras.pushforward_extras, + ) jac = stack(CartesianIndices(x); dims=2) do j dx_j = basis(backend, x, j) - jac_col_j = pushforward(f, backend, x, dx_j, extras.pushforward_extras) + jac_col_j = pushforward(f, backend, x, dx_j, pushforward_extras_same) vec(jac_col_j) end return y, jac @@ -111,10 +118,13 @@ end function value_and_jacobian_onearg_aux( f::F, backend, x::AbstractArray, extras::PullbackJacobianExtras ) where {F} - y, pullbackfunc = value_and_pullback_split(f, backend, x, extras.pullback_extras) + y = f(x) # TODO: remove + pullback_extras_same = prepare_pullback_same_point( + f, backend, x, basis(backend, y, first(CartesianIndices(y))), extras.pullback_extras + ) jac = stack(CartesianIndices(y); dims=1) do i dy_i = basis(backend, y, i) - jac_row_i = pullbackfunc(dy_i) + jac_row_i = pullback(f, backend, x, dy_i, pullback_extras_same) vec(jac_row_i) end return y, jac @@ -133,11 +143,18 @@ end function value_and_jacobian_onearg_aux!( f::F, jac::AbstractMatrix, backend, x::AbstractArray, extras::PushforwardJacobianExtras ) where {F} - y = f(x) + y = f(x) # TODO: remove + pushforward_extras_same = prepare_pushforward_same_point( + f, + backend, + x, + basis(backend, x, first(CartesianIndices(x))), + extras.pushforward_extras, + ) for (k, j) in enumerate(CartesianIndices(x)) dx_j = basis(backend, x, j) jac_col_j = reshape(view(jac, :, k), size(y)) - pushforward!(f, jac_col_j, backend, x, dx_j, extras.pushforward_extras) + pushforward!(f, jac_col_j, backend, x, dx_j, pushforward_extras_same) end return y, jac end @@ -145,11 +162,14 @@ end function value_and_jacobian_onearg_aux!( f::F, jac::AbstractMatrix, backend, x::AbstractArray, extras::PullbackJacobianExtras ) where {F} - y, pullbackfunc! = value_and_pullback!_split(f, backend, x, extras.pullback_extras) + y = f(x) # TODO: remove + pullback_extras_same = prepare_pullback_same_point( + f, backend, x, basis(backend, y, first(CartesianIndices(y))), extras.pullback_extras + ) for (k, i) in enumerate(CartesianIndices(y)) dy_i = basis(backend, y, i) jac_row_i = reshape(view(jac, k, :), size(x)) - pullbackfunc!(jac_row_i, dy_i) + pullback!(f, jac_row_i, backend, x, dy_i, pullback_extras_same) end return y, jac end @@ -185,25 +205,40 @@ end function value_and_jacobian_twoarg_aux( f!::F, y, backend, x::AbstractArray, extras::PushforwardJacobianExtras ) where {F} + pushforward_extras_same = prepare_pushforward_same_point( + f!, + y, + backend, + x, + basis(backend, x, first(CartesianIndices(x))), + extras.pushforward_extras, + ) jac = stack(CartesianIndices(x); dims=2) do j dx_j = basis(backend, x, j) - jac_col_j = pushforward(f!, y, backend, x, dx_j, extras.pushforward_extras) + jac_col_j = pushforward(f!, y, backend, x, dx_j, pushforward_extras_same) vec(jac_col_j) end - f!(y, x) + f!(y, x) # TODO: remove return y, jac end function value_and_jacobian_twoarg_aux( f!::F, y, backend, x::AbstractArray, extras::PullbackJacobianExtras ) where {F} - y, pullbackfunc = value_and_pullback_split(f!, y, backend, x, extras.pullback_extras) + pullback_extras_same = prepare_pullback_same_point( + f!, + y, + backend, + x, + basis(backend, y, first(CartesianIndices(y))), + extras.pullback_extras, + ) jac = stack(CartesianIndices(y); dims=1) do i dy_i = basis(backend, y, i) - jac_row_i = pullbackfunc(y, dy_i) + jac_row_i = pullback(f!, y, backend, x, dy_i, pullback_extras_same) vec(jac_row_i) end - f!(y, x) + f!(y, x) # TODO: remove return y, jac end @@ -226,25 +261,40 @@ function value_and_jacobian_twoarg_aux!( x::AbstractArray, extras::PushforwardJacobianExtras, ) where {F} + pushforward_extras_same = prepare_pushforward_same_point( + f!, + y, + backend, + x, + basis(backend, x, first(CartesianIndices(x))), + extras.pushforward_extras, + ) for (k, j) in enumerate(CartesianIndices(x)) dx_j = basis(backend, x, j) jac_col_j = reshape(view(jac, :, k), size(y)) - pushforward!(f!, y, jac_col_j, backend, x, dx_j, extras.pushforward_extras) + pushforward!(f!, y, jac_col_j, backend, x, dx_j, pushforward_extras_same) end - f!(y, x) + f!(y, x) # TODO: remove return y, jac end function value_and_jacobian_twoarg_aux!( f!::F, y, jac::AbstractMatrix, backend, x::AbstractArray, extras::PullbackJacobianExtras ) where {F} - y, pullbackfunc! = value_and_pullback!_split(f!, y, backend, x, extras.pullback_extras) + pullback_extras_same = prepare_pullback_same_point( + f!, + y, + backend, + x, + basis(backend, y, first(CartesianIndices(y))), + extras.pullback_extras, + ) for (k, i) in enumerate(CartesianIndices(y)) dy_i = basis(backend, y, i) jac_row_i = reshape(view(jac, k, :), size(x)) - pullbackfunc!(y, jac_row_i, dy_i) + pullback!(f!, y, jac_row_i, backend, x, dy_i, pullback_extras_same) end - f!(y, x) + f!(y, x) # TODO: remove return y, jac end diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index c8231bc80..2edea75a2 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -12,6 +12,18 @@ Create an `extras` object subtyping [`PullbackExtras`](@ref) that can be given t """ function prepare_pullback end +""" + prepare_pullback_same_point(f, backend, x, dy) -> extras_same + prepare_pullback_same_point(f!, y, backend, x, dy) -> extras_same + +Create an `extras_same` object subtyping [`PullbackExtras`](@ref) that can be given to pullback operators _if they are applied at the same point `x`_. + +!!! warning + If the function or the point changes in any way, the result of preparation will be invalidated, and you will need to run it again. + In the two-argument case, `y` is mutated by `f!` during preparation. +""" +function prepare_pullback_same_point end + """ value_and_pullback(f, backend, x, dy, [extras]) -> (y, dx) value_and_pullback(f!, y, backend, x, dy, [extras]) -> (y, dx) @@ -90,6 +102,30 @@ end prepare_pullback_aux(f::F, backend, x, dy, ::PullbackFast) where {F} = throw(MissingBackendError(backend)) prepare_pullback_aux(f!::F, y, backend, x, dy, ::PullbackFast) where {F} = throw(MissingBackendError(backend)) +## Preparation (same point) + +function prepare_pullback_same_point( + f::F, backend::AbstractADType, x, dy, extras::PullbackExtras +) where {F} + return extras +end + +function prepare_pullback_same_point( + f!::F, y, backend::AbstractADType, x, dy, extras::PullbackExtras +) where {F} + return extras +end + +function prepare_pullback_same_point(f::F, backend::AbstractADType, x, dy) where {F} + extras = prepare_pullback(f, backend, x, dy) + return prepare_pullback_same_point(f, backend, x, dy, extras) +end + +function prepare_pullback_same_point(f!::F, y, backend::AbstractADType, x, dy) where {F} + extras = prepare_pullback(f!, y, backend, x, dy) + return prepare_pullback_same_point(f!, y, backend, x, dy, extras) +end + ## One argument function value_and_pullback( @@ -219,95 +255,3 @@ function pullback!( ) where {F} return value_and_pullback!(f!, y, dx, backend, x, dy, extras)[2] end - -## Split one argument - -struct OneArgPullbackFunc{B,F,X,E} - f::F - backend::B - x::X - extras::E -end - -struct OneArgPullbackFunc!{B,F,X,E} - f::F - backend::B - x::X - extras::E -end - -function (pbf::OneArgPullbackFunc)(dy) - @compat (; f, backend, x, extras) = pbf - return pullback(f, backend, x, dy, extras) -end - -function (pbf::OneArgPullbackFunc!)(dx, dy) - @compat (; f, backend, x, extras) = pbf - return pullback!(f, dx, backend, x, dy, extras) -end - -function value_and_pullback_split( - f::F, - backend::AbstractADType, - x, - extras::PullbackExtras=prepare_pullback(f, backend, x, f(x)), -) where {F} - return f(x), OneArgPullbackFunc(f, backend, x, extras) -end - -function value_and_pullback!_split( - f::F, - backend::AbstractADType, - x, - extras::PullbackExtras=prepare_pullback(f, backend, x, f(x)), -) where {F} - return f(x), OneArgPullbackFunc!(f, backend, x, extras) -end - -## Split two argument - -struct TwoArgPullbackFunc{B,F,X,E} - f!::F - backend::B - x::X - extras::E -end - -struct TwoArgPullbackFunc!{B,F,X,E} - f!::F - backend::B - x::X - extras::E -end - -function (pbf::TwoArgPullbackFunc)(y, dy) - @compat (; f!, backend, x, extras) = pbf - return pullback(f!, y, backend, x, dy, extras) -end - -function (pbf::TwoArgPullbackFunc!)(y, dx, dy) - @compat (; f!, backend, x, extras) = pbf - return pullback!(f!, y, dx, backend, x, dy, extras) -end - -function value_and_pullback_split( - f!::F, - y, - backend::AbstractADType, - x, - extras::PullbackExtras=prepare_pullback(f!, y, backend, x, similar(y)), -) where {F} - f!(y, x) - return y, TwoArgPullbackFunc(f!, backend, x, extras) -end - -function value_and_pullback!_split( - f!::F, - y, - backend::AbstractADType, - x, - extras::PullbackExtras=prepare_pullback(f!, y, backend, x, similar(y)), -) where {F} - f!(y, x) - return y, TwoArgPullbackFunc!(f!, backend, x, extras) -end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 6d35f3e2c..47b2d0e24 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -8,9 +8,22 @@ Create an `extras` object subtyping [`PushforwardExtras`](@ref) that can be give !!! warning If the function changes in any way, the result of preparation will be invalidated, and you will need to run it again. - In the two-argument case, `y` is mutated by `f!` during preparation.""" + In the two-argument case, `y` is mutated by `f!` during preparation. +""" function prepare_pushforward end +""" + prepare_pushforward_same_point(f, backend, x, dx) -> extras_same + prepare_pushforward_same_point(f!, y, backend, x, dx) -> extras_same + +Create an `extras_same` object subtyping [`PushforwardExtras`](@ref) that can be given to pushforward operators _if they are applied at the same point `x`_. + +!!! warning + If the function or the point changes in any way, the result of preparation will be invalidated, and you will need to run it again. + In the two-argument case, `y` is mutated by `f!` during preparation. +""" +function prepare_pushforward_same_point end + """ value_and_pushforward(f, backend, x, dx, [extras]) -> (y, dy) value_and_pushforward(f!, y, backend, x, dx, [extras]) -> (y, dy) @@ -78,6 +91,30 @@ end prepare_pushforward_aux(f::F, backend, x, dy, ::PushforwardFast) where {F} = throw(MissingBackendError(backend)) prepare_pushforward_aux(f!::F, y, backend, x, dy, ::PushforwardFast) where {F} = throw(MissingBackendError(backend)) +## Preparation (same point) + +function prepare_pushforward_same_point( + f::F, backend::AbstractADType, x, dx, extras::PushforwardExtras +) where {F} + return extras +end + +function prepare_pushforward_same_point( + f!::F, y, backend::AbstractADType, x, dx, extras::PushforwardExtras +) where {F} + return extras +end + +function prepare_pushforward_same_point(f::F, backend::AbstractADType, x, dx) where {F} + extras = prepare_pushforward(f, backend, x, dx) + return prepare_pushforward_same_point(f, backend, x, dx, extras) +end + +function prepare_pushforward_same_point(f!::F, y, backend::AbstractADType, x, dx) where {F} + extras = prepare_pushforward(f!, y, backend, x, dx) + return prepare_pushforward_same_point(f!, y, backend, x, dx, extras) +end + ## One argument function value_and_pushforward( @@ -94,18 +131,18 @@ function value_and_pushforward_onearg_aux( f::F, backend, x, dx, extras::PullbackPushforwardExtras ) where {F} @compat (; pullback_extras) = extras - y, pullbackfunc = value_and_pullback_split(f, backend, x, pullback_extras) + y = f(x) dy = if x isa Number && y isa Number - dx * pullbackfunc(one(y)) + dx * pullback(f, backend, x, one(y), pullback_extras) elseif x isa AbstractArray && y isa Number - dot(dx, pullbackfunc(one(y))) + dot(dx, pullback(f, backend, x, one(y), pullback_extras)) elseif x isa Number && y isa AbstractArray map(CartesianIndices(y)) do i - dx * pullbackfunc(basis(backend, y, i)) + dx * pullback(f, backend, x, basis(backend, y, i), pullback_extras) end elseif x isa AbstractArray && y isa AbstractArray map(CartesianIndices(y)) do i - dot(dx, pullbackfunc(basis(backend, y, i))) + dot(dx, pullback(f, backend, x, basis(backend, y, i), pullback_extras)) end end return y, dy diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 04e56e024..f4a16c5fe 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -54,8 +54,11 @@ end function hessian( f::F, backend::SecondOrder, x, extras::HessianExtras=prepare_hessian(f, backend, x) ) where {F} + hvp_extras_same = prepare_hvp_same_point( + f, backend, x, basis(backend, x, first(CartesianIndices(x))), extras.hvp_extras + ) hess = stack(vec(CartesianIndices(x))) do j - hess_col_j = hvp(f, backend, x, basis(backend, x, j), extras.hvp_extras) + hess_col_j = hvp(f, backend, x, basis(backend, x, j), hvp_extras_same) vec(hess_col_j) end return hess @@ -80,9 +83,12 @@ function hessian!( x, extras::HessianExtras=prepare_hessian(f, backend, x), ) where {F} + hvp_extras_same = prepare_hvp_same_point( + f, backend, x, basis(backend, x, first(CartesianIndices(x))), extras.hvp_extras + ) for (k, j) in enumerate(CartesianIndices(x)) hess_col_j = reshape(view(hess, :, k), size(x)) - hvp!(f, hess_col_j, backend, x, basis(backend, x, j), extras.hvp_extras) + hvp!(f, hess_col_j, backend, x, basis(backend, x, j), hvp_extras_same) end return hess end diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index d303763af..7f31c75e4 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -10,6 +10,16 @@ Create an `extras` object subtyping [`HVPExtras`](@ref) that can be given to Hes """ function prepare_hvp end +""" + prepare_hvp_same_point(f, backend, x, v) -> extras_same + +Create an `extras_same` object subtyping [`HVPExtras`](@ref) that can be given to Hessian-vector product operators _if they are applied at the same point `x`_. + +!!! warning + If the function or the point changes in any way, the result of preparation will be invalidated, and you will need to run it again. +""" +function prepare_hvp_same_point end + """ hvp(f, backend, x, v, [extras]) -> p """ @@ -104,6 +114,19 @@ function prepare_hvp_aux(f::F, backend::SecondOrder, x, v, ::ReverseOverReverse) return ReverseOverReverseHVPExtras(inner_gradient_closure, outer_pullback_extras) end +## Preparation (same point) + +function prepare_hvp_same_point( + f::F, backend::AbstractADType, x, v, extras::HVPExtras +) where {F} + return extras +end + +function prepare_hvp_same_point(f::F, backend::AbstractADType, x, v) where {F} + extras = prepare_hvp(f, backend, x, v) + return prepare_hvp_same_point(f, backend, x, v, extras) +end + ## One argument function hvp( diff --git a/DifferentiationInterface/src/sparse/fallbacks.jl b/DifferentiationInterface/src/sparse/fallbacks.jl index 8338cea81..0aec9dea2 100644 --- a/DifferentiationInterface/src/sparse/fallbacks.jl +++ b/DifferentiationInterface/src/sparse/fallbacks.jl @@ -17,6 +17,7 @@ for op in (:pushforward, :pullback, :hvp) valop = Symbol("value_and_", op) valop! = Symbol("value_and_", op, "!") prep = Symbol("prepare_", op) + prepsame = Symbol("prepare_", op, "_same_point") E = if op == :pushforward :PushforwardExtras elseif op == :pullback @@ -28,6 +29,9 @@ for op in (:pushforward, :pullback, :hvp) ## One argument @eval begin $prep(f::F, ba::AutoSparse, x, v) where {F} = $prep(f, dense_ad(ba), x, v) + $prepsame(f::F, ba::AutoSparse, x, v) where {F} = $prepsame(f, dense_ad(ba), x, v) + $prepsame(f::F, ba::AutoSparse, x, v, ex::$E) where {F} = + $prepsame(f, dense_ad(ba), x, v, ex) $op(f::F, ba::AutoSparse, x, v, ex::$E=$prep(f, ba, x, v)) where {F} = $op(f, dense_ad(ba), x, v, ex) $valop(f::F, ba::AutoSparse, x, v, ex::$E=$prep(f, ba, x, v)) where {F} = @@ -41,6 +45,10 @@ for op in (:pushforward, :pullback, :hvp) ## Two arguments @eval begin $prep(f!::F, y, ba::AutoSparse, x, v) where {F} = $prep(f!, y, dense_ad(ba), x, v) + $prepsame(f!::F, y, ba::AutoSparse, x, v) where {F} = + $prepsame(f!, y, dense_ad(ba), x, v) + $prepsame(f!::F, y, ba::AutoSparse, x, v, ex::$E) where {F} = + $prepsame(f!, y, dense_ad(ba), x, v, ex) $op(f!::F, y, ba::AutoSparse, x, v, ex::$E=$prep(f!, y, ba, x, v)) where {F} = $op(f!, y, dense_ad(ba), x, v, ex) $valop(f!::F, y, ba::AutoSparse, x, v, ex::$E=$prep(f!, y, ba, x, v)) where {F} = @@ -51,25 +59,6 @@ for op in (:pushforward, :pullback, :hvp) f!::F, y, res, ba::AutoSparse, x, v, ex::$E=$prep(f!, y, ba, x, v) ) where {F} = $valop!(f!, y, res, dense_ad(ba), x, v, ex) end - - ## Split - if op == :pullback - valop_split = Symbol("value_and_", op, "_split") - valop!_split = Symbol("value_and_", op!, "_split") - - @eval begin - $valop_split(f::F, ba::AutoSparse, x, ex::$E=$prep(f, ba, x, f(x))) where {F} = - $valop_split(f, dense_ad(ba), x, ex) - $valop!_split(f::F, ba::AutoSparse, x, ex::$E=$prep(f, ba, x, f(x))) where {F} = - $valop!_split(f, dense_ad(ba), x, ex) - $valop_split( - f!::F, y, ba::AutoSparse, x, ex::$E=$prep(f, ba, x, similar(y)) - ) where {F} = $valop_split(f!, y, dense_ad(ba), x, ex) - $valop!_split( - f!::F, y, ba::AutoSparse, x, ex::$E=$prep(f, ba, x, similar(y)) - ) where {F} = $valop!_split(f!, y, dense_ad(ba), x, ex) - end - end end for op in (:derivative, :gradient, :second_derivative) diff --git a/DifferentiationInterface/src/sparse/hessian.jl b/DifferentiationInterface/src/sparse/hessian.jl index f5c32bf82..488fea2b9 100644 --- a/DifferentiationInterface/src/sparse/hessian.jl +++ b/DifferentiationInterface/src/sparse/hessian.jl @@ -21,7 +21,7 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F} end hvp_extras = prepare_hvp(f, backend, x, first(seeds)) products = map(seeds) do seed - hvp(f, backend, x, seed, hvp_extras) + similar(x) end aggregates = stack(vec, products; dims=2) compressed = CompressedMatrix{:col}(sparsity, colors, groups, aggregates) @@ -30,8 +30,9 @@ end function hessian!(f::F, hess, backend::AutoSparse, x, extras::SparseHessianExtras) where {F} @compat (; compressed, seeds, products, hvp_extras) = extras + hvp_extras_same = prepare_hvp_same_point(f, backend, x, seeds[1], hvp_extras) for k in eachindex(seeds, products) - hvp!(f, products[k], backend, x, seeds[k], hvp_extras) + hvp!(f, products[k], backend, x, seeds[k], hvp_extras_same) copyto!(view(compressed.aggregates, :, k), vec(products[k])) end decompress!(hess, compressed) diff --git a/DifferentiationInterface/src/sparse/jacobian.jl b/DifferentiationInterface/src/sparse/jacobian.jl index baac01c88..f7af44fbe 100644 --- a/DifferentiationInterface/src/sparse/jacobian.jl +++ b/DifferentiationInterface/src/sparse/jacobian.jl @@ -35,7 +35,7 @@ function prepare_jacobian(f::F, backend::AutoSparse, x) where {F} end jp_extras = prepare_pushforward(f, backend, x, first(seeds)) products = map(seeds) do seed - pushforward(f, backend, x, seed, jp_extras) + similar(y) end aggregates = stack(vec, products; dims=2) compressed = CompressedMatrix{:col}(sparsity, colors, groups, aggregates) @@ -50,7 +50,7 @@ function prepare_jacobian(f::F, backend::AutoSparse, x) where {F} end jp_extras = prepare_pullback(f, backend, x, first(seeds)) products = map(seeds) do seed - pullback(f, backend, x, seed, jp_extras) + similar(x) end aggregates = stack(vec, products; dims=1) compressed = CompressedMatrix{:row}(sparsity, colors, groups, aggregates) @@ -62,8 +62,11 @@ function jacobian!( f::F, jac, backend::AutoSparse, x, extras::SparseJacobianExtras{1,:col} ) where {F} @compat (; compressed, seeds, products, jp_extras) = extras + pushforward_extras_same = prepare_pushforward_same_point( + f, backend, x, seeds[1], jp_extras + ) for k in eachindex(seeds, products) - pushforward!(f, products[k], backend, x, seeds[k], jp_extras) + pushforward!(f, products[k], backend, x, seeds[k], pushforward_extras_same) copyto!(view(compressed.aggregates, :, k), vec(products[k])) end decompress!(jac, compressed) @@ -74,8 +77,9 @@ function jacobian!( f::F, jac, backend::AutoSparse, x, extras::SparseJacobianExtras{1,:row} ) where {F} @compat (; compressed, seeds, products, jp_extras) = extras + pullback_extras_same = prepare_pullback_same_point(f, backend, x, seeds[1], jp_extras) for k in eachindex(seeds, products) - pullback!(f, products[k], backend, x, seeds[k], jp_extras) + pullback!(f, products[k], backend, x, seeds[k], pullback_extras_same) copyto!(view(compressed.aggregates, k, :), vec(products[k])) end decompress!(jac, compressed) @@ -114,7 +118,7 @@ function prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F} end jp_extras = prepare_pushforward(f!, y, backend, x, first(seeds)) products = map(seeds) do seed - pushforward(f!, y, backend, x, seed, jp_extras) + similar(y) end aggregates = stack(vec, products; dims=2) compressed = CompressedMatrix{:col}(sparsity, colors, groups, aggregates) @@ -129,7 +133,7 @@ function prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F} end jp_extras = prepare_pullback(f!, y, backend, x, first(seeds)) products = map(seeds) do seed - pullback(f!, y, backend, x, seed, jp_extras) + similar(x) end aggregates = stack(vec, products; dims=1) compressed = CompressedMatrix{:row}(sparsity, colors, groups, aggregates) @@ -141,8 +145,11 @@ function jacobian!( f!::F, y, jac, backend::AutoSparse, x, extras::SparseJacobianExtras{2,:col} ) where {F} @compat (; compressed, seeds, products, jp_extras) = extras + pushforward_extras_same = prepare_pushforward_same_point( + f!, y, backend, x, seeds[1], jp_extras + ) for k in eachindex(seeds, products) - pushforward!(f!, y, products[k], backend, x, seeds[k], jp_extras) + pushforward!(f!, y, products[k], backend, x, seeds[k], pushforward_extras_same) copyto!(view(compressed.aggregates, :, k), vec(products[k])) end decompress!(jac, compressed) @@ -153,8 +160,11 @@ function jacobian!( f!::F, y, jac, backend::AutoSparse, x, extras::SparseJacobianExtras{2,:row} ) where {F} @compat (; compressed, seeds, products, jp_extras) = extras + pullback_extras_same = prepare_pullback_same_point( + f!, y, backend, x, seeds[1], jp_extras + ) for k in eachindex(seeds, products) - pullback!(f!, y, products[k], backend, x, seeds[k], jp_extras) + pullback!(f!, y, products[k], backend, x, seeds[k], pullback_extras_same) copyto!(view(compressed.aggregates, k, :), vec(products[k])) end decompress!(jac, compressed) diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index b95c53b57..077e3a9a2 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterfaceTest" uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.3.1" +version = "0.4.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -23,7 +23,7 @@ ADTypes = "1.0.0" Chairmarks = "1.2.1" Compat = "4" ComponentArrays = "0.15" -DifferentiationInterface = "0.3.4" +DifferentiationInterface = "0.4.0" DocStringExtensions = "0.9" JET = "0.4 - 0.8" JLArrays = "0.1" diff --git a/DifferentiationInterfaceTest/docs/Project.toml b/DifferentiationInterfaceTest/docs/Project.toml index ebb759898..ddaf7becc 100644 --- a/DifferentiationInterfaceTest/docs/Project.toml +++ b/DifferentiationInterfaceTest/docs/Project.toml @@ -13,6 +13,4 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [compat] -DifferentiationInterface = "0.3" -DifferentiationInterfaceTest = "0.3" Documenter = "1" diff --git a/DifferentiationInterfaceTest/src/scenarios/scenario.jl b/DifferentiationInterfaceTest/src/scenarios/scenario.jl index 60f21235e..f2b4e72db 100644 --- a/DifferentiationInterfaceTest/src/scenarios/scenario.jl +++ b/DifferentiationInterfaceTest/src/scenarios/scenario.jl @@ -218,7 +218,7 @@ for S in (:PushforwardScenario, :HVPScenario) y = f(x) end if isnothing(dx) - dx = mysimilar_random(x) + dx = mycopy_random(x) end return ($S){args,place,F,X,typeof(y),typeof(dx),R}(f, x, y, dx, ref) end @@ -247,7 +247,7 @@ for S in (:PullbackScenario,) y = f(x) end if isnothing(dy) - dy = mysimilar_random(y) + dy = mycopy_random(y) end return ($S){args,place,F,X,typeof(y),typeof(dy),R}(f, x, y, dy, ref) end diff --git a/DifferentiationInterfaceTest/src/tests/benchmark.jl b/DifferentiationInterfaceTest/src/tests/benchmark.jl index 7d73e80de..80648708e 100644 --- a/DifferentiationInterfaceTest/src/tests/benchmark.jl +++ b/DifferentiationInterfaceTest/src/tests/benchmark.jl @@ -287,16 +287,12 @@ function run_benchmark!( logging::Bool, ) @compat (; f, x, y, dy) = deepcopy(scen) - @compat (; - bench0, bench1, bench2, bench3, bench4, calls0, calls1, calls2, calls3, calls4 - ) = try + @compat (; bench0, bench1, bench2, calls0, calls1, calls2) = try # benchmark extras = prepare_pullback(f, ba, x, dy) bench0 = @be prepare_pullback(f, ba, x, dy) samples = 1 evals = 1 bench1 = @be deepcopy(extras) value_and_pullback(f, ba, x, dy, _) bench2 = @be deepcopy(extras) pullback(f, ba, x, dy, _) - bench3 = @be deepcopy(extras) value_and_pullback_split(f, ba, x, _) - bench4 = @be last(value_and_pullback_split(f, ba, x, deepcopy(extras))) _(dy) # count cc = CallCounter(f) extras = prepare_pullback(cc, ba, x, dy) @@ -305,23 +301,17 @@ function run_benchmark!( calls1 = reset_count!(cc) pullback(cc, ba, x, dy, extras) calls2 = reset_count!(cc) - _, pullbackfunc = value_and_pullback_split(cc, ba, x, extras) - calls3 = reset_count!(cc) - pullbackfunc(dy) - calls4 = reset_count!(cc) - (; bench0, bench1, bench2, bench3, bench4, calls0, calls1, calls2, calls3, calls4) + (; bench0, bench1, bench2, calls0, calls1, calls2) catch e logging && @warn "Error during benchmarking" ba scen e - bench0, bench1, bench2, bench3, bench4 = failed_benchs(5) - calls0, calls1, calls2, calls3, calls4 = -1, -1, -1, -1, -1 - (; bench0, bench1, bench2, bench3, bench4, calls0, calls1, calls2, calls3, calls4) + bench0, bench1, bench2 = failed_benchs(3) + calls0, calls1, calls2 = -1, -1, -1 + (; bench0, bench1, bench2, calls0, calls1, calls2) end # record record!(data, ba, scen, :prepare_pullback, bench0, calls0) record!(data, ba, scen, :value_and_pullback, bench1, calls1) record!(data, ba, scen, :pullback, bench2, calls2) - record!(data, ba, scen, :value_and_pullback_split, bench3, calls3) - record!(data, ba, scen, :pullbackfunc, bench4, calls4) return nothing end @@ -332,9 +322,7 @@ function run_benchmark!( logging::Bool, ) @compat (; f, x, y, dy) = deepcopy(scen) - @compat (; - bench0, bench1, bench2, bench3, bench4, calls0, calls1, calls2, calls3, calls4 - ) = try + @compat (; bench0, bench1, bench2, calls0, calls1, calls2) = try # benchmark extras = prepare_pullback(f, ba, x, dy) bench0 = @be prepare_pullback(f, ba, x, dy) samples = 1 evals = 1 @@ -344,11 +332,6 @@ function run_benchmark!( bench2 = @be (dx=mysimilar(x), ext=deepcopy(extras)) pullback!( f, _.dx, ba, x, dy, _.ext ) evals = 1 - bench3 = @be deepcopy(extras) value_and_pullback!_split(f, ba, x, _) - bench4 = @be ( - dx=mysimilar(x), - (pullbackfunc!)=last(value_and_pullback!_split(f, ba, x, deepcopy(extras))), - ) _.pullbackfunc!(_.dx, dy) evals = 1 # count cc = CallCounter(f) extras = prepare_pullback(cc, ba, x, dy) @@ -357,23 +340,17 @@ function run_benchmark!( calls1 = reset_count!(cc) pullback!(cc, mysimilar(x), ba, x, dy, extras) calls2 = reset_count!(cc) - _, pullbackfunc! = value_and_pullback!_split(cc, ba, x, extras) - calls3 = reset_count!(cc) - pullbackfunc!(mysimilar(x), dy) - calls4 = reset_count!(cc) - (; bench0, bench1, bench2, bench3, bench4, calls0, calls1, calls2, calls3, calls4) + (; bench0, bench1, bench2, calls0, calls1, calls2) catch e logging && @warn "Error during benchmarking" ba scen e - bench0, bench1, bench2, bench3, bench4 = failed_benchs(5) - calls0, calls1, calls2, calls3, calls4 = -1, -1, -1, -1, -1 - (; bench0, bench1, bench2, bench3, bench4, calls0, calls1, calls2, calls3, calls4) + bench0, bench1, bench2 = failed_benchs(3) + calls0, calls1, calls2 = -1, -1, -1 + (; bench0, bench1, bench2, calls0, calls1, calls2) end # record record!(data, ba, scen, :prepare_pullback, bench0, calls0) record!(data, ba, scen, :value_and_pullback!, bench1, calls1) record!(data, ba, scen, :pullback!, bench2, calls2) - record!(data, ba, scen, :value_and_pullback!_split, bench3, calls3) - record!(data, ba, scen, :pullbackfunc!, bench4, calls4) return nothing end @@ -385,9 +362,7 @@ function run_benchmark!( ) @compat (; f, x, y, dy) = deepcopy(scen) f! = f - @compat (; - bench0, bench1, bench2, bench3, bench4, calls0, calls1, calls2, calls3, calls4 - ) = try + @compat (; bench0, bench1, bench2, calls0, calls1, calls2) = try # benchmark extras = prepare_pullback(f!, mysimilar(y), ba, x, dy) bench0 = @be mysimilar(y) prepare_pullback(f!, _, ba, x, dy) samples = 1 evals = @@ -398,13 +373,6 @@ function run_benchmark!( bench2 = @be (y=mysimilar(y), ext=deepcopy(extras)) pullback( f!, _.y, ba, x, dy, _.ext ) evals = 1 - bench3 = @be deepcopy(extras) value_and_pullback_split(f!, y, ba, x, _) - bench4 = @be ( - y=mysimilar(y), - pullbackfunc=last( - value_and_pullback_split(f!, mysimilar(y), ba, x, deepcopy(extras)) - ), - ) _.pullbackfunc(_.y, dy) evals = 1 # count cc! = CallCounter(f!) extras = prepare_pullback(cc!, mysimilar(y), ba, x, dy) @@ -413,23 +381,17 @@ function run_benchmark!( calls1 = reset_count!(cc!) pullback(cc!, mysimilar(y), ba, x, dy, extras) calls2 = reset_count!(cc!) - _, pullbackfunc = value_and_pullback_split(cc!, y, ba, x, extras) - calls3 = reset_count!(cc!) - pullbackfunc(y, dy) - calls4 = reset_count!(cc!) - (; bench0, bench1, bench2, bench3, bench4, calls0, calls1, calls2, calls3, calls4) + (; bench0, bench1, bench2, calls0, calls1, calls2) catch e logging && @warn "Error during benchmarking" ba scen e - bench0, bench1, bench2, bench3, bench4 = failed_benchs(5) - calls0, calls1, calls2, calls3, calls4 = -1, -1, -1, -1, -1 - (; bench0, bench1, bench2, bench3, bench4, calls0, calls1, calls2, calls3, calls4) + bench0, bench1, bench2 = failed_benchs(3) + calls0, calls1, calls2 = -1, -1, -1 + (; bench0, bench1, bench2, calls0, calls1, calls2) end # record record!(data, ba, scen, :prepare_pullback, bench0, calls0) record!(data, ba, scen, :value_and_pullback, bench1, calls1) record!(data, ba, scen, :pullback, bench2, calls2) - record!(data, ba, scen, :value_and_pullback_split, bench3, calls3) - record!(data, ba, scen, :pullbackfunc, bench4, calls4) return nothing end @@ -441,9 +403,7 @@ function run_benchmark!( ) @compat (; f, x, y, dy) = deepcopy(scen) f! = f - @compat (; - bench0, bench1, bench2, bench3, bench4, calls0, calls1, calls2, calls3, calls4 - ) = try + @compat (; bench0, bench1, bench2, calls0, calls1, calls2) = try # benchmark extras = prepare_pullback(f!, mysimilar(y), ba, x, dy) bench0 = @be mysimilar(y) prepare_pullback(f!, _, ba, x, dy) samples = 1 evals = @@ -454,16 +414,6 @@ function run_benchmark!( bench2 = @be (y=mysimilar(y), dx=mysimilar(x), ext=deepcopy(extras)) pullback!( f!, _.y, _.dx, ba, x, dy, _.ext ) evals = 1 - bench3 = @be (y=mysimilar(y), ext=deepcopy(extras)) value_and_pullback!_split( - f!, _.y, ba, x, _.ext - ) - bench4 = @be ( - y=mysimilar(y), - dx=mysimilar(x), - (pullbackfunc!)=last( - value_and_pullback!_split(f!, mysimilar(y), ba, x, deepcopy(extras)) - ), - ) _.pullbackfunc!(_.y, _.dx, dy) evals = 1 # count cc! = CallCounter(f!) extras = prepare_pullback(cc!, mysimilar(y), ba, x, dy) @@ -472,23 +422,17 @@ function run_benchmark!( calls1 = reset_count!(cc!) pullback!(cc!, mysimilar(y), mysimilar(x), ba, x, dy, extras) calls2 = reset_count!(cc!) - _, pullbackfunc! = value_and_pullback!_split(cc!, y, ba, x, extras) - calls3 = reset_count!(cc!) - pullbackfunc!(y, mysimilar(x), dy) - calls4 = reset_count!(cc!) - (; bench0, bench1, bench2, bench3, bench4, calls0, calls1, calls2, calls3, calls4) + (; bench0, bench1, bench2, calls0, calls1, calls2) catch e logging && @warn "Error during benchmarking" ba scen e - bench0, bench1, bench2, bench3, bench4 = failed_benchs(5) - calls0, calls1, calls2, calls3, calls4 = -1, -1, -1, -1, -1 - (; bench0, bench1, bench2, bench3, bench4, calls0, calls1, calls2, calls3, calls4) + bench0, bench1, bench2 = failed_benchs(3) + calls0, calls1, calls2 = -1, -1, -1 + (; bench0, bench1, bench2, calls0, calls1, calls2) end # record record!(data, ba, scen, :prepare_pullback, bench0, calls0) record!(data, ba, scen, :value_and_pullback!, bench1, calls1) record!(data, ba, scen, :pullback!, bench2, calls2) - record!(data, ba, scen, :value_and_pullback!_split, bench3, calls3) - record!(data, ba, scen, :pullbackfunc!, bench4, calls4) return nothing end diff --git a/DifferentiationInterfaceTest/src/tests/correctness.jl b/DifferentiationInterfaceTest/src/tests/correctness.jl index 5c1c7322d..e8e3796c5 100644 --- a/DifferentiationInterfaceTest/src/tests/correctness.jl +++ b/DifferentiationInterfaceTest/src/tests/correctness.jl @@ -20,26 +20,33 @@ function test_correctness( ref_backend, ) @compat (; f, x, y, dx) = new_scen = deepcopy(scen) - extras = prepare_pushforward(f, ba, mysimilar_random(x), mysimilar_random(dx)) dy_true = if ref_backend isa AbstractADType pushforward(f, ref_backend, x, dx) else new_scen.ref(x, dx) end - y1, dy1 = value_and_pushforward(f, ba, x, dx, extras) - dy2 = pushforward(f, ba, x, dx, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PushforwardExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Tangent value" begin - @test dy1 ≈ dy_true - @test dy2 ≈ dy_true + for (k, extras) in enumerate([ + prepare_pushforward(f, ba, mycopy_random(x), mycopy_random(dx)), + prepare_pushforward_same_point(f, ba, x, mycopy_random(dx)), + ]) + testset_name = k == 1 ? "Different point" : "Same point" + @testset "$testset_name" begin + y1, dy1 = value_and_pushforward(f, ba, x, dx, extras) + dy2 = pushforward(f, ba, x, dx, extras) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa PushforwardExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Tangent value" begin + @test dy1 ≈ dy_true + @test dy2 ≈ dy_true + end + end end end test_scen_intact(new_scen, scen) @@ -55,31 +62,38 @@ function test_correctness( ref_backend, ) @compat (; f, x, y, dx) = new_scen = deepcopy(scen) - extras = prepare_pushforward(f, ba, mysimilar_random(x), mysimilar_random(dx)) dy_true = if ref_backend isa AbstractADType pushforward(f, ref_backend, x, dx) else new_scen.ref(x, dx) end - dy1_in = mysimilar(y) - y1, dy1 = value_and_pushforward!(f, dy1_in, ba, x, dx, extras) - - dy2_in = mysimilar(y) - dy2 = pushforward!(f, dy2_in, ba, x, dx, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PushforwardExtras - end - @testset "Primal value" begin - @test y1 ≈ y - end - @testset "Tangent value" begin - @test dy1_in ≈ dy_true - @test dy1 ≈ dy_true - @test dy2_in ≈ dy_true - @test dy2 ≈ dy_true + for (k, extras) in enumerate([ + prepare_pushforward(f, ba, mycopy_random(x), mycopy_random(dx)), + prepare_pushforward_same_point(f, ba, x, mycopy_random(dx)), + ]) + testset_name = k == 1 ? "Different point" : "Same point" + @testset "$testset_name" begin + dy1_in = mysimilar(y) + y1, dy1 = value_and_pushforward!(f, dy1_in, ba, x, dx, extras) + + dy2_in = mysimilar(y) + dy2 = pushforward!(f, dy2_in, ba, x, dx, extras) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa PushforwardExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Tangent value" begin + @test dy1_in ≈ dy_true + @test dy1 ≈ dy_true + @test dy2_in ≈ dy_true + @test dy2 ≈ dy_true + end + end end end test_scen_intact(new_scen, scen) @@ -96,32 +110,37 @@ function test_correctness( ) @compat (; f, x, y, dx) = new_scen = deepcopy(scen) f! = f - extras = prepare_pushforward( - f!, mysimilar(y), ba, mysimilar_random(x), mysimilar_random(dx) - ) dy_true = if ref_backend isa AbstractADType pushforward(f!, mysimilar(y), ref_backend, x, dx) else new_scen.ref(x, dx) end - y1_in = mysimilar(y) - y1, dy1 = value_and_pushforward(f!, y1_in, ba, x, dx, extras) - - y2_in = mysimilar(y) - dy2 = pushforward(f!, y2_in, ba, x, dx, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PushforwardExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Tangent value" begin - @test dy1 ≈ dy_true - @test dy2 ≈ dy_true + for (k, extras) in enumerate([ + prepare_pushforward(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dx)), + prepare_pushforward_same_point(f!, mysimilar(y), ba, x, mycopy_random(dx)), + ]) + testset_name = k == 1 ? "Different point" : "Same point" + @testset "$testset_name" begin + y1_in = mysimilar(y) + y1, dy1 = value_and_pushforward(f!, y1_in, ba, x, dx, extras) + + y2_in = mysimilar(y) + dy2 = pushforward(f!, y2_in, ba, x, dx, extras) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa PushforwardExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Tangent value" begin + @test dy1 ≈ dy_true + @test dy2 ≈ dy_true + end + end end end test_scen_intact(new_scen, scen) @@ -138,34 +157,39 @@ function test_correctness( ) @compat (; f, x, y, dx) = new_scen = deepcopy(scen) f! = f - extras = prepare_pushforward( - f!, mysimilar(y), ba, mysimilar_random(x), mysimilar_random(dx) - ) dy_true = if ref_backend isa AbstractADType pushforward(f!, mysimilar(y), ref_backend, x, dx) else new_scen.ref(x, dx) end - y1_in, dy1_in = mysimilar(y), mysimilar(y) - y1, dy1 = value_and_pushforward!(f!, y1_in, dy1_in, ba, x, dx, extras) - - y2_in, dy2_in = mysimilar(y), mysimilar(y) - dy2 = pushforward!(f!, y2_in, dy2_in, ba, x, dx, extras) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PushforwardExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - end - @testset "Tangent value" begin - @test dy1_in ≈ dy_true - @test dy1 ≈ dy_true - @test dy2_in ≈ dy_true - @test dy2 ≈ dy_true + for (k, extras) in enumerate([ + prepare_pushforward(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dx)), + prepare_pushforward_same_point(f!, mysimilar(y), ba, x, mycopy_random(dx)), + ]) + testset_name = k == 1 ? "Different point" : "Same point" + @testset "$testset_name" begin + y1_in, dy1_in = mysimilar(y), mysimilar(y) + y1, dy1 = value_and_pushforward!(f!, y1_in, dy1_in, ba, x, dx, extras) + + y2_in, dy2_in = mysimilar(y), mysimilar(y) + dy2 = pushforward!(f!, y2_in, dy2_in, ba, x, dx, extras) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa PushforwardExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Tangent value" begin + @test dy1_in ≈ dy_true + @test dy1 ≈ dy_true + @test dy2_in ≈ dy_true + @test dy2 ≈ dy_true + end + end end end test_scen_intact(new_scen, scen) @@ -183,33 +207,34 @@ function test_correctness( ref_backend, ) @compat (; f, x, y, dy) = new_scen = deepcopy(scen) - extras = prepare_pullback(f, ba, mysimilar_random(x), mysimilar_random(dy)) dx_true = if ref_backend isa AbstractADType pullback(f, ref_backend, x, dy) else new_scen.ref(x, dy) end - y1, dx1 = value_and_pullback(f, ba, x, dy, extras) + for (k, extras) in enumerate([ + prepare_pullback(f, ba, mycopy_random(x), mycopy_random(dy)), + prepare_pullback_same_point(f, ba, x, mycopy_random(dy)), + ]) + testset_name = k == 1 ? "Different point" : "Same point" + @testset "$testset_name" begin + y1, dx1 = value_and_pullback(f, ba, x, dy, extras) - dx2 = pullback(f, ba, x, dy, extras) + dx2 = pullback(f, ba, x, dy, extras) - y3, pullbackfunc = value_and_pullback_split(f, ba, x, extras) - pullbackfunc(dy) # call once in case the second errors - dx3 = pullbackfunc(dy) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PullbackExtras - end - @testset "Primal value" begin - @test y1 ≈ y - @test y3 ≈ y - end - @testset "Cotangent value" begin - @test dx1 ≈ dx_true - @test dx2 ≈ dx_true - @test dx3 ≈ dx_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa PullbackExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Cotangent value" begin + @test dx1 ≈ dx_true + @test dx2 ≈ dx_true + end + end end end test_scen_intact(new_scen, scen) @@ -225,39 +250,38 @@ function test_correctness( ref_backend, ) @compat (; f, x, y, dy) = new_scen = deepcopy(scen) - extras = prepare_pullback(f, ba, mysimilar_random(x), mysimilar_random(dy)) dx_true = if ref_backend isa AbstractADType pullback(f, ref_backend, x, dy) else new_scen.ref(x, dy) end - dx1_in = mysimilar(x) - y1, dx1 = value_and_pullback!(f, dx1_in, ba, x, dy, extras) - - dx2_in = mysimilar(x) - dx2 = pullback!(f, dx2_in, ba, x, dy, extras) - - y3, pullbackfunc! = value_and_pullback!_split(f, ba, x, extras) - pullbackfunc!(mysimilar(x), dy) # call once in case the second errors - dx3_in = mysimilar(x) - dx3 = pullbackfunc!(dx3_in, dy) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PullbackExtras - end - @testset "Primal value" begin - @test y1 ≈ y - @test y3 ≈ y - end - @testset "Cotangent value" begin - @test dx1_in ≈ dx_true - @test dx1 ≈ dx_true - @test dx2_in ≈ dx_true - @test dx2 ≈ dx_true - @test dx3_in ≈ dx_true - @test dx3 ≈ dx_true + for (k, extras) in enumerate([ + prepare_pullback(f, ba, mycopy_random(x), mycopy_random(dy)), + prepare_pullback_same_point(f, ba, x, mycopy_random(dy)), + ]) + testset_name = k == 1 ? "Different point" : "Same point" + @testset "$testset_name" begin + dx1_in = mysimilar(x) + y1, dx1 = value_and_pullback!(f, dx1_in, ba, x, dy, extras) + + dx2_in = mysimilar(x) + dx2 = pullback!(f, dx2_in, ba, x, dy, extras) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa PullbackExtras + end + @testset "Primal value" begin + @test y1 ≈ y + end + @testset "Cotangent value" begin + @test dx1_in ≈ dx_true + @test dx1 ≈ dx_true + @test dx2_in ≈ dx_true + @test dx2 ≈ dx_true + end + end end end test_scen_intact(new_scen, scen) @@ -274,41 +298,37 @@ function test_correctness( ) @compat (; f, x, y, dy) = new_scen = deepcopy(scen) f! = f - extras = prepare_pullback( - f!, mysimilar(y), ba, mysimilar_random(x), mysimilar_random(dy) - ) dx_true = if ref_backend isa AbstractADType pullback(f!, mysimilar(y), ref_backend, x, dy) else new_scen.ref(x, dy) end - y1_in = mysimilar(y) - y1, dx1 = value_and_pullback(f!, y1_in, ba, x, dy, extras) - - y2_in = mysimilar(y) - dx2 = pullback(f!, y2_in, ba, x, dy, extras) - - y3_in = mysimilar(y) - y3, pullbackfunc = value_and_pullback_split(f!, y3_in, ba, x, extras) - pullbackfunc(mysimilar(y), dy) # call once in case the second errors - y3_in2 = mysimilar(y) - dx3 = pullbackfunc(y3_in2, dy) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PullbackExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - @test y3_in ≈ y - @test y3 ≈ y - end - @testset "Cotangent value" begin - @test dx1 ≈ dx_true - @test dx2 ≈ dx_true - @test dx3 ≈ dx_true + for (k, extras) in enumerate([ + prepare_pullback(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dy)), + prepare_pullback_same_point(f!, mysimilar(y), ba, x, mycopy_random(dy)), + ]) + testset_name = k == 1 ? "Different point" : "Same point" + @testset "$testset_name" begin + y1_in = mysimilar(y) + y1, dx1 = value_and_pullback(f!, y1_in, ba, x, dy, extras) + + y2_in = mysimilar(y) + dx2 = pullback(f!, y2_in, ba, x, dy, extras) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa PullbackExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Cotangent value" begin + @test dx1 ≈ dx_true + @test dx2 ≈ dx_true + end + end end end test_scen_intact(new_scen, scen) @@ -325,44 +345,39 @@ function test_correctness( ) @compat (; f, x, y, dy) = new_scen = deepcopy(scen) f! = f - extras = prepare_pullback( - f!, mysimilar(y), ba, mysimilar_random(x), mysimilar_random(dy) - ) dx_true = if ref_backend isa AbstractADType pullback(f!, mysimilar(y), ref_backend, x, dy) else new_scen.ref(x, dy) end - y1_in, dx1_in = mysimilar(y), mysimilar(x) - y1, dx1 = value_and_pullback!(f!, y1_in, dx1_in, ba, x, dy, extras) - - y2_in, dx2_in = mysimilar(y), mysimilar(x) - dx2 = pullback!(f!, y2_in, dx2_in, ba, x, dy, extras) - - y3_in = mysimilar(y) - y3, pullbackfunc! = value_and_pullback!_split(f!, y3_in, ba, x, extras) - pullbackfunc!(mysimilar(y), mysimilar(x), dy) # call once in case the second errors - y3_in2, dx3_in = mysimilar(y), mysimilar(x) - dx3 = pullbackfunc!(y3_in2, dx3_in, dy) - - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa PullbackExtras - end - @testset "Primal value" begin - @test y1_in ≈ y - @test y1 ≈ y - @test y3_in ≈ y - @test y3 ≈ y - end - @testset "Cotangent value" begin - @test dx1_in ≈ dx_true - @test dx1 ≈ dx_true - @test dx2_in ≈ dx_true - @test dx2 ≈ dx_true - @test dx3_in ≈ dx_true - @test dx3 ≈ dx_true + for (k, extras) in enumerate([ + prepare_pullback(f!, mysimilar(y), ba, mycopy_random(x), mycopy_random(dy)), + prepare_pullback_same_point(f!, mysimilar(y), ba, x, mycopy_random(dy)), + ]) + testset_name = k == 1 ? "Different point" : "Same point" + @testset "$testset_name" begin + y1_in, dx1_in = mysimilar(y), mysimilar(x) + y1, dx1 = value_and_pullback!(f!, y1_in, dx1_in, ba, x, dy, extras) + + y2_in, dx2_in = mysimilar(y), mysimilar(x) + dx2 = pullback!(f!, y2_in, dx2_in, ba, x, dy, extras) + + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa PullbackExtras + end + @testset "Primal value" begin + @test y1_in ≈ y + @test y1 ≈ y + end + @testset "Cotangent value" begin + @test dx1_in ≈ dx_true + @test dx1 ≈ dx_true + @test dx2_in ≈ dx_true + @test dx2 ≈ dx_true + end + end end end test_scen_intact(new_scen, scen) @@ -380,7 +395,7 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_derivative(f, ba, mysimilar_random(x)) + extras = prepare_derivative(f, ba, mycopy_random(x)) der_true = if ref_backend isa AbstractADType derivative(f, ref_backend, x) else @@ -416,7 +431,7 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_derivative(f, ba, mysimilar_random(x)) + extras = prepare_derivative(f, ba, mycopy_random(x)) der_true = if ref_backend isa AbstractADType derivative(f, ref_backend, x) else @@ -457,7 +472,7 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) f! = f - extras = prepare_derivative(f!, mysimilar(y), ba, mysimilar_random(x)) + extras = prepare_derivative(f!, mysimilar(y), ba, mycopy_random(x)) der_true = if ref_backend isa AbstractADType derivative(f!, mysimilar(y), ref_backend, x) else @@ -497,7 +512,7 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) f! = f - extras = prepare_derivative(f!, mysimilar(y), ba, mysimilar_random(x)) + extras = prepare_derivative(f!, mysimilar(y), ba, mycopy_random(x)) der_true = if ref_backend isa AbstractADType derivative(f!, mysimilar(y), ref_backend, x) else @@ -540,7 +555,7 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_gradient(f, ba, mysimilar_random(x)) + extras = prepare_gradient(f, ba, mycopy_random(x)) grad_true = if ref_backend isa AbstractADType gradient(f, ref_backend, x) else @@ -576,7 +591,7 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_gradient(f, ba, mysimilar_random(x)) + extras = prepare_gradient(f, ba, mycopy_random(x)) grad_true = if ref_backend isa AbstractADType gradient(f, ref_backend, x) else @@ -618,7 +633,7 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_jacobian(f, ba, mysimilar_random(x)) + extras = prepare_jacobian(f, ba, mycopy_random(x)) jac_true = if ref_backend isa AbstractADType jacobian(f, ref_backend, x) else @@ -654,7 +669,7 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_jacobian(f, ba, mysimilar_random(x)) + extras = prepare_jacobian(f, ba, mycopy_random(x)) jac_true = if ref_backend isa AbstractADType jacobian(f, ref_backend, x) else @@ -695,7 +710,7 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) f! = f - extras = prepare_jacobian(f!, mysimilar(y), ba, mysimilar_random(x)) + extras = prepare_jacobian(f!, mysimilar(y), ba, mycopy_random(x)) jac_true = if ref_backend isa AbstractADType jacobian(f!, mysimilar(y), ref_backend, x) else @@ -735,7 +750,7 @@ function test_correctness( ) @compat (; f, x, y) = new_scen = deepcopy(scen) f! = f - extras = prepare_jacobian(f!, mysimilar(y), ba, mysimilar_random(x)) + extras = prepare_jacobian(f!, mysimilar(y), ba, mycopy_random(x)) jac_true = if ref_backend isa AbstractADType jacobian(f!, mysimilar(y), ref_backend, x) else @@ -778,7 +793,7 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_second_derivative(f, ba, mysimilar_random(x)) + extras = prepare_second_derivative(f, ba, mycopy_random(x)) der2_true = if ref_backend isa AbstractADType second_derivative(f, ref_backend, x) else @@ -808,7 +823,7 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_second_derivative(f, ba, mysimilar_random(x)) + extras = prepare_second_derivative(f, ba, mycopy_random(x)) der2_true = if ref_backend isa AbstractADType second_derivative(f, ref_backend, x) else @@ -842,21 +857,28 @@ function test_correctness( ref_backend, ) @compat (; f, x, dx) = new_scen = deepcopy(scen) - extras = prepare_hvp(f, ba, mysimilar_random(x), mysimilar_random(dx)) p_true = if ref_backend isa AbstractADType hvp(f, ref_backend, x, dx) else new_scen.ref(x, dx) end - p1 = hvp(f, ba, x, dx, extras) + for (k, extras) in enumerate([ + prepare_hvp(f, ba, mycopy_random(x), mycopy_random(dx)), + prepare_hvp_same_point(f, ba, x, mycopy_random(dx)), + ]) + testset_name = k == 1 ? "Different point" : "Same point" + @testset "$testset_name" begin + p1 = hvp(f, ba, x, dx, extras) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa HVPExtras - end - @testset "HVP value" begin - @test p1 ≈ p_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa HVPExtras + end + @testset "HVP value" begin + @test p1 ≈ p_true + end + end end end test_scen_intact(new_scen, scen) @@ -872,23 +894,30 @@ function test_correctness( ref_backend, ) @compat (; f, x, dx) = new_scen = deepcopy(scen) - extras = prepare_hvp(f, ba, mysimilar_random(x), mysimilar_random(dx)) p_true = if ref_backend isa AbstractADType hvp(f, ref_backend, x, dx) else new_scen.ref(x, dx) end - p1_in = mysimilar(x) - p1 = hvp!(f, p1_in, ba, x, dx, extras) + for (k, extras) in enumerate([ + prepare_hvp(f, ba, mycopy_random(x), mycopy_random(dx)), + prepare_hvp_same_point(f, ba, x, mycopy_random(dx)), + ]) + testset_name = k == 1 ? "Different point" : "Same point" + @testset "$testset_name" begin + p1_in = mysimilar(x) + p1 = hvp!(f, p1_in, ba, x, dx, extras) - let (≈)(x, y) = isapprox(x, y; atol, rtol) - @testset "Extras type" begin - @test extras isa HVPExtras - end - @testset "HVP value" begin - @test p1_in ≈ p_true - @test p1 ≈ p_true + let (≈)(x, y) = isapprox(x, y; atol, rtol) + @testset "Extras type" begin + @test extras isa HVPExtras + end + @testset "HVP value" begin + @test p1_in ≈ p_true + @test p1 ≈ p_true + end + end end end test_scen_intact(new_scen, scen) @@ -906,7 +935,7 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_hessian(f, ba, mysimilar_random(x)) + extras = prepare_hessian(f, ba, mycopy_random(x)) hess_true = if ref_backend isa AbstractADType hessian(f, ref_backend, x) else @@ -936,7 +965,7 @@ function test_correctness( ref_backend, ) @compat (; f, x, y) = new_scen = deepcopy(scen) - extras = prepare_hessian(f, ba, mysimilar_random(x)) + extras = prepare_hessian(f, ba, mycopy_random(x)) hess_true = if ref_backend isa AbstractADType hessian(f, ref_backend, x) else diff --git a/DifferentiationInterfaceTest/src/tests/type_stability.jl b/DifferentiationInterfaceTest/src/tests/type_stability.jl index a32e56209..ecdb97f0b 100644 --- a/DifferentiationInterfaceTest/src/tests/type_stability.jl +++ b/DifferentiationInterfaceTest/src/tests/type_stability.jl @@ -55,12 +55,9 @@ function test_jet(ba::AbstractADType, scen::PullbackScenario{1,:outofplace}; ref @compat (; f, x, dy) = deepcopy(scen) extras = prepare_pullback(f, ba, x, dy) - _, pullbackfunc = value_and_pullback_split(f, ba, x, extras) - if Bool(pullback_performance(ba)) JET.@test_opt value_and_pullback(f, ba, x, dy, extras) JET.@test_opt pullback(f, ba, x, dy, extras) - JET.@test_opt pullbackfunc(dy) end return nothing end @@ -70,12 +67,9 @@ function test_jet(ba::AbstractADType, scen::PullbackScenario{1,:inplace}; ref_ba extras = prepare_pullback(f, ba, x, dy) dx_in = mysimilar(x) - _, pullbackfunc! = value_and_pullback!_split(f, ba, x, extras) - if Bool(pullback_performance(ba)) JET.@test_opt value_and_pullback!(f, dx_in, ba, x, dy, extras) JET.@test_opt pullback!(f, dx_in, ba, x, dy, extras) - JET.@test_opt pullbackfunc!(dx_in, dy) end return nothing end @@ -86,12 +80,9 @@ function test_jet(ba::AbstractADType, scen::PullbackScenario{2,:outofplace}; ref extras = prepare_pullback(f!, mysimilar(y), ba, x, dy) y_in = mysimilar(y) - _, pullbackfunc = value_and_pullback_split(f!, y, ba, x, extras) - if Bool(pullback_performance(ba)) JET.@test_opt value_and_pullback(f!, y_in, ba, x, dy, extras) JET.@test_opt pullback(f!, y_in, ba, x, dy, extras) - JET.@test_opt pullbackfunc(y_in, dy) end return nothing end @@ -102,12 +93,9 @@ function test_jet(ba::AbstractADType, scen::PullbackScenario{2,:inplace}; ref_ba extras = prepare_pullback(f!, mysimilar(y), ba, x, dy) y_in, dx_in = mysimilar(y), mysimilar(x) - _, pullbackfunc! = value_and_pullback!_split(f!, y, ba, x, extras) - if Bool(pullback_performance(ba)) JET.@test_opt value_and_pullback!(f!, y_in, dx_in, ba, x, dy, extras) JET.@test_opt pullback!(f!, y_in, dx_in, ba, x, dy, extras) - JET.@test_opt pullbackfunc!(y_in, dx_in, dy) end return nothing end diff --git a/DifferentiationInterfaceTest/src/utils/misc.jl b/DifferentiationInterfaceTest/src/utils/misc.jl index 08acab64a..5de37dcc9 100644 --- a/DifferentiationInterfaceTest/src/utils/misc.jl +++ b/DifferentiationInterfaceTest/src/utils/misc.jl @@ -1,5 +1,4 @@ -# mysimilar(x::Number) = zero(x) mysimilar(x::AbstractArray) = similar(x) -mysimilar_random(x::Number) = randn(typeof(x)) -mysimilar_random(x::AbstractArray) = map(mysimilar_random, similar(x)) +mycopy_random(x::Number) = randn(typeof(x)) +mycopy_random(x::AbstractArray) = map(mycopy_random, x)