Skip to content

Commit

Permalink
[BREAKING] Move sparse functionality into package extensions (#448)
Browse files Browse the repository at this point in the history
* Move extras in core code

* Update backend extensions

* Update docs

* Typos

* Typos

* Fixes

* Typos

* Fix

* Fix ForwardDiff

* Fixes

* Fixes

* Fix Enzyme

* Bump versions and compats

* Move sparse functionality to extensions

* Remove prefixes and add docs
  • Loading branch information
gdalle authored Sep 5, 2024
1 parent de23245 commit 1069266
Show file tree
Hide file tree
Showing 13 changed files with 226 additions and 171 deletions.
6 changes: 4 additions & 2 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -23,6 +21,8 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -38,6 +38,8 @@ DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
DifferentiationInterfaceSymbolicsExt = "Symbolics"
DifferentiationInterfaceTapirExt = "Tapir"
DifferentiationInterfaceTrackerExt = "Tracker"
Expand Down
7 changes: 4 additions & 3 deletions DifferentiationInterface/docs/src/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,17 @@ For this to work, three ingredients are needed (read [this survey](https://epubs
- [`DenseSparsityDetector`](@ref) from DifferentiationInterface.jl (beware that this detector only gives a locally valid pattern)
3. A coloring algorithm: [`GreedyColoringAlgorithm`](@extref SparseMatrixColorings.GreedyColoringAlgorithm) from [SparseMatrixColorings.jl](https://github.com/gdalle/SparseMatrixColorings.jl) is the only one we support.

!!! warning
Generic sparse AD is now located in a package extension which depends on SparseMatrixColorings.jl.

These ingredients can be combined within the [`AutoSparse`](@extref ADTypes.AutoSparse) wrapper, which DifferentiationInterface.jl re-exports.
Note that for sparse Hessians, you need to put the `SecondOrder` backend inside `AutoSparse`, and not the other way around.
`AutoSparse` backends only support operators [`jacobian`](@ref) and [`hessian`](@ref) (as well as their variants).

The preparation step of `jacobian` or `hessian` with an `AutoSparse` backend can be long, because it needs to detect the sparsity pattern and color the resulting sparse matrix.
But after preparation, the more zeros are present in the matrix, the greater the speedup will be compared to dense differentiation.

!!! danger
`AutoSparse` backends only support operators [`jacobian`](@ref) and [`hessian`](@ref) (as well as their variants).

!!! warning
The result of preparation for an `AutoSparse` backend cannot be reused if the sparsity pattern changes.

!!! info
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module DifferentiationInterfaceSparseArraysExt

using ADTypes: ADTypes
using Compat
using DifferentiationInterface
using DifferentiationInterface:
DenseSparsityDetector, PushforwardFast, PushforwardSlow, basis, pushforward_performance
import DifferentiationInterface as DI
using SparseArrays: SparseMatrixCSC, nonzeros, nzrange, rowvals, sparse

include("sparsity_detector.jl")

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
## Direct

function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector{:direct})
@compat (; backend, atol) = detector
J = jacobian(f, backend, x)
return sparse(abs.(J) .> atol)
end

function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector{:direct})
@compat (; backend, atol) = detector
J = jacobian(f!, y, backend, x)
return sparse(abs.(J) .> atol)
end

function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector{:direct})
@compat (; backend, atol) = detector
H = hessian(f, backend, x)
return sparse(abs.(H) .> atol)
end

## Iterative

function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector{:iterative})
@compat (; backend, atol) = detector
y = f(x)
n, m = length(x), length(y)
I, J = Int[], Int[]
if pushforward_performance(backend) isa PushforwardFast
p = similar(y)
extras = prepare_pushforward_same_point(
f, backend, x, basis(backend, x, first(CartesianIndices(x)))
)
for (kj, j) in enumerate(CartesianIndices(x))
pushforward!(f, p, extras, backend, x, basis(backend, x, j))
for ki in LinearIndices(p)
if abs(p[ki]) > atol
push!(I, ki)
push!(J, kj)
end
end
end
else
p = similar(x)
extras = prepare_pullback_same_point(
f, backend, x, basis(backend, y, first(CartesianIndices(y)))
)
for (ki, i) in enumerate(CartesianIndices(y))
pullback!(f, p, extras, backend, x, basis(backend, y, i))
for kj in LinearIndices(p)
if abs(p[kj]) > atol
push!(I, ki)
push!(J, kj)
end
end
end
end
return sparse(I, J, ones(Bool, length(I)), m, n)
end

