Skip to content

Commit

Permalink
Return primal value in pushforward! and pullback! (#17)
Browse files Browse the repository at this point in the history
* Return primal value in `pushforward!` and `pullback!`

* Update tests

* Allow BenchmarkCI workflow to write comments

* Try removing PR write permission

* Add `value_and_pushforward!` and `value_and_pullback!` functions

* Add fallback for `pullback!` and `pushforward!`, restructure src

* Rm developer.md

* Skip broken tests

* Fix tests and add DiffResults as trigger

---------

Co-authored-by: Guillaume Dalle <[email protected]>
  • Loading branch information
adrhill and gdalle authored Feb 28, 2024
1 parent e37adca commit 31d1be9
Show file tree
Hide file tree
Showing 13 changed files with 338 additions and 202 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -18,11 +19,12 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore"
DifferentiationInterfaceEnzymeExt = "Enzyme"
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]

[compat]
ChainRulesCore = "1.19"
DiffResults = "1.1"
DocStringExtensions = "0.9"
FiniteDiff = "2.22"
Enzyme = "0.11"
Expand Down
12 changes: 6 additions & 6 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,30 +33,30 @@ SUITE = BenchmarkGroup()
for n in n_values
for backend in forward_backends
SUITE["forward"]["scalar_to_scalar"][n][string(backend)] = @benchmarkable begin
pushforward!(dy, $backend, scalar_to_scalar, x, dx)
value_and_pushforward!(dy, $backend, scalar_to_scalar, x, dx)
end setup = (x = 1.0; dx = 1.0; dy = 0.0) evals = 1
if backend != EnzymeForwardBackend() # type instability?
SUITE["forward"]["scalar_to_vector"][n][string(backend)] = @benchmarkable begin
pushforward!(dy, $backend, Fix2(scalar_to_vector, $n), x, dx)
value_and_pushforward!(dy, $backend, Fix2(scalar_to_vector, $n), x, dx)
end setup = (x = 1.0; dx = 1.0; dy = zeros($n)) evals = 1
end
SUITE["forward"]["vector_to_vector"][n][string(backend)] = @benchmarkable begin
pushforward!(dy, $backend, vector_to_vector, x, dx)
value_and_pushforward!(dy, $backend, vector_to_vector, x, dx)
end setup = (x = randn($n); dx = randn($n); dy = zeros($n)) evals = 1
end

for backend in reverse_backends
if backend != ReverseDiffBackend()
SUITE["reverse"]["scalar_to_scalar"][n][string(backend)] = @benchmarkable begin
pullback!(dx, $backend, scalar_to_scalar, x, dy)
value_and_pullback!(dx, $backend, scalar_to_scalar, x, dy)
end setup = (x = 1.0; dy = 1.0; dx = 0.0) evals = 1
end
SUITE["reverse"]["vector_to_scalar"][n][string(backend)] = @benchmarkable begin
pullback!(dx, $backend, vector_to_scalar, x, dy)
value_and_pullback!(dx, $backend, vector_to_scalar, x, dy)
end setup = (x = randn($n); dy = 1.0; dx = zeros($n)) evals = 1
if backend != EnzymeReverseBackend()
SUITE["reverse"]["vector_to_vector"][n][string(backend)] = @benchmarkable begin
pullback!(dx, $backend, vector_to_vector, x, dy)
value_and_pullback!(dx, $backend, vector_to_vector, x, dy)
end setup = (x = randn($n); dy = randn($n); dx = zeros($n)) evals = 1
end
end
Expand Down
24 changes: 12 additions & 12 deletions ext/DifferentiationInterfaceChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,40 @@ using LinearAlgebra
ruleconfig(backend::ChainRulesForwardBackend) = backend.ruleconfig
ruleconfig(backend::ChainRulesReverseBackend) = backend.ruleconfig

function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
_dy::Y, backend::ChainRulesForwardBackend, f, x::X, dx::X
) where {X,Y<:Number}
rc = ruleconfig(backend)
_, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
return new_dy
y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
return y, new_dy
end

function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
dy::Y, backend::ChainRulesForwardBackend, f, x::X, dx::X
) where {X,Y<:AbstractArray}
rc = ruleconfig(backend)
_, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
dy .= new_dy
return dy
return y, dy
end

function DifferentiationInterface.pullback!(
function DifferentiationInterface.value_and_pullback!(
_dx::X, backend::ChainRulesReverseBackend, f, x::X, dy::Y
) where {X<:Number,Y}
rc = ruleconfig(backend)
_, pullback = rrule_via_ad(rc, f, x)
y, pullback = rrule_via_ad(rc, f, x)
_, new_dx = pullback(dy)
return new_dx
return y, new_dx
end

function DifferentiationInterface.pullback!(
function DifferentiationInterface.value_and_pullback!(
dx::X, backend::ChainRulesReverseBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y}
rc = ruleconfig(backend)
_, pullback = rrule_via_ad(rc, f, x)
y, pullback = rrule_via_ad(rc, f, x)
_, new_dx = pullback(dy)
dx .= new_dx
return dx
return y, dx
end

end
24 changes: 14 additions & 10 deletions ext/DifferentiationInterfaceEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,43 +9,47 @@ using Enzyme
"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
_dy::Y, ::EnzymeForwardBackend, f, x::X, dx::X
) where {X,Y<:Real}
return only(autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, dx)))
y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx))
return y, new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::EnzymeForwardBackend, f, x::X, dx::X
) where {X,Y<:AbstractArray}
dy .= only(autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, dx)))
return dy
y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx))
dy .= new_dy
return y, dy
end

