Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Jacobians #22

Closed
wants to merge 14 commits into from
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.1.0"

[deps]
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
Expand All @@ -26,8 +27,9 @@ DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
ChainRulesCore = "1.19"
DiffResults = "1.1"
DocStringExtensions = "0.9"
FiniteDiff = "2.22"
Enzyme = "0.11"
FillArrays = "1"
FiniteDiff = "2.22"
ForwardDiff = "0.10"
LinearAlgebra = "1"
ReverseDiff = "1.15"
Expand Down
4 changes: 2 additions & 2 deletions ext/DifferentiationInterfaceChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ update!(_old::Number, new::Number) = new
update!(old, new) = old .= new

function DifferentiationInterface.value_and_pushforward!(
dy::Y, backend::ChainRulesForwardBackend, f, x::X, dx::X
dy::Y, backend::ChainRulesForwardBackend, f, x::X, dx
) where {X,Y}
rc = ruleconfig(backend)
y, new_dy = frule_via_ad(rc, (NoTangent(), dx), f, x)
return y, update!(dy, new_dy)
end

function DifferentiationInterface.value_and_pullback!(
dx::X, backend::ChainRulesReverseBackend, f, x::X, dy::Y
dx, backend::ChainRulesReverseBackend, f, x::X, dy::Y
) where {X,Y}
rc = ruleconfig(backend)
y, pullback = rrule_via_ad(rc, f, x)
Expand Down
32 changes: 29 additions & 3 deletions ext/DifferentiationInterfaceEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,30 @@ module DifferentiationInterfaceEnzymeExt
using DifferentiationInterface
using Enzyme: Forward, ReverseWithPrimal, Active, Duplicated, autodiff

const EnzymeBackends = Union{EnzymeForwardBackend,EnzymeReverseBackend}

## Unit vector

# Enzyme's `Duplicated(x, dx)` expects both arguments to be of the same type
function DifferentiationInterface.unitvector(
::EnzymeBackends, v::AbstractVector{T}, i
) where {T}
uv = zero(v)
uv[i] = one(T)
return uv
end

## Forward mode

function DifferentiationInterface.value_and_pushforward!(
_dy::Y, ::EnzymeForwardBackend, f, x::X, dx::X
_dy::Y, ::EnzymeForwardBackend, f, x::X, dx
) where {X,Y<:Real}
y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx))
return y, new_dy
end

function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::EnzymeForwardBackend, f, x::X, dx::X
dy::Y, ::EnzymeForwardBackend, f, x::X, dx
) where {X,Y<:AbstractArray}
y, new_dy = autodiff(Forward, f, Duplicated, Duplicated(x, dx))
dy .= new_dy
Expand All @@ -23,7 +36,7 @@ end
## Reverse mode

function DifferentiationInterface.value_and_pullback!(
_dx::X, ::EnzymeReverseBackend, f, x::X, dy::Y
_dx, ::EnzymeReverseBackend, f, x::X, dy::Y
) where {X<:Number,Y<:Union{Real,Nothing}}
der, y = autodiff(ReverseWithPrimal, f, Active, Active(x))
new_dx = dy * only(der)
Expand All @@ -39,4 +52,17 @@ function DifferentiationInterface.value_and_pullback!(
return y, dx
end

# Enzyme's Duplicated assumes x and dx to be of the same type.
# When writing into pre-allocated arrays, e.g. Jacobians,
# dx often is a view or SubArray.
# This requires a specialized method that allocates a new dx.
function DifferentiationInterface.value_and_pullback!(
dx, ::EnzymeReverseBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:Union{Real,Nothing}}
_dx = zero(x)
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, _dx))
@. dx = _dx * dy
return y, dx
end

end # module
8 changes: 4 additions & 4 deletions ext/DifferentiationInterfaceFiniteDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using LinearAlgebra: dot, mul!
const DEFAULT_FDTYPE = Val{:central}