function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector{:iterative})
@compat (; backend, atol) = detector
n, m = length(x), length(y)
I, J = Int[], Int[]
if pushforward_performance(backend) isa PushforwardFast
p = similar(y)
extras = prepare_pushforward_same_point(
f!, y, backend, x, basis(backend, x, first(CartesianIndices(x)))
)
for (kj, j) in enumerate(CartesianIndices(x))
pushforward!(f!, y, p, extras, backend, x, basis(backend, x, j))
for ki in LinearIndices(p)
if abs(p[ki]) > atol
push!(I, ki)
push!(J, kj)
end
end
end
else
p = similar(x)
extras = prepare_pullback_same_point(
f!, y, backend, x, basis(backend, y, first(CartesianIndices(y)))
)
for (ki, i) in enumerate(CartesianIndices(y))
pullback!(f!, y, p, extras, backend, x, basis(backend, y, i))
for kj in LinearIndices(p)
if abs(p[kj]) > atol
push!(I, ki)
push!(J, kj)
end
end
end
end
return sparse(I, J, ones(Bool, length(I)), m, n)
end

function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector{:iterative})
@compat (; backend, atol) = detector
n = length(x)
I, J = Int[], Int[]
p = similar(x)
extras = prepare_hvp_same_point(
f, backend, x, basis(backend, x, first(CartesianIndices(x)))
)
for (kj, j) in enumerate(CartesianIndices(x))
hvp!(f, p, extras, backend, x, basis(backend, x, j))
for ki in LinearIndices(p)
if abs(p[ki]) > atol
push!(I, ki)
push!(J, kj)
end
end
end
return sparse(I, J, ones(Bool, length(I)), n, n)
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
module DifferentiationInterfaceSparseMatrixColoringsExt

using ADTypes:
ADTypes,
AbstractADType,
AutoSparse,
dense_ad,
coloring_algorithm,
sparsity_detector,
jacobian_sparsity,
hessian_sparsity
using Compat
using DifferentiationInterface
using DifferentiationInterface:
GradientExtras,
HessianExtras,
HVPExtras,
JacobianExtras,
PullbackExtras,
PushforwardExtras,
PushforwardFast,
PushforwardSlow,
Tangents,
dense_ad,
maybe_dense_ad,
maybe_inner,
maybe_outer,
multibasis,
pick_batchsize,
pushforward_performance
import DifferentiationInterface as DI
using SparseMatrixColorings:
AbstractColoringResult,
ColoringProblem,
GreedyColoringAlgorithm,
coloring,
column_colors,
row_colors,
column_groups,
row_groups,
decompress,
decompress!

include("jacobian.jl")
include("hessian.jl")

end
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ end

## Hessian, one argument

function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
function DI.prepare_hessian(f::F, backend::AutoSparse, x) where {F}
dense_backend = dense_ad(backend)
sparsity = hessian_sparsity(f, x, sparsity_detector(backend))
problem = ColoringProblem{:symmetric,:column}()
Expand Down Expand Up @@ -64,7 +64,9 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F}
)
end

