diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index e2709e23f..facb13ac4 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -10,8 +10,6 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -23,6 +21,8 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -38,6 +38,8 @@ DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences" DifferentiationInterfaceForwardDiffExt = "ForwardDiff" DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff" DifferentiationInterfaceReverseDiffExt = "ReverseDiff" +DifferentiationInterfaceSparseArraysExt = "SparseArrays" +DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings" DifferentiationInterfaceSymbolicsExt = "Symbolics" DifferentiationInterfaceTapirExt = "Tapir" DifferentiationInterfaceTrackerExt = "Tracker" diff --git a/DifferentiationInterface/docs/src/operators.md b/DifferentiationInterface/docs/src/operators.md index 062fd9d05..080ae0cc9 100644 --- a/DifferentiationInterface/docs/src/operators.md +++ b/DifferentiationInterface/docs/src/operators.md @@ -168,16 +168,17 @@ For this to work, three ingredients are needed (read [this survey](https://epubs - [`DenseSparsityDetector`](@ref) from DifferentiationInterface.jl (beware that this detector only gives a locally valid pattern) 3. A coloring algorithm: [`GreedyColoringAlgorithm`](@extref SparseMatrixColorings.GreedyColoringAlgorithm) from [SparseMatrixColorings.jl](https://github.com/gdalle/SparseMatrixColorings.jl) is the only one we support. +!!! warning + Generic sparse AD is now located in a package extension which depends on SparseMatrixColorings.jl. + These ingredients can be combined within the [`AutoSparse`](@extref ADTypes.AutoSparse) wrapper, which DifferentiationInterface.jl re-exports. Note that for sparse Hessians, you need to put the `SecondOrder` backend inside `AutoSparse`, and not the other way around. +`AutoSparse` backends only support operators [`jacobian`](@ref) and [`hessian`](@ref) (as well as their variants). The preparation step of `jacobian` or `hessian` with an `AutoSparse` backend can be long, because it needs to detect the sparsity pattern and color the resulting sparse matrix. But after preparation, the more zeros are present in the matrix, the greater the speedup will be compared to dense differentiation. !!! danger - `AutoSparse` backends only support operators [`jacobian`](@ref) and [`hessian`](@ref) (as well as their variants). - -!!! warning The result of preparation for an `AutoSparse` backend cannot be reused if the sparsity pattern changes. !!! info diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/DifferentiationInterfaceSparseArraysExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/DifferentiationInterfaceSparseArraysExt.jl new file mode 100644 index 000000000..5d44b4150 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/DifferentiationInterfaceSparseArraysExt.jl @@ -0,0 +1,13 @@ +module DifferentiationInterfaceSparseArraysExt + +using ADTypes: ADTypes +using Compat +using DifferentiationInterface +using DifferentiationInterface: + DenseSparsityDetector, PushforwardFast, PushforwardSlow, basis, pushforward_performance +import DifferentiationInterface as DI +using SparseArrays: SparseMatrixCSC, nonzeros, nzrange, rowvals, sparse + +include("sparsity_detector.jl") + +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/sparsity_detector.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/sparsity_detector.jl new file mode 100644 index 000000000..5e6448082 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseArraysExt/sparsity_detector.jl @@ -0,0 +1,114 @@ +## Direct + +function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector{:direct}) + @compat (; backend, atol) = detector + J = jacobian(f, backend, x) + return sparse(abs.(J) .> atol) +end + +function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector{:direct}) + @compat (; backend, atol) = detector + J = jacobian(f!, y, backend, x) + return sparse(abs.(J) .> atol) +end + +function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector{:direct}) + @compat (; backend, atol) = detector + H = hessian(f, backend, x) + return sparse(abs.(H) .> atol) +end + +## Iterative + +function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector{:iterative}) + @compat (; backend, atol) = detector + y = f(x) + n, m = length(x), length(y) + I, J = Int[], Int[] + if pushforward_performance(backend) isa PushforwardFast + p = similar(y) + extras = prepare_pushforward_same_point( + f, backend, x, basis(backend, x, first(CartesianIndices(x))) + ) + for (kj, j) in enumerate(CartesianIndices(x)) + pushforward!(f, p, extras, backend, x, basis(backend, x, j)) + for ki in LinearIndices(p) + if abs(p[ki]) > atol + push!(I, ki) + push!(J, kj) + end + end + end + else + p = similar(x) + extras = prepare_pullback_same_point( + f, backend, x, basis(backend, y, first(CartesianIndices(y))) + ) + for (ki, i) in enumerate(CartesianIndices(y)) + pullback!(f, p, extras, backend, x, basis(backend, y, i)) + for kj in LinearIndices(p) + if abs(p[kj]) > atol + push!(I, ki) + push!(J, kj) + end + end + end + end + return sparse(I, J, ones(Bool, length(I)), m, n) +end + +function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector{:iterative}) + @compat (; backend, atol) = detector + n, m = length(x), length(y) + I, J = Int[], Int[] + if pushforward_performance(backend) isa PushforwardFast + p = similar(y) + extras = prepare_pushforward_same_point( + f!, y, backend, x, basis(backend, x, first(CartesianIndices(x))) + ) + for (kj, j) in enumerate(CartesianIndices(x)) + pushforward!(f!, y, p, extras, backend, x, basis(backend, x, j)) + for ki in LinearIndices(p) + if abs(p[ki]) > atol + push!(I, ki) + push!(J, kj) + end + end + end + else + p = similar(x) + extras = prepare_pullback_same_point( + f!, y, backend, x, basis(backend, y, first(CartesianIndices(y))) + ) + for (ki, i) in enumerate(CartesianIndices(y)) + pullback!(f!, y, p, extras, backend, x, basis(backend, y, i)) + for kj in LinearIndices(p) + if abs(p[kj]) > atol + push!(I, ki) + push!(J, kj) + end + end + end + end + return sparse(I, J, ones(Bool, length(I)), m, n) +end + +function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector{:iterative}) + @compat (; backend, atol) = detector + n = length(x) + I, J = Int[], Int[] + p = similar(x) + extras = prepare_hvp_same_point( + f, backend, x, basis(backend, x, first(CartesianIndices(x))) + ) + for (kj, j) in enumerate(CartesianIndices(x)) + hvp!(f, p, extras, backend, x, basis(backend, x, j)) + for ki in LinearIndices(p) + if abs(p[ki]) > atol + push!(I, ki) + push!(J, kj) + end + end + end + return sparse(I, J, ones(Bool, length(I)), n, n) +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl new file mode 100644 index 000000000..d7459fefe --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl @@ -0,0 +1,47 @@ +module DifferentiationInterfaceSparseMatrixColoringsExt + +using ADTypes: + ADTypes, + AbstractADType, + AutoSparse, + dense_ad, + coloring_algorithm, + sparsity_detector, + jacobian_sparsity, + hessian_sparsity +using Compat +using DifferentiationInterface +using DifferentiationInterface: + GradientExtras, + HessianExtras, + HVPExtras, + JacobianExtras, + PullbackExtras, + PushforwardExtras, + PushforwardFast, + PushforwardSlow, + Tangents, + dense_ad, + maybe_dense_ad, + maybe_inner, + maybe_outer, + multibasis, + pick_batchsize, + pushforward_performance +import DifferentiationInterface as DI +using SparseMatrixColorings: + AbstractColoringResult, + ColoringProblem, + GreedyColoringAlgorithm, + coloring, + column_colors, + row_colors, + column_groups, + row_groups, + decompress, + decompress! + +include("jacobian.jl") +include("hessian.jl") + +end diff --git a/DifferentiationInterface/src/sparse/hessian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl similarity index 93% rename from DifferentiationInterface/src/sparse/hessian.jl rename to DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl index 42e3ffce4..e33a02ad5 100644 --- a/DifferentiationInterface/src/sparse/hessian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl @@ -35,7 +35,7 @@ end ## Hessian, one argument -function prepare_hessian(f::F, backend::AutoSparse, x) where {F} +function DI.prepare_hessian(f::F, backend::AutoSparse, x) where {F} dense_backend = dense_ad(backend) sparsity = hessian_sparsity(f, x, sparsity_detector(backend)) problem = ColoringProblem{:symmetric,:column}() @@ -64,7 +64,9 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F} ) end -function hessian(f::F, extras::SparseHessianExtras{B}, backend::AutoSparse, x) where {F,B} +function DI.hessian( + f::F, extras::SparseHessianExtras{B}, backend::AutoSparse, x +) where {F,B} @compat (; coloring_result, batched_seeds, hvp_extras) = extras dense_backend = dense_ad(backend) Ng = length(column_groups(coloring_result)) @@ -85,7 +87,7 @@ function hessian(f::F, extras::SparseHessianExtras{B}, backend::AutoSparse, x) w return decompress(compressed_matrix, coloring_result) end -function hessian!( +function DI.hessian!( f::F, hess, extras::SparseHessianExtras{B}, backend::AutoSparse, x ) where {F,B} @compat (; @@ -113,7 +115,7 @@ function hessian!( return hess end -function value_gradient_and_hessian!( +function DI.value_gradient_and_hessian!( f::F, grad, hess, extras::SparseHessianExtras, backend::AutoSparse, x ) where {F} y, _ = value_and_gradient!( @@ -123,7 +125,7 @@ function value_gradient_and_hessian!( return y, grad, hess end -function value_gradient_and_hessian( +function DI.value_gradient_and_hessian( f::F, extras::SparseHessianExtras, backend::AutoSparse, x ) where {F} y, grad = value_and_gradient( diff --git a/DifferentiationInterface/src/sparse/jacobian.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl similarity index 94% rename from DifferentiationInterface/src/sparse/jacobian.jl rename to DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl index b71821594..e1e62b9a1 100644 --- a/DifferentiationInterface/src/sparse/jacobian.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl @@ -60,14 +60,14 @@ function PullbackSparseJacobianExtras{B}(; ) end -function prepare_jacobian(f::F, backend::AutoSparse, x) where {F} +function DI.prepare_jacobian(f::F, backend::AutoSparse, x) where {F} y = f(x) return _prepare_sparse_jacobian_aux( (f,), backend, x, y, pushforward_performance(backend) ) end -function prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F} +function DI.prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F} return _prepare_sparse_jacobian_aux( (f!, y), backend, x, y, pushforward_performance(backend) ) @@ -137,23 +137,23 @@ end ## One argument -function jacobian(f::F, extras::SparseJacobianExtras, backend::AutoSparse, x) where {F} +function DI.jacobian(f::F, extras::SparseJacobianExtras, backend::AutoSparse, x) where {F} return _sparse_jacobian_aux((f,), extras, backend, x) end -function jacobian!( +function DI.jacobian!( f::F, jac, extras::SparseJacobianExtras, backend::AutoSparse, x ) where {F} return _sparse_jacobian_aux!((f,), jac, extras, backend, x) end -function value_and_jacobian( +function DI.value_and_jacobian( f::F, extras::SparseJacobianExtras, backend::AutoSparse, x ) where {F} return f(x), jacobian(f, extras, backend, x) end -function value_and_jacobian!( +function DI.value_and_jacobian!( f::F, jac, extras::SparseJacobianExtras, backend::AutoSparse, x ) where {F} return f(x), jacobian!(f, jac, extras, backend, x) @@ -161,17 +161,19 @@ end ## Two arguments -function jacobian(f!::F, y, extras::SparseJacobianExtras, backend::AutoSparse, x) where {F} +function DI.jacobian( + f!::F, y, extras::SparseJacobianExtras, backend::AutoSparse, x +) where {F} return _sparse_jacobian_aux((f!, y), extras, backend, x) end -function jacobian!( +function DI.jacobian!( f!::F, y, jac, extras::SparseJacobianExtras, backend::AutoSparse, x ) where {F} return _sparse_jacobian_aux!((f!, y), jac, extras, backend, x) end -function value_and_jacobian( +function DI.value_and_jacobian( f!::F, y, extras::SparseJacobianExtras, backend::AutoSparse, x ) where {F} jac = jacobian(f!, y, extras, backend, x) @@ -179,7 +181,7 @@ function value_and_jacobian( return y, jac end -function value_and_jacobian!( +function DI.value_and_jacobian!( f!::F, y, jac, extras::SparseJacobianExtras, backend::AutoSparse, x ) where {F} jacobian!(f!, y, jac, extras, backend, x) diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 7b155320a..588000c83 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -9,11 +9,16 @@ $(EXPORTS) """ module DifferentiationInterface -using ADTypes: ADTypes, AbstractADType -using ADTypes: mode, ForwardMode, ForwardOrReverseMode, ReverseMode, SymbolicMode -using ADTypes: AutoSparse, dense_ad -using ADTypes: coloring_algorithm -using ADTypes: sparsity_detector, jacobian_sparsity, hessian_sparsity +using ADTypes: + ADTypes, + AbstractADType, + AutoSparse, + ForwardMode, + ForwardOrReverseMode, + ReverseMode, + SymbolicMode, + dense_ad, + mode using ADTypes: AutoChainRules, AutoDiffractor, @@ -33,18 +38,6 @@ using DocStringExtensions using FillArrays: OneElement using LinearAlgebra: Symmetric, Transpose, dot, parent, transpose using PackageExtensionCompat: @require_extensions -using SparseArrays: SparseMatrixCSC, nonzeros, nzrange, rowvals, sparse -using SparseMatrixColorings: - AbstractColoringResult, - ColoringProblem, - GreedyColoringAlgorithm, - coloring, - column_colors, - row_colors, - column_groups, - row_groups, - decompress, - decompress! include("second_order/second_order.jl") @@ -70,13 +63,9 @@ include("second_order/hessian.jl") include("fallbacks/no_extras.jl") include("fallbacks/no_tangents.jl") -include("sparse/fallbacks.jl") -include("sparse/jacobian.jl") -include("sparse/hessian.jl") - include("misc/differentiate_with.jl") -include("misc/sparsity_detector.jl") include("misc/from_primitive.jl") +include("misc/sparsity_detector.jl") include("misc/zero_backends.jl") function __init__() @@ -136,10 +125,6 @@ export AutoZygote export AutoSparse -## Re-exported from SparseMatrixColorings - -export GreedyColoringAlgorithm - ## Public but not exported @compat public inner diff --git a/DifferentiationInterface/src/misc/sparsity_detector.jl b/DifferentiationInterface/src/misc/sparsity_detector.jl index 1aa81b048..6c5853dad 100644 --- a/DifferentiationInterface/src/misc/sparsity_detector.jl +++ b/DifferentiationInterface/src/misc/sparsity_detector.jl @@ -4,13 +4,14 @@ Sparsity pattern detector satisfying the [detection API](https://sciml.github.io/ADTypes.jl/stable/#Sparse-AD) of [ADTypes.jl](https://github.com/SciML/ADTypes.jl). The nonzeros in a Jacobian or Hessian are detected by computing the relevant matrix with _dense_ AD, and thresholding the entries with a given tolerance (which can be numerically inaccurate). - -!!! warning - This detector can be very slow, and should only be used if its output can be exploited multiple times to compute many sparse matrices. +This process can be very slow, and should only be used if its output can be exploited multiple times to compute many sparse matrices. !!! danger In general, the sparsity pattern you obtain can depend on the provided input `x`. If you want to reuse the pattern, make sure that it is input-agnostic. +!!! warning + `DenseSparsityDetector` functionality is now located in a package extension, please load the SparseArrays.jl standard library before you use it. + # Fields - `backend::AbstractADType` is the dense AD backend used under the hood @@ -94,118 +95,3 @@ function DenseSparsityDetector( end return DenseSparsityDetector{method,typeof(backend)}(backend, atol) end - -## Direct - -function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector{:direct}) - @compat (; backend, atol) = detector - J = jacobian(f, backend, x) - return sparse(abs.(J) .> atol) -end - -function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector{:direct}) - @compat (; backend, atol) = detector - J = jacobian(f!, y, backend, x) - return sparse(abs.(J) .> atol) -end - -function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector{:direct}) - @compat (; backend, atol) = detector - H = hessian(f, backend, x) - return sparse(abs.(H) .> atol) -end - -## Iterative - -function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector{:iterative}) - @compat (; backend, atol) = detector - y = f(x) - n, m = length(x), length(y) - I, J = Int[], Int[] - if pushforward_performance(backend) isa PushforwardFast - p = similar(y) - extras = prepare_pushforward_same_point( - f, backend, x, basis(backend, x, first(CartesianIndices(x))) - ) - for (kj, j) in enumerate(CartesianIndices(x)) - pushforward!(f, p, extras, backend, x, basis(backend, x, j)) - for ki in LinearIndices(p) - if abs(p[ki]) > atol - push!(I, ki) - push!(J, kj) - end - end - end - else - p = similar(x) - extras = prepare_pullback_same_point( - f, backend, x, basis(backend, y, first(CartesianIndices(y))) - ) - for (ki, i) in enumerate(CartesianIndices(y)) - pullback!(f, p, extras, backend, x, basis(backend, y, i)) - for kj in LinearIndices(p) - if abs(p[kj]) > atol - push!(I, ki) - push!(J, kj) - end - end - end - end - return sparse(I, J, ones(Bool, length(I)), m, n) -end - -function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector{:iterative}) - @compat (; backend, atol) = detector - n, m = length(x), length(y) - I, J = Int[], Int[] - if pushforward_performance(backend) isa PushforwardFast - p = similar(y) - extras = prepare_pushforward_same_point( - f!, y, backend, x, basis(backend, x, first(CartesianIndices(x))) - ) - for (kj, j) in enumerate(CartesianIndices(x)) - pushforward!(f!, y, p, extras, backend, x, basis(backend, x, j)) - for ki in LinearIndices(p) - if abs(p[ki]) > atol - push!(I, ki) - push!(J, kj) - end - end - end - else - p = similar(x) - extras = prepare_pullback_same_point( - f!, y, backend, x, basis(backend, y, first(CartesianIndices(y))) - ) - for (ki, i) in enumerate(CartesianIndices(y)) - pullback!(f!, y, p, extras, backend, x, basis(backend, y, i)) - for kj in LinearIndices(p) - if abs(p[kj]) > atol - push!(I, ki) - push!(J, kj) - end - end - end - end - return sparse(I, J, ones(Bool, length(I)), m, n) -end - -function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector{:iterative}) - @compat (; backend, atol) = detector - n = length(x) - I, J = Int[], Int[] - p = similar(x) - extras = prepare_hvp_same_point( - f, backend, x, basis(backend, x, first(CartesianIndices(x))) - ) - for (kj, j) in enumerate(CartesianIndices(x)) - hvp!(f, p, extras, backend, x, basis(backend, x, j)) - for ki in LinearIndices(p) - if abs(p[ki]) > atol - push!(I, ki) - push!(J, kj) - end - end - end - return sparse(I, J, ones(Bool, length(I)), n, n) -end diff --git a/DifferentiationInterface/src/sparse/fallbacks.jl b/DifferentiationInterface/src/sparse/fallbacks.jl deleted file mode 100644 index e568cfda5..000000000 --- a/DifferentiationInterface/src/sparse/fallbacks.jl +++ /dev/null @@ -1,5 +0,0 @@ -check_available(backend::AutoSparse) = check_available(dense_ad(backend)) -twoarg_support(backend::AutoSparse) = twoarg_support(dense_ad(backend)) -pushforward_performance(backend::AutoSparse) = pushforward_performance(dense_ad(backend)) -pullback_performance(backend::AutoSparse) = pullback_performance(dense_ad(backend)) -hvp_mode(backend::AutoSparse{<:SecondOrder}) = hvp_mode(dense_ad(backend)) diff --git a/DifferentiationInterface/src/utils/check.jl b/DifferentiationInterface/src/utils/check.jl index 6783d4e5f..00efa28ac 100644 --- a/DifferentiationInterface/src/utils/check.jl +++ b/DifferentiationInterface/src/utils/check.jl @@ -9,6 +9,8 @@ function check_available(backend::SecondOrder) return check_available(inner(backend)) && check_available(outer(backend)) end +check_available(backend::AutoSparse) = check_available(dense_ad(backend)) + """ check_twoarg(backend) diff --git a/DifferentiationInterface/src/utils/printing.jl b/DifferentiationInterface/src/utils/printing.jl index e30f58cdc..20d52c72b 100644 --- a/DifferentiationInterface/src/utils/printing.jl +++ b/DifferentiationInterface/src/utils/printing.jl @@ -5,14 +5,14 @@ function package_name(b::AbstractADType) return s[5:(k - 1)] end -package_name(b::AutoSparse) = package_name(dense_ad(b)) - function package_name(b::SecondOrder) p1 = package_name(outer(b)) p2 = package_name(inner(b)) return p1 == p2 ? p1 : "$p1, $p2" end +package_name(b::AutoSparse) = package_name(dense_ad(b)) + function document_preparation(operator_name::AbstractString; same_point=false) if same_point return "To improve performance via operator preparation, refer to [`prepare_$(operator_name)`](@ref) and [`prepare_$(operator_name)_same_point`](@ref)." diff --git a/DifferentiationInterface/src/utils/traits.jl b/DifferentiationInterface/src/utils/traits.jl index a33bf181d..1e3c3b626 100644 --- a/DifferentiationInterface/src/utils/traits.jl +++ b/DifferentiationInterface/src/utils/traits.jl @@ -31,6 +31,8 @@ function twoarg_support(backend::SecondOrder) end end +twoarg_support(backend::AutoSparse) = twoarg_support(dense_ad(backend)) + ## Pushforward abstract type PushforwardPerformance end @@ -59,6 +61,7 @@ pushforward_performance(::ForwardMode) = PushforwardFast() pushforward_performance(::ForwardOrReverseMode) = PushforwardFast() pushforward_performance(::ReverseMode) = PushforwardSlow() pushforward_performance(::SymbolicMode) = PushforwardFast() +pushforward_performance(backend::AutoSparse) = pushforward_performance(dense_ad(backend)) ## Pullback @@ -88,6 +91,7 @@ pullback_performance(::ForwardMode) = PullbackSlow() pullback_performance(::ForwardOrReverseMode) = PullbackFast() pullback_performance(::ReverseMode) = PullbackFast() pullback_performance(::SymbolicMode) = PullbackFast() +pullback_performance(backend::AutoSparse) = pullback_performance(dense_ad(backend)) ## HVP @@ -134,6 +138,8 @@ function hvp_mode(ba::SecondOrder) end end +hvp_mode(backend::AutoSparse{<:SecondOrder}) = hvp_mode(dense_ad(backend)) + ## Conversions Base.Bool(::TwoArgSupported) = true