function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
dy::Y, ::FiniteDiffBackend, f, x::X, dx
) where {X<:Number,Y<:Number}
y = f(x)
der = finite_difference_derivative(
Expand All @@ -26,7 +26,7 @@ function DifferentiationInterface.value_and_pushforward!(
end

function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
dy::Y, ::FiniteDiffBackend, f, x::X, dx
) where {X<:Number,Y<:AbstractArray}
y = f(x)
finite_difference_gradient!(
Expand All @@ -43,7 +43,7 @@ function DifferentiationInterface.value_and_pushforward!(
end

function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
dy::Y, ::FiniteDiffBackend, f, x::X, dx
) where {X<:AbstractArray,Y<:Number}
y = f(x)
g = finite_difference_gradient(
Expand All @@ -59,7 +59,7 @@ function DifferentiationInterface.value_and_pushforward!(
end

function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::FiniteDiffBackend, f, x::X, dx::X
dy::Y, ::FiniteDiffBackend, f, x::X, dx
) where {X<:AbstractArray,Y<:AbstractArray}
y = f(x)
J = finite_difference_jacobian(
Expand Down
8 changes: 4 additions & 4 deletions ext/DifferentiationInterfaceForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using ForwardDiff: Dual, Tag, value, extract_derivative, extract_derivative!
using LinearAlgebra: mul!

function DifferentiationInterface.value_and_pushforward!(
_dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
_dy::Y, ::ForwardDiffBackend, f, x::X, dx
) where {X<:Real,Y<:Real}
T = typeof(Tag(f, X))
xdual = Dual{T}(x, dx)
Expand All @@ -17,7 +17,7 @@ function DifferentiationInterface.value_and_pushforward!(
end

function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
dy::Y, ::ForwardDiffBackend, f, x::X, dx
) where {X<:Real,Y<:AbstractArray}
T = typeof(Tag(f, X))
xdual = Dual{T}(x, dx)
Expand All @@ -28,7 +28,7 @@ function DifferentiationInterface.value_and_pushforward!(
end

function DifferentiationInterface.value_and_pushforward!(
_dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
_dy::Y, ::ForwardDiffBackend, f, x::X, dx
) where {X<:AbstractArray,Y<:Real}
T = typeof(Tag(f, X)) # TODO: unsure
xdual = Dual{T}.(x, dx) # TODO: allocation
Expand All @@ -39,7 +39,7 @@ function DifferentiationInterface.value_and_pushforward!(
end

function DifferentiationInterface.value_and_pushforward!(
dy::Y, ::ForwardDiffBackend, f, x::X, dx::X
dy::Y, ::ForwardDiffBackend, f, x::X, dx
) where {X<:AbstractArray,Y<:AbstractArray}
T = typeof(Tag(f, X)) # TODO: unsure
xdual = Dual{T}.(x, dx) # TODO: allocation
Expand Down
4 changes: 2 additions & 2 deletions ext/DifferentiationInterfaceReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using ReverseDiff: gradient!, jacobian!
using LinearAlgebra: mul!

function DifferentiationInterface.value_and_pullback!(
dx::X, ::ReverseDiffBackend, f, x::X, dy::Y
dx, ::ReverseDiffBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:Real}
res = DiffResults.DiffResult(zero(Y), dx)
res = gradient!(res, f, x)
Expand All @@ -16,7 +16,7 @@ function DifferentiationInterface.value_and_pullback!(
end

function DifferentiationInterface.value_and_pullback!(
dx::X, ::ReverseDiffBackend, f, x::X, dy::Y
dx, ::ReverseDiffBackend, f, x::X, dy::Y
) where {X<:AbstractArray,Y<:AbstractArray}
res = DiffResults.DiffResult(similar(dy), similar(dy, length(dy), length(x)))
res = jacobian!(res, f, x)
Expand Down
5 changes: 5 additions & 0 deletions src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ $(EXPORTS)
module DifferentiationInterface

using DocStringExtensions
using FillArrays: OneElement

abstract type AbstractBackend end
abstract type AbstractForwardBackend <: AbstractBackend end
Expand All @@ -19,6 +20,8 @@ abstract type AbstractReverseBackend <: AbstractBackend end
include("backends.jl")
include("forward.jl")
include("reverse.jl")
include("unitvector.jl")
include("jacobian.jl")

export ChainRulesReverseBackend,
ChainRulesForwardBackend,
Expand All @@ -29,5 +32,7 @@ export ChainRulesReverseBackend,
ReverseDiffBackend
export pushforward!, value_and_pushforward!
export pullback!, value_and_pullback!
export jacobian!, value_and_jacobian!
export jacobian, value_and_jacobian

end # module
11 changes: 7 additions & 4 deletions src/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@ Compute a Jacobian-vector product inside `dy` and return it and the primal outpu
- `dx`: tangent
- `stuff`: optional backend-specific storage (cache, config), might be modified
"""
function value_and_pushforward! end
function value_and_pushforward!(dy::Y, backend::AbstractBackend, f, x::X, dx) where {X,Y}
return error("No package extension loaded for backend $backend.")
end

"""
pushforward!(dy, backend, f, x, dx[, stuff])

Compute a Jacobian-vector product inside `dy` and return it.
Compute a Jacobian-vector product inside `dy`.
Returns the primal output of `f(x)` and the JVP `dy`.

See [`value_and_pushforward!`](@ref).
"""
function pushforward!(dy, backend, f, x, dx, stuff)
_, dy = value_and_pushforward!(dy, backend, f, x, dx, stuff)
function pushforward!(dy, backend, f, x, dx)
_, dy = value_and_pushforward!(dy, backend, f, x, dx)
return dy
end
110 changes: 110 additions & 0 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
const DOC_JACOBIAN_SHAPE = "For a function `f: ℝⁿ → ℝᵐ`, `J` is returned as a `m × n` matrix."

## In-place mutating functions

"""
value_and_jacobian!(J, backend, f, x[, stuff])

Compute the Jacobian inside the pre-allocated matrix `J`.
$DOC_JACOBIAN_SHAPE
Returns the primal output of the computation `f(x)` and the corresponding Jacobian `J`.

See [`value_and_jacobian`](@ref), [`jacobian!`](@ref) and [`jacobian`](@ref).
"""
function value_and_jacobian!(J::AbstractMatrix, backend, f, x)
y = f(x)
nx, ny = length(x), length(y)
size(J) != (ny, nx) && throw(
DimensionMismatch("Size of Jacobian buffer doesn't match expected size ($ny, $nx)"),
)
return y, J = _value_and_jacobian!(J, backend, f, x, y)
end

function _value_and_jacobian!(J, backend::AbstractReverseBackend, f, x, y)
for i in axes(J, 1)
dy = unitvector(backend, y, i)
Jrow = reshapeview(J, (i, :), x) # view onto i-th row of J, reshaped to match x
pullback!(Jrow, backend, f, x, dy)
end
return y, J
end

function _value_and_jacobian!(J, backend::AbstractForwardBackend, f, x, y)
for i in axes(J, 2)
dx = unitvector(backend, x, i)
Jcol = reshapeview(J, (:, i), y) # view onto i-th column of J, reshaped to match y
pushforward!(Jcol, backend, f, x, dx)
end
return y, J
end

# Special case for scalar x since pullback! assumes it can't mutate dx.
function _value_and_jacobian!(J, backend::AbstractReverseBackend, f, x::Real, y)
dx = one(x) # place-holder for dispatch as it won't be mutated
for i in axes(J, 1)
dy = unitvector(backend, y, i)
J[i] = pullback!(dx, backend, f, x, dy) # J is of shape (length(x), 1)
end
return y, J
end

# Special case for scalar y since pushforward! assumes it can't mutate dy.
function _value_and_jacobian!(J, backend::AbstractForwardBackend, f, x, y::Real)
dy = one(y) # place-holder for dispatch as it won't be mutated
for i in axes(J, 2)
dx = unitvector(backend, x, i)
J[i] = pushforward!(dy, backend, f, x, dx) # J is of shape (1, length(x))
end
return y, J
end

reshapeview(A, inds, B) = reshape(view(A, inds...), size(B)...)

"""
jacobian!(J, backend, f, x[, stuff])

Compute the Jacobian of `f` at `x` inside the pre-allocated matrix `J` and return `J`.
$DOC_JACOBIAN_SHAPE

See [`value_and_jacobian!`](@ref), [`value_and_jacobian`](@ref) and [`jacobian`](@ref).
"""
function jacobian!(J::AbstractMatrix, backend::AbstractBackend, f, x)
_, J = value_and_jacobian!(J, backend, f, x)
return J
end

## Allocating functions

"""
value_and_jacobian(backend, f, x[, stuff])

Return the primal output of the computation `f(x)` and the corresponding Jacobian `J`.
$DOC_JACOBIAN_SHAPE

See [`value_and_jacobian!`](@ref), [`jacobian!`](@ref) and [`jacobian`](@ref).
"""
function value_and_jacobian(backend::AbstractBackend, f, x)
y = f(x)
J = allocate_jacobian_buffer(x, y)
return y, J = value_and_jacobian!(J, backend, f, x)
end

function allocate_jacobian_buffer(x, y)
# The type of a derivative is the type julia promotes dy/dx to
T = typeof(one(eltype(y)) / one(eltype(x)))
# For a function f: ℝⁿ → ℝᵐ , a matrix of size (m, n) is returned
return Matrix{T}(undef, length(y), length(x))
end

"""
jacobian(backend, f, x[, stuff])

Return the Jacobian `J` of function `f` at `x`.
$DOC_JACOBIAN_SHAPE

See [`value_and_jacobian`](@ref), [`value_and_jacobian!`](@ref) and [`jacobian!`](@ref).
"""
function jacobian(backend::AbstractBackend, f, x)
_, J = value_and_jacobian(backend, f, x)
return J
end
8 changes: 5 additions & 3 deletions src/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ Compute a vector-Jacobian product inside `dx` and return it and the primal outpu
- `dy`: cotangent
- `stuff`: optional backend-specific storage (cache, config), might be modified
"""
function value_and_pullback! end
function value_and_pullback!(dx, backend::AbstractBackend, f, x::X, dy::Y) where {X,Y}
return error("No package extension loaded for backend $backend.")
end

"""
pullback!(dx, backend, f, x, dy[, stuff])
Expand All @@ -22,7 +24,7 @@ Compute a vector-Jacobian product inside `dx` and return it.

See [`value_and_pullback!`](@ref).
"""
function pullback!(dx, backend, f, x, dy, stuff)
_, dx = value_and_pullback!(dx, backend, f, x, dy, stuff)
function pullback!(dx, backend, f, x, dy)
_, dx = value_and_pullback!(dx, backend, f, x, dy)
return dx
end
15 changes: 15 additions & 0 deletions src/unitvector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
unitvector(backend, v::AbstractVector, i)
unitvector(backend, v::Real, i)

Construct `i`-th stardard basis vector in the vector space of `v` with element type `eltype(v)`.
If `v` is a real number, one is returned.

## Note
If an AD backend benefits from a more specialized unit vector implementation,
this function can be extended on the backend type.
"""
function unitvector(::AbstractBackend, v::AbstractVector{T}, i) where {T}
return OneElement(one(T), i, length(v))
end
unitvector(::AbstractBackend, v::T, i) where {T<:Real} = one(v)
3 changes: 3 additions & 0 deletions test/diffractor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ using DifferentiationInterface
test_pushforward(
ChainRulesForwardBackend(Diffractor.DiffractorRuleConfig()); type_stability=false
)
test_jacobian(
ChainRulesForwardBackend(Diffractor.DiffractorRuleConfig()); type_stability=false
)
Loading
Loading