function hessian(f::F, extras::SparseHessianExtras{B}, backend::AutoSparse, x) where {F,B}
function DI.hessian(
f::F, extras::SparseHessianExtras{B}, backend::AutoSparse, x
) where {F,B}
@compat (; coloring_result, batched_seeds, hvp_extras) = extras
dense_backend = dense_ad(backend)
Ng = length(column_groups(coloring_result))
Expand All @@ -85,7 +87,7 @@ function hessian(f::F, extras::SparseHessianExtras{B}, backend::AutoSparse, x) w
return decompress(compressed_matrix, coloring_result)
end

function hessian!(
function DI.hessian!(
f::F, hess, extras::SparseHessianExtras{B}, backend::AutoSparse, x
) where {F,B}
@compat (;
Expand Down Expand Up @@ -113,7 +115,7 @@ function hessian!(
return hess
end

function value_gradient_and_hessian!(
function DI.value_gradient_and_hessian!(
f::F, grad, hess, extras::SparseHessianExtras, backend::AutoSparse, x
) where {F}
y, _ = value_and_gradient!(
Expand All @@ -123,7 +125,7 @@ function value_gradient_and_hessian!(
return y, grad, hess
end

function value_gradient_and_hessian(
function DI.value_gradient_and_hessian(
f::F, extras::SparseHessianExtras, backend::AutoSparse, x
) where {F}
y, grad = value_and_gradient(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ function PullbackSparseJacobianExtras{B}(;
)
end

function prepare_jacobian(f::F, backend::AutoSparse, x) where {F}
function DI.prepare_jacobian(f::F, backend::AutoSparse, x) where {F}
y = f(x)
return _prepare_sparse_jacobian_aux(
(f,), backend, x, y, pushforward_performance(backend)
)
end

function prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F}
function DI.prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F}
return _prepare_sparse_jacobian_aux(
(f!, y), backend, x, y, pushforward_performance(backend)
)
Expand Down Expand Up @@ -137,49 +137,51 @@ end

## One argument

function jacobian(f::F, extras::SparseJacobianExtras, backend::AutoSparse, x) where {F}
function DI.jacobian(f::F, extras::SparseJacobianExtras, backend::AutoSparse, x) where {F}
return _sparse_jacobian_aux((f,), extras, backend, x)
end

function jacobian!(
function DI.jacobian!(
f::F, jac, extras::SparseJacobianExtras, backend::AutoSparse, x
) where {F}
return _sparse_jacobian_aux!((f,), jac, extras, backend, x)
end

function value_and_jacobian(
function DI.value_and_jacobian(
f::F, extras::SparseJacobianExtras, backend::AutoSparse, x
) where {F}
return f(x), jacobian(f, extras, backend, x)
end

function value_and_jacobian!(
function DI.value_and_jacobian!(
f::F, jac, extras::SparseJacobianExtras, backend::AutoSparse, x
) where {F}
return f(x), jacobian!(f, jac, extras, backend, x)
end

## Two arguments

function jacobian(f!::F, y, extras::SparseJacobianExtras, backend::AutoSparse, x) where {F}
function DI.jacobian(
f!::F, y, extras::SparseJacobianExtras, backend::AutoSparse, x
) where {F}
return _sparse_jacobian_aux((f!, y), extras, backend, x)
end

function jacobian!(
function DI.jacobian!(
f!::F, y, jac, extras::SparseJacobianExtras, backend::AutoSparse, x
) where {F}
return _sparse_jacobian_aux!((f!, y), jac, extras, backend, x)
end

function value_and_jacobian(
function DI.value_and_jacobian(
f!::F, y, extras::SparseJacobianExtras, backend::AutoSparse, x
) where {F}
jac = jacobian(f!, y, extras, backend, x)
f!(y, x)
return y, jac
end

function value_and_jacobian!(
function DI.value_and_jacobian!(
f!::F, y, jac, extras::SparseJacobianExtras, backend::AutoSparse, x
) where {F}
jacobian!(f!, y, jac, extras, backend, x)
Expand Down
Loading

0 comments on commit 1069266

Please sign in to comment.