## Reverse-mode

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pullback!(
function DifferentiationInterface.value_and_pullback!(
_dx::X, ::EnzymeReverseBackend, f, x::X, dy::Y
) where {X<:Number,Y<:Union{Real,Nothing}}
return only(first(autodiff(Reverse, f, Active, Active(x)))) * dy
der, y = autodiff(ReverseWithPrimal, f, Active, Active(x))
new_dx = dy * only(der)
return y, new_dx
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pullback!(
function DifferentiationInterface.value_and_pullback!(
dx::X, ::EnzymeReverseBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:Union{Real,Nothing}}
dx .= zero(eltype(dx))
autodiff(Reverse, f, Active, Duplicated(x, dx))
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx))
dx .*= dy
return dx
return y, dx
end

end # module
59 changes: 46 additions & 13 deletions ext/DifferentiationInterfaceFiniteDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,80 @@ using DocStringExtensions
using FiniteDiff
using LinearAlgebra

const DEFAULT_FDTYPE = Val{:central}

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
) where {X<:Number,Y<:Number}
new_dy = FiniteDiff.finite_difference_derivative(f, x) * dx
return new_dy
y = f(x)
der = FiniteDiff.finite_difference_derivative(
f,
x,
DEFAULT_FDTYPE, # fdtype
eltype(dy), # returntype
y, # fx
)
new_dy = der * dx
return y, new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
) where {X<:Number,Y<:AbstractArray}
new_dy = FiniteDiff.finite_difference_derivative(f, x)
dy .= new_dy .* dx
return dy
y = f(x)
FiniteDiff.finite_difference_gradient!(
dy,
f,
x,
DEFAULT_FDTYPE, # fdtype
eltype(dy), # returntype
Val{false}, # inplace
y, # fx
)
dy .*= dx
return y, dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
) where {X<:AbstractArray,Y<:Number}
g = FiniteDiff.finite_difference_gradient(f, x)
y = f(x)
g = FiniteDiff.finite_difference_gradient(
f,
x,
DEFAULT_FDTYPE, # fdtype
eltype(dy), # returntype
Val{false}, # inplace
y, # fx
)
new_dy = dot(g, dx)
return new_dy
return y, new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
) where {X<:AbstractArray,Y<:AbstractArray}
J = FiniteDiff.finite_difference_jacobian(f, x)
y = f(x)
J = FiniteDiff.finite_difference_jacobian(
f,
x,
DEFAULT_FDTYPE, # fdtype
eltype(dy), # returntype
)
mul!(dy, J, dx)
return dy
return y, dy
end

end # module
52 changes: 35 additions & 17 deletions ext/DifferentiationInterfaceForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,51 +1,69 @@
module DifferentiationInterfaceForwardDiffExt

using DifferentiationInterface
using DiffResults
using DocStringExtensions
using ForwardDiff
using ForwardDiff: Dual, Tag, value, extract_derivative, extract_derivative!
using LinearAlgebra

function extract_value(::Type{T}, ydual) where {T}
return value.(T, ydual)
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
function DifferentiationInterface.value_and_pushforward!(
_dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
) where {X<:Real,Y<:Real}
new_dy = ForwardDiff.derivative(f, x) * dx
return new_dy
T = typeof(Tag(f, X))
xdual = Dual{T}(x, dx)
ydual = f(xdual)
y = extract_value(T, ydual)
new_dy = extract_derivative(T, ydual)
return y, new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
) where {X<:Real,Y<:AbstractArray}
ForwardDiff.derivative!(dy, f, x)
dy .*= dx
return dy
T = typeof(Tag(f, X))
xdual = Dual{T}(x, dx)
ydual = f(xdual)
y = extract_value(T, ydual)
dy = extract_derivative!(T, dy, ydual)
return y, dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
function DifferentiationInterface.value_and_pushforward!(
_dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
) where {X<:AbstractArray,Y<:Real}
g = ForwardDiff.gradient(f, x) # TODO: replace with duals, n times too slow
new_dy = dot(g, dx)
return new_dy
res = DiffResults.GradientResult(x)
ForwardDiff.gradient!(res, f, x)
y = DiffResults.value(res)
new_dy = dot(DiffResults.gradient(res), dx)
return y, new_dy
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pushforward!(
function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
) where {X<:AbstractArray,Y<:AbstractArray}
J = ForwardDiff.jacobian(f, x) # TODO: replace with duals, n times too slow
res = DiffResults.JacobianResult(x)
ForwardDiff.jacobian!(res, f, x) # TODO: replace with duals, n times too slow
y = DiffResults.value(res)
J = DiffResults.jacobian(res)
mul!(dy, J, dx)
return dy
return y, dy
end

end
end # module
22 changes: 14 additions & 8 deletions ext/DifferentiationInterfaceReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,36 @@
module DifferentiationInterfaceReverseDiffExt

using DifferentiationInterface
using DiffResults
using DocStringExtensions
using ReverseDiff
using LinearAlgebra

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pullback!(
function DifferentiationInterface.value_and_pullback!(
dx::X, ::ReverseDiffBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:Real}
ReverseDiff.gradient!(dx, f, x)
dx .*= dy
return dx
res = DiffResults.GradientResult(x)
ReverseDiff.gradient!(res, f, x)
y = DiffResults.value(res)
dx .= dy .* DiffResults.gradient(res)
return y, dx
end

"""
$(TYPEDSIGNATURES)
"""
function DifferentiationInterface.pullback!(
function DifferentiationInterface.value_and_pullback!(
dx::X, ::ReverseDiffBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:AbstractArray}
J = ReverseDiff.jacobian(f, x)
res = DiffResults.JacobianResult(x)
ReverseDiff.jacobian!(res, f, x)
y = DiffResults.value(res)
J = DiffResults.jacobian(res)
mul!(dx, transpose(J), dy)
return dx
return y, dx
end

end
end # module
Loading

0 comments on commit 31d1be9

Please sign in to comment.