From 2751235a3e40887df353e804425ecd4ff108e8b9 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 16 Apr 2024 18:22:11 +0200 Subject: [PATCH] Fallback from sparse to dense (#183) * Fallback from sparse to dense * Fix ambiguities * Missing Any * Fix more fixes --- DifferentiationInterface/docs/src/backends.md | 3 +- ...ntiationInterfaceFastDifferentiationExt.jl | 8 +- .../onearg.jl | 57 +++----- .../twoarg.jl | 28 ++-- ...ferentiationInterfaceSparseDiffToolsExt.jl | 25 +--- .../onearg.jl | 19 ++- .../twoarg.jl | 20 ++- .../DifferentiationInterfaceSymbolicsExt.jl | 8 +- .../onearg.jl | 34 ++--- .../twoarg.jl | 22 ++- .../DifferentiationInterfaceZygoteExt.jl | 2 +- .../src/DifferentiationInterface.jl | 1 + DifferentiationInterface/src/sparse.jl | 138 ++++++++++++++++++ DifferentiationInterface/test/first_order.jl | 17 ++- DifferentiationInterface/test/second_order.jl | 24 ++- DifferentiationInterface/test/sparsity.jl | 2 + 16 files changed, 269 insertions(+), 139 deletions(-) create mode 100644 DifferentiationInterface/src/sparse.jl diff --git a/DifferentiationInterface/docs/src/backends.md b/DifferentiationInterface/docs/src/backends.md index 0df94ddd3..11aca2231 100644 --- a/DifferentiationInterface/docs/src/backends.md +++ b/DifferentiationInterface/docs/src/backends.md @@ -65,8 +65,7 @@ AutoZygote ### Sparse -!!! warning - For sparse backends, only the Jacobian and Hessian operators are implemented. +For sparse backends, only the Jacobian and Hessian operators are implemented differently, the other operators behave the same as for the corresponding dense backend. ```@docs AutoSparseFastDifferentiation diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl index 76397f3b5..be4605f46 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl @@ -31,10 +31,10 @@ const AnyAutoFastDifferentiation = Union{ AutoFastDifferentiation,AutoSparseFastDifferentiation } -DI.check_available(::AnyAutoFastDifferentiation) = true -DI.mode(::AnyAutoFastDifferentiation) = ADTypes.AbstractSymbolicDifferentiationMode -DI.pushforward_performance(::AnyAutoFastDifferentiation) = DI.PushforwardFast() -DI.pullback_performance(::AnyAutoFastDifferentiation) = DI.PullbackSlow() +DI.check_available(::AutoFastDifferentiation) = true +DI.mode(::AutoFastDifferentiation) = ADTypes.AbstractSymbolicDifferentiationMode +DI.pushforward_performance(::AutoFastDifferentiation) = DI.PushforwardFast() +DI.pullback_performance(::AutoFastDifferentiation) = DI.PullbackSlow() monovec(x::Number) = Fill(x, 1) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl index ea4bcbf48..e1d678bee 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl @@ -6,7 +6,7 @@ struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E2} <: PushforwardExtras jvp_exe!::E2 end -function DI.prepare_pushforward(f, ::AnyAutoFastDifferentiation, x, dx) +function DI.prepare_pushforward(f, ::AutoFastDifferentiation, x, dx) y_prototype = f(x) x_var = if x isa Number only(make_variables(:x)) @@ -24,11 +24,7 @@ function DI.prepare_pushforward(f, ::AnyAutoFastDifferentiation, x, dx) end function DI.pushforward( - f, - ::AnyAutoFastDifferentiation, - x, - dx, - extras::FastDifferentiationOneArgPushforwardExtras, + f, ::AutoFastDifferentiation, x, dx, extras::FastDifferentiationOneArgPushforwardExtras ) v_vec = vcat(myvec(x), myvec(dx)) if extras.y_prototype isa Number @@ -41,7 +37,7 @@ end function DI.pushforward!( f, dy, - ::AnyAutoFastDifferentiation, + ::AutoFastDifferentiation, x, dx, extras::FastDifferentiationOneArgPushforwardExtras, @@ -53,7 +49,7 @@ end function DI.value_and_pushforward( f, - backend::AnyAutoFastDifferentiation, + backend::AutoFastDifferentiation, x, dx, extras::FastDifferentiationOneArgPushforwardExtras, @@ -64,7 +60,7 @@ end function DI.value_and_pushforward!( f, dy, - backend::AnyAutoFastDifferentiation, + backend::AutoFastDifferentiation, x, dx, extras::FastDifferentiationOneArgPushforwardExtras, @@ -84,7 +80,7 @@ struct FastDifferentiationOneArgDerivativeExtras{Y,E1,E2} <: DerivativeExtras der_exe!::E2 end -function DI.prepare_derivative(f, ::AnyAutoFastDifferentiation, x) +function DI.prepare_derivative(f, ::AutoFastDifferentiation, x) y_prototype = f(x) x_var = only(make_variables(:x)) y_var = f(x_var) @@ -98,7 +94,7 @@ function DI.prepare_derivative(f, ::AnyAutoFastDifferentiation, x) end function DI.derivative( - f, ::AnyAutoFastDifferentiation, x, extras::FastDifferentiationOneArgDerivativeExtras + f, ::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgDerivativeExtras ) if extras.y_prototype isa Number return only(extras.der_exe(monovec(x))) @@ -108,11 +104,7 @@ function DI.derivative( end function DI.derivative!( - f, - der, - ::AnyAutoFastDifferentiation, - x, - extras::FastDifferentiationOneArgDerivativeExtras, + f, der, ::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgDerivativeExtras ) extras.der_exe!(vec(der), monovec(x)) return der @@ -120,7 +112,7 @@ end function DI.value_and_derivative( f, - backend::AnyAutoFastDifferentiation, + backend::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgDerivativeExtras, ) @@ -130,7 +122,7 @@ end function DI.value_and_derivative!( f, der, - backend::AnyAutoFastDifferentiation, + backend::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgDerivativeExtras, ) @@ -144,7 +136,7 @@ struct FastDifferentiationOneArgGradientExtras{E1,E2} <: GradientExtras jac_exe!::E2 end -function DI.prepare_gradient(f, backend::AnyAutoFastDifferentiation, x) +function DI.prepare_gradient(f, backend::AutoFastDifferentiation, x) y_prototype = f(x) x_var = make_variables(:x, size(x)...) y_var = f(x_var) @@ -158,7 +150,7 @@ function DI.prepare_gradient(f, backend::AnyAutoFastDifferentiation, x) end function DI.gradient( - f, ::AnyAutoFastDifferentiation, x, extras::FastDifferentiationOneArgGradientExtras + f, ::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgGradientExtras ) jac = extras.jac_exe(vec(x)) grad_vec = @view jac[1, :] @@ -166,21 +158,14 @@ function DI.gradient( end function DI.gradient!( - f, - grad, - ::AnyAutoFastDifferentiation, - x, - extras::FastDifferentiationOneArgGradientExtras, + f, grad, ::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgGradientExtras ) extras.jac_exe!(reshape(grad, 1, length(grad)), vec(x)) return grad end function DI.value_and_gradient( - f, - backend::AnyAutoFastDifferentiation, - x, - extras::FastDifferentiationOneArgGradientExtras, + f, backend::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgGradientExtras ) return f(x), DI.gradient(f, backend, x, extras) end @@ -188,7 +173,7 @@ end function DI.value_and_gradient!( f, grad, - backend::AnyAutoFastDifferentiation, + backend::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgGradientExtras, ) @@ -261,7 +246,7 @@ struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,E1,E2} <: der2_exe!::E2 end -function DI.prepare_second_derivative(f, ::AnyAutoFastDifferentiation, x) +function DI.prepare_second_derivative(f, ::AutoFastDifferentiation, x) y_prototype = f(x) x_var = only(make_variables(:x)) y_var = f(x_var) @@ -278,7 +263,7 @@ end function DI.second_derivative( f, - ::AnyAutoFastDifferentiation, + ::AutoFastDifferentiation, x, extras::FastDifferentiationAllocatingSecondDerivativeExtras, ) @@ -292,7 +277,7 @@ end function DI.second_derivative!( f, der2, - backend::AnyAutoFastDifferentiation, + backend::AutoFastDifferentiation, x, extras::FastDifferentiationAllocatingSecondDerivativeExtras, ) @@ -307,7 +292,7 @@ struct FastDifferentiationHVPExtras{E1,E2} <: HVPExtras hvp_exe!::E2 end -function DI.prepare_hvp(f, ::AnyAutoFastDifferentiation, x, v) +function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, v) x_var = make_variables(:x, size(x)...) y_var = f(x_var) @@ -318,14 +303,14 @@ function DI.prepare_hvp(f, ::AnyAutoFastDifferentiation, x, v) return FastDifferentiationHVPExtras(hvp_exe, hvp_exe!) end -function DI.hvp(f, ::AnyAutoFastDifferentiation, x, v, extras::FastDifferentiationHVPExtras) +function DI.hvp(f, ::AutoFastDifferentiation, x, v, extras::FastDifferentiationHVPExtras) v_vec = vcat(vec(x), vec(v)) hv_vec = extras.hvp_exe(v_vec) return reshape(hv_vec, size(x)) end function DI.hvp!( - f, p, ::AnyAutoFastDifferentiation, x, v, extras::FastDifferentiationHVPExtras + f, p, ::AutoFastDifferentiation, x, v, extras::FastDifferentiationHVPExtras ) v_vec = vcat(vec(x), vec(v)) extras.hvp_exe!(p, v_vec) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl index 3225bbe44..25014f102 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl @@ -5,7 +5,7 @@ struct FastDifferentiationTwoArgPushforwardExtras{E1,E2} <: PushforwardExtras jvp_exe!::E2 end -function DI.prepare_pushforward(f!, y, ::AnyAutoFastDifferentiation, x, dx) +function DI.prepare_pushforward(f!, y, ::AutoFastDifferentiation, x, dx) x_var = if x isa Number only(make_variables(:x)) else @@ -25,7 +25,7 @@ end function DI.value_and_pushforward( f!, y, - ::AnyAutoFastDifferentiation, + ::AutoFastDifferentiation, x, dx, extras::FastDifferentiationTwoArgPushforwardExtras, @@ -40,7 +40,7 @@ function DI.value_and_pushforward!( f!, y, dy, - ::AnyAutoFastDifferentiation, + ::AutoFastDifferentiation, x, dx, extras::FastDifferentiationTwoArgPushforwardExtras, @@ -54,7 +54,7 @@ end function DI.pushforward( f!, y, - ::AnyAutoFastDifferentiation, + ::AutoFastDifferentiation, x, dx, extras::FastDifferentiationTwoArgPushforwardExtras, @@ -68,7 +68,7 @@ function DI.pushforward!( f!, y, dy, - ::AnyAutoFastDifferentiation, + ::AutoFastDifferentiation, x, dx, extras::FastDifferentiationTwoArgPushforwardExtras, @@ -85,7 +85,7 @@ struct FastDifferentiationTwoArgDerivativeExtras{E1,E2} <: DerivativeExtras der_exe!::E2 end -function DI.prepare_derivative(f!, y, ::AnyAutoFastDifferentiation, x) +function DI.prepare_derivative(f!, y, ::AutoFastDifferentiation, x) x_var = only(make_variables(:x)) y_var = make_variables(:y, size(y)...) f!(y_var, x_var) @@ -99,11 +99,7 @@ function DI.prepare_derivative(f!, y, ::AnyAutoFastDifferentiation, x) end function DI.value_and_derivative( - f!, - y, - ::AnyAutoFastDifferentiation, - x, - extras::FastDifferentiationTwoArgDerivativeExtras, + f!, y, ::AutoFastDifferentiation, x, extras::FastDifferentiationTwoArgDerivativeExtras ) f!(y, x) der = reshape(extras.der_exe(monovec(x)), size(y)) @@ -114,7 +110,7 @@ function DI.value_and_derivative!( f!, y, der, - ::AnyAutoFastDifferentiation, + ::AutoFastDifferentiation, x, extras::FastDifferentiationTwoArgDerivativeExtras, ) @@ -124,11 +120,7 @@ function DI.value_and_derivative!( end function DI.derivative( - f!, - y, - ::AnyAutoFastDifferentiation, - x, - extras::FastDifferentiationTwoArgDerivativeExtras, + f!, y, ::AutoFastDifferentiation, x, extras::FastDifferentiationTwoArgDerivativeExtras ) der = reshape(extras.der_exe(monovec(x)), size(y)) return der @@ -138,7 +130,7 @@ function DI.derivative!( f!, y, der, - ::AnyAutoFastDifferentiation, + ::AutoFastDifferentiation, x, extras::FastDifferentiationTwoArgDerivativeExtras, ) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseDiffToolsExt/DifferentiationInterfaceSparseDiffToolsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseDiffToolsExt/DifferentiationInterfaceSparseDiffToolsExt.jl index 8aa5873d6..8eb7e5b8a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseDiffToolsExt/DifferentiationInterfaceSparseDiffToolsExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseDiffToolsExt/DifferentiationInterfaceSparseDiffToolsExt.jl @@ -3,7 +3,7 @@ module DifferentiationInterfaceSparseDiffToolsExt using ADTypes import DifferentiationInterface as DI using DifferentiationInterface: - HessianExtras, JacobianExtras, NoHessianExtras, SecondOrder, inner, outer + AnyAutoSparse, HessianExtras, JacobianExtras, NoHessianExtras, SecondOrder, inner, outer using SparseDiffTools: JacPrototypeSparsityDetection, SymbolicsSparsityDetection, @@ -12,7 +12,7 @@ using SparseDiffTools: sparse_jacobian_cache using Symbolics: Symbolics -AnyOneArgAutoSparse = Union{ +AnyAutoSparseNoSymbolic = Union{ AutoSparseFiniteDiff, AutoSparseForwardDiff, AutoSparsePolyesterForwardDiff, @@ -20,27 +20,6 @@ AnyOneArgAutoSparse = Union{ AutoSparseZygote, } -AnyTwoArgAutoSparse = Union{ - AutoSparseFiniteDiff, - AutoSparseForwardDiff, - AutoSparsePolyesterForwardDiff, - AutoSparseReverseDiff, -} - -dense(::AutoSparseFiniteDiff) = AutoFiniteDiff() -dense(backend::AutoSparseReverseDiff) = AutoReverseDiff(backend.compile) -dense(::AutoSparseZygote) = AutoZygote() - -function dense(backend::AutoSparseForwardDiff{chunksize,T}) where {chunksize,T} - return AutoForwardDiff{chunksize,T}(backend.tag) -end - -function dense(::AutoSparsePolyesterForwardDiff{chunksize}) where {chunksize} - return AutoSparsePolyesterForwardDiff{chunksize}() -end - -DI.check_available(backend::AnyOneArgAutoSparse) = DI.check_available(dense(backend)) - include("onearg.jl") include("twoarg.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseDiffToolsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseDiffToolsExt/onearg.jl index b871283f3..5a475750a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseDiffToolsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseDiffToolsExt/onearg.jl @@ -9,47 +9,50 @@ end ## Jacobian -function DI.prepare_jacobian(f, backend::AnyOneArgAutoSparse, x::AbstractArray) +function DI.prepare_jacobian(f, backend::AnyAutoSparseNoSymbolic, x::AbstractArray) cache = sparse_jacobian_cache(backend, SymbolicsSparsityDetection(), f, x; fx=f(x)) return SparseDiffToolsOneArgJacobianExtras(cache) end function DI.value_and_jacobian!( - f, jac, backend::AnyOneArgAutoSparse, x, extras::SparseDiffToolsOneArgJacobianExtras + f, jac, backend::AnyAutoSparseNoSymbolic, x, extras::SparseDiffToolsOneArgJacobianExtras ) sparse_jacobian!(jac, backend, extras.cache, f, x) return f(x), jac end function DI.jacobian!( - f, jac, backend::AnyOneArgAutoSparse, x, extras::SparseDiffToolsOneArgJacobianExtras + f, jac, backend::AnyAutoSparseNoSymbolic, x, extras::SparseDiffToolsOneArgJacobianExtras ) sparse_jacobian!(jac, backend, extras.cache, f, x) return jac end function DI.value_and_jacobian( - f, backend::AnyOneArgAutoSparse, x, extras::SparseDiffToolsOneArgJacobianExtras + f, backend::AnyAutoSparseNoSymbolic, x, extras::SparseDiffToolsOneArgJacobianExtras ) return f(x), sparse_jacobian(backend, extras.cache, f, x) end function DI.jacobian( - f, backend::AnyOneArgAutoSparse, x, extras::SparseDiffToolsOneArgJacobianExtras + f, backend::AnyAutoSparseNoSymbolic, x, extras::SparseDiffToolsOneArgJacobianExtras ) return sparse_jacobian(backend, extras.cache, f, x) end ## Hessian -function DI.prepare_hessian(f, backend::SecondOrder{<:AnyOneArgAutoSparse}, x) +function DI.prepare_hessian(f, backend::SecondOrder{<:AnyAutoSparseNoSymbolic}, x) inner_gradient_closure(z) = DI.gradient(f, inner(backend), z) outer_jacobian_extras = DI.prepare_jacobian(inner_gradient_closure, outer(backend), x) return SparseDiffToolsHessianExtras(inner_gradient_closure, outer_jacobian_extras) end function DI.hessian( - f, backend::SecondOrder{<:AnyOneArgAutoSparse}, x, extras::SparseDiffToolsHessianExtras + f, + backend::SecondOrder{<:AnyAutoSparseNoSymbolic}, + x, + extras::SparseDiffToolsHessianExtras, ) return DI.jacobian( extras.inner_gradient_closure, outer(backend), x, extras.outer_jacobian_extras @@ -59,7 +62,7 @@ end function DI.hessian!( f, hess, - backend::SecondOrder{<:AnyOneArgAutoSparse}, + backend::SecondOrder{<:AnyAutoSparseNoSymbolic}, x, extras::SparseDiffToolsHessianExtras, ) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseDiffToolsExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseDiffToolsExt/twoarg.jl index 64368f99b..30c8cecf7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseDiffToolsExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseDiffToolsExt/twoarg.jl @@ -5,14 +5,14 @@ end ## Jacobian function DI.prepare_jacobian( - f!, y::AbstractArray, backend::AnyTwoArgAutoSparse, x::AbstractArray + f!, y::AbstractArray, backend::AnyAutoSparseNoSymbolic, x::AbstractArray ) cache = sparse_jacobian_cache(backend, SymbolicsSparsityDetection(), f!, similar(y), x) return SparseDiffToolsTwoArgJacobianExtras(cache) end function DI.value_and_jacobian( - f!, y, backend::AnyTwoArgAutoSparse, x, extras::SparseDiffToolsTwoArgJacobianExtras + f!, y, backend::AnyAutoSparseNoSymbolic, x, extras::SparseDiffToolsTwoArgJacobianExtras ) jac = sparse_jacobian(backend, extras.cache, f!, y, x) f!(y, x) @@ -20,7 +20,12 @@ function DI.value_and_jacobian( end function DI.value_and_jacobian!( - f!, y, jac, backend::AnyTwoArgAutoSparse, x, extras::SparseDiffToolsTwoArgJacobianExtras + f!, + y, + jac, + backend::AnyAutoSparseNoSymbolic, + x, + extras::SparseDiffToolsTwoArgJacobianExtras, ) sparse_jacobian!(jac, backend, extras.cache, f!, y, x) f!(y, x) @@ -28,14 +33,19 @@ function DI.value_and_jacobian!( end function DI.jacobian( - f!, y, backend::AnyTwoArgAutoSparse, x, extras::SparseDiffToolsTwoArgJacobianExtras + f!, y, backend::AnyAutoSparseNoSymbolic, x, extras::SparseDiffToolsTwoArgJacobianExtras ) jac = sparse_jacobian(backend, extras.cache, f!, y, x) return jac end function DI.jacobian!( - f!, y, jac, backend::AnyTwoArgAutoSparse, x, extras::SparseDiffToolsTwoArgJacobianExtras + f!, + y, + jac, + backend::AnyAutoSparseNoSymbolic, + x, + extras::SparseDiffToolsTwoArgJacobianExtras, ) sparse_jacobian!(jac, backend, extras.cache, f!, y, x) return jac diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl index 04a6edee9..f09c5fe96 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/DifferentiationInterfaceSymbolicsExt.jl @@ -29,10 +29,10 @@ using Symbolics.RuntimeGeneratedFunctions: RuntimeGeneratedFunction const AnyAutoSymbolics = Union{AutoSymbolics,AutoSparseSymbolics} -DI.check_available(::AnyAutoSymbolics) = true -DI.mode(::AnyAutoSymbolics) = ADTypes.AbstractSymbolicDifferentiationMode -DI.pushforward_performance(::AnyAutoSymbolics) = DI.PushforwardFast() -DI.pullback_performance(::AnyAutoSymbolics) = DI.PullbackSlow() +DI.check_available(::AutoSymbolics) = true +DI.mode(::AutoSymbolics) = ADTypes.AbstractSymbolicDifferentiationMode +DI.pushforward_performance(::AutoSymbolics) = DI.PushforwardFast() +DI.pullback_performance(::AutoSymbolics) = DI.PullbackSlow() monovec(x::Number) = Fill(x, 1) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl index 45555716f..e5564285d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl @@ -5,7 +5,7 @@ struct SymbolicsOneArgPushforwardExtras{E1,E2} <: PushforwardExtras pf_exe!::E2 end -function DI.prepare_pushforward(f, ::AnyAutoSymbolics, x, dx) +function DI.prepare_pushforward(f, ::AutoSymbolics, x, dx) x_var = if x isa Number variable(:x) else @@ -29,16 +29,14 @@ function DI.prepare_pushforward(f, ::AnyAutoSymbolics, x, dx) return SymbolicsOneArgPushforwardExtras(pf_exe, pf_exe!) end -function DI.pushforward( - f, ::AnyAutoSymbolics, x, dx, extras::SymbolicsOneArgPushforwardExtras -) +function DI.pushforward(f, ::AutoSymbolics, x, dx, extras::SymbolicsOneArgPushforwardExtras) v_vec = vcat(myvec(x), myvec(dx)) dy = extras.pf_exe(v_vec) return dy end function DI.pushforward!( - f, dy, ::AnyAutoSymbolics, x, dx, extras::SymbolicsOneArgPushforwardExtras + f, dy, ::AutoSymbolics, x, dx, extras::SymbolicsOneArgPushforwardExtras ) v_vec = vcat(myvec(x), myvec(dx)) extras.pf_exe!(dy, v_vec) @@ -46,13 +44,13 @@ function DI.pushforward!( end function DI.value_and_pushforward( - f, backend::AnyAutoSymbolics, x, dx, extras::SymbolicsOneArgPushforwardExtras + f, backend::AutoSymbolics, x, dx, extras::SymbolicsOneArgPushforwardExtras ) return f(x), DI.pushforward(f, backend, x, dx, extras) end function DI.value_and_pushforward!( - f, dy, backend::AnyAutoSymbolics, x, dx, extras::SymbolicsOneArgPushforwardExtras + f, dy, backend::AutoSymbolics, x, dx, extras::SymbolicsOneArgPushforwardExtras ) return f(x), DI.pushforward!(f, dy, backend, x, dx, extras) end @@ -64,7 +62,7 @@ struct SymbolicsOneArgDerivativeExtras{E1,E2} <: DerivativeExtras der_exe!::E2 end -function DI.prepare_derivative(f, ::AnyAutoSymbolics, x) +function DI.prepare_derivative(f, ::AutoSymbolics, x) x_var = variable(:x) der_var = derivative(f(x_var), x_var) @@ -77,25 +75,23 @@ function DI.prepare_derivative(f, ::AnyAutoSymbolics, x) return SymbolicsOneArgDerivativeExtras(der_exe, der_exe!) end -function DI.derivative(f, ::AnyAutoSymbolics, x, extras::SymbolicsOneArgDerivativeExtras) +function DI.derivative(f, ::AutoSymbolics, x, extras::SymbolicsOneArgDerivativeExtras) return extras.der_exe(x) end -function DI.derivative!( - f, der, ::AnyAutoSymbolics, x, extras::SymbolicsOneArgDerivativeExtras -) +function DI.derivative!(f, der, ::AutoSymbolics, x, extras::SymbolicsOneArgDerivativeExtras) extras.der_exe!(der, x) return der end function DI.value_and_derivative( - f, backend::AnyAutoSymbolics, x, extras::SymbolicsOneArgDerivativeExtras + f, backend::AutoSymbolics, x, extras::SymbolicsOneArgDerivativeExtras ) return f(x), DI.derivative(f, backend, x, extras) end function DI.value_and_derivative!( - f, der, backend::AnyAutoSymbolics, x, extras::SymbolicsOneArgDerivativeExtras + f, der, backend::AutoSymbolics, x, extras::SymbolicsOneArgDerivativeExtras ) return f(x), DI.derivative!(f, der, backend, x, extras) end @@ -107,7 +103,7 @@ struct SymbolicsOneArgGradientExtras{E1,E2} <: GradientExtras grad_exe!::E2 end -function DI.prepare_gradient(f, ::AnyAutoSymbolics, x) +function DI.prepare_gradient(f, ::AutoSymbolics, x) x_var = variables(:x, axes(x)...) # Symbolic.gradient only accepts vectors grad_var = gradient(f(x_var), vec(x_var)) @@ -121,23 +117,23 @@ function DI.prepare_gradient(f, ::AnyAutoSymbolics, x) return SymbolicsOneArgGradientExtras(grad_exe, grad_exe!) end -function DI.gradient(f, ::AnyAutoSymbolics, x, extras::SymbolicsOneArgGradientExtras) +function DI.gradient(f, ::AutoSymbolics, x, extras::SymbolicsOneArgGradientExtras) return reshape(extras.grad_exe(vec(x)), size(x)) end -function DI.gradient!(f, grad, ::AnyAutoSymbolics, x, extras::SymbolicsOneArgGradientExtras) +function DI.gradient!(f, grad, ::AutoSymbolics, x, extras::SymbolicsOneArgGradientExtras) extras.grad_exe!(vec(grad), vec(x)) return grad end function DI.value_and_gradient( - f, backend::AnyAutoSymbolics, x, extras::SymbolicsOneArgGradientExtras + f, backend::AutoSymbolics, x, extras::SymbolicsOneArgGradientExtras ) return f(x), DI.gradient(f, backend, x, extras) end function DI.value_and_gradient!( - f, grad, backend::AnyAutoSymbolics, x, extras::SymbolicsOneArgGradientExtras + f, grad, backend::AutoSymbolics, x, extras::SymbolicsOneArgGradientExtras ) return f(x), DI.gradient!(f, grad, backend, x, extras) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl index 4975d5bc0..5ab435492 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/twoarg.jl @@ -5,7 +5,7 @@ struct SymbolicsTwoArgPushforwardExtras{E1,E2} <: PushforwardExtras pushforward_exe!::E2 end -function DI.prepare_pushforward(f!, y, ::AnyAutoSymbolics, x, dx) +function DI.prepare_pushforward(f!, y, ::AutoSymbolics, x, dx) x_var = if x isa Number variable(:x) else @@ -32,7 +32,7 @@ function DI.prepare_pushforward(f!, y, ::AnyAutoSymbolics, x, dx) end function DI.pushforward( - f!, y, ::AnyAutoSymbolics, x, dx, extras::SymbolicsTwoArgPushforwardExtras + f!, y, ::AutoSymbolics, x, dx, extras::SymbolicsTwoArgPushforwardExtras ) v_vec = vcat(myvec(x), myvec(dx)) dy = extras.pushforward_exe(v_vec) @@ -40,7 +40,7 @@ function DI.pushforward( end function DI.pushforward!( - f!, y, dy, ::AnyAutoSymbolics, x, dx, extras::SymbolicsTwoArgPushforwardExtras + f!, y, dy, ::AutoSymbolics, x, dx, extras::SymbolicsTwoArgPushforwardExtras ) v_vec = vcat(myvec(x), myvec(dx)) extras.pushforward_exe!(dy, v_vec) @@ -48,7 +48,7 @@ function DI.pushforward!( end function DI.value_and_pushforward( - f!, y, backend::AnyAutoSymbolics, x, dx, extras::SymbolicsTwoArgPushforwardExtras + f!, y, backend::AutoSymbolics, x, dx, extras::SymbolicsTwoArgPushforwardExtras ) dy = DI.pushforward(f!, y, backend, x, dx, extras) f!(y, x) @@ -56,7 +56,7 @@ function DI.value_and_pushforward( end function DI.value_and_pushforward!( - f!, y, dy, backend::AnyAutoSymbolics, x, dx, extras::SymbolicsTwoArgPushforwardExtras + f!, y, dy, backend::AutoSymbolics, x, dx, extras::SymbolicsTwoArgPushforwardExtras ) DI.pushforward!(f!, y, dy, backend, x, dx, extras) f!(y, x) @@ -70,7 +70,7 @@ struct SymbolicsTwoArgDerivativeExtras{E1,E2} <: DerivativeExtras der_exe!::E2 end -function DI.prepare_derivative(f!, y, ::AnyAutoSymbolics, x) +function DI.prepare_derivative(f!, y, ::AutoSymbolics, x) x_var = variable(:x) y_var = variables(:y, axes(y)...) f!(y_var, x_var) @@ -85,21 +85,19 @@ function DI.prepare_derivative(f!, y, ::AnyAutoSymbolics, x) return SymbolicsTwoArgDerivativeExtras(der_exe, der_exe!) end -function DI.derivative( - f!, y, ::AnyAutoSymbolics, x, extras::SymbolicsTwoArgDerivativeExtras -) +function DI.derivative(f!, y, ::AutoSymbolics, x, extras::SymbolicsTwoArgDerivativeExtras) return extras.der_exe(x) end function DI.derivative!( - f!, y, der, ::AnyAutoSymbolics, x, extras::SymbolicsTwoArgDerivativeExtras + f!, y, der, ::AutoSymbolics, x, extras::SymbolicsTwoArgDerivativeExtras ) extras.der_exe!(der, x) return der end function DI.value_and_derivative( - f!, y, backend::AnyAutoSymbolics, x, extras::SymbolicsTwoArgDerivativeExtras + f!, y, backend::AutoSymbolics, x, extras::SymbolicsTwoArgDerivativeExtras ) der = DI.derivative(f!, y, backend, x, extras) f!(y, x) @@ -107,7 +105,7 @@ function DI.value_and_derivative( end function DI.value_and_derivative!( - f!, y, der, backend::AnyAutoSymbolics, x, extras::SymbolicsTwoArgDerivativeExtras + f!, y, der, backend::AutoSymbolics, x, extras::SymbolicsTwoArgDerivativeExtras ) DI.derivative!(f!, y, der, backend, x, extras) f!(y, x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index cfbaf773f..ae123f628 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -9,7 +9,7 @@ using Zygote: ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian DI.check_available(::AutoZygote) = true -DI.mutation_support(::Union{AutoZygote,AutoSparseZygote}) = DI.MutationNotSupported() +DI.mutation_support(::AutoZygote) = DI.MutationNotSupported() ## Pullback diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index efa11debb..e9b155daa 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -56,6 +56,7 @@ include("hvp.jl") include("hessian.jl") include("check.jl") +include("sparse.jl") export AutoChainRules, AutoDiffractor, diff --git a/DifferentiationInterface/src/sparse.jl b/DifferentiationInterface/src/sparse.jl new file mode 100644 index 000000000..2c9922fbe --- /dev/null +++ b/DifferentiationInterface/src/sparse.jl @@ -0,0 +1,138 @@ +AnyAutoSparse = Union{ + AutoSparseFastDifferentiation, + AutoSparseFiniteDiff, + AutoSparseForwardDiff, + AutoSparsePolyesterForwardDiff, + AutoSparseReverseDiff, + AutoSparseSymbolics, + AutoSparseZygote, +} + +## Conversion + +dense_ad(::AutoSparseFastDifferentiation) = AutoFastDifferentiation() +dense_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff() +dense_ad(backend::AutoSparseReverseDiff) = AutoReverseDiff(backend.compile) +dense_ad(::AutoSparseSymbolics) = AutoSymbolics() +dense_ad(::AutoSparseZygote) = AutoZygote() + +function dense_ad(backend::AutoSparseForwardDiff{chunksize,T}) where {chunksize,T} + return AutoForwardDiff{chunksize,T}(backend.tag) +end + +function dense_ad(::AutoSparsePolyesterForwardDiff{chunksize}) where {chunksize} + return AutoSparsePolyesterForwardDiff{chunksize}() +end + +## Traits + +for trait in ( + :check_available, + :mode, + :mutation_support, + :pushforward_performance, + :pullback_performance, + :hvp_mode, +) + @eval $trait(backend::AnyAutoSparse) = $trait(dense_ad(backend)) +end + +## Operators + +for op in (:pushforward, :pullback, :hvp) + op! = Symbol(op, "!") + valop = Symbol("value_and_", op) + valop! = Symbol("value_and_", op, "!") + prep = Symbol("prepare_", op) + E = if op == :pushforward + :PushforwardExtras + elseif op == :pullback + :PullbackExtras + elseif op == :hvp + :HVPExtras + end + + ## One argument + @eval begin + $prep(f, ba::AnyAutoSparse, x, v) = $prep(f, dense_ad(ba), x, v) + $op(f, ba::AnyAutoSparse, x, v, ex::$E=$prep(f, ba, x, v)) = + $op(f, dense_ad(ba), x, v, ex) + $valop(f, ba::AnyAutoSparse, x, v, ex::$E=$prep(f, ba, x, v)) = + $valop(f, dense_ad(ba), x, v, ex) + $op!(f, res, ba::AnyAutoSparse, x, v, ex::$E=$prep(f, ba, x, v)) = + $op!(f, res, dense_ad(ba), x, v, ex) + $valop!(f, res, ba::AnyAutoSparse, x, v, ex::$E=$prep(f, ba, x, v)) = + $valop!(f, res, dense_ad(ba), x, v, ex) + end + + ## Two arguments + @eval begin + $prep(f!, y, ba::AnyAutoSparse, x, v) = $prep(f!, y, dense_ad(ba), x, v) + $op(f!, y, ba::AnyAutoSparse, x, v, ex::$E=$prep(f!, y, ba, x, v)) = + $op(f!, y, dense_ad(ba), x, v, ex) + $valop(f!, y, ba::AnyAutoSparse, x, v, ex::$E=$prep(f!, y, ba, x, v)) = + $valop(f!, y, dense_ad(ba), x, v, ex) + $op!(f!, y, res, ba::AnyAutoSparse, x, v, ex::$E=$prep(f!, y, ba, x, v)) = + $op!(f!, y, res, dense_ad(ba), x, v, ex) + $valop!(f!, y, res, ba::AnyAutoSparse, x, v, ex::$E=$prep(f!, y, ba, x, v)) = + $valop!(f!, y, res, dense_ad(ba), x, v, ex) + end + + ## Split + if op == :pullback + valop_split = Symbol("value_and_", op, "_split") + valop!_split = Symbol("value_and_", op!, "_split") + + @eval begin + $valop_split(f, ba::AnyAutoSparse, x, ex::$E=$prep(f, ba, x, f(x))) = + $valop_split(f, dense_ad(ba), x, ex) + $valop!_split(f, ba::AnyAutoSparse, x, ex::$E=$prep(f, ba, x, f(x))) = + $valop!_split(f, dense_ad(ba), x, ex) + $valop_split(f!, y, ba::AnyAutoSparse, x, ex::$E=$prep(f, ba, x, similar(y))) = + $valop_split(f!, y, dense_ad(ba), x, ex) + $valop!_split(f!, y, ba::AnyAutoSparse, x, ex::$E=$prep(f, ba, x, similar(y))) = + $valop!_split(f!, y, dense_ad(ba), x, ex) + end + end +end + +for op in (:derivative, :gradient, :second_derivative) + op! = Symbol(op, "!") + valop = Symbol("value_and_", op) + valop! = Symbol("value_and_", op, "!") + prep = Symbol("prepare_", op) + E = if op == :derivative + :DerivativeExtras + elseif op == :gradient + :GradientExtras + elseif op == :second_derivative + :SecondDerivativeExtras + end + + ## One argument + @eval begin + $prep(f, ba::AnyAutoSparse, x) = $prep(f, dense_ad(ba), x) + $op(f, ba::AnyAutoSparse, x, ex::$E=$prep(f, ba, x)) = $op(f, dense_ad(ba), x, ex) + $valop(f, ba::AnyAutoSparse, x, ex::$E=$prep(f, ba, x)) = + $valop(f, dense_ad(ba), x, ex) + $op!(f, res, ba::AnyAutoSparse, x, ex::$E=$prep(f, ba, x)) = + $op!(f, res, dense_ad(ba), x, ex) + $valop!(f, res, ba::AnyAutoSparse, x, ex::$E=$prep(f, ba, x)) = + $valop!(f, res, dense_ad(ba), x, ex) + end + + ## Two arguments + if op in (:derivative, :jacobian) + @eval begin + $prep(f!, y, ba::AnyAutoSparse, x) = $prep(f!, y, dense_ad(ba), x) + $op(f!, y, ba::AnyAutoSparse, x, ex::$E=$prep(f!, y, ba, x)) = + $op(f!, y, dense_ad(ba), x, ex) + $valop(f!, y, ba::AnyAutoSparse, x, ex::$E=$prep(f!, y, ba, x)) = + $valop(f!, y, dense_ad(ba), x, ex) + $op!(f!, y, res, ba::AnyAutoSparse, x, ex::$E=$prep(f!, y, ba, x)) = + $op!(f!, y, res, dense_ad(ba), x, ex) + $valop!(f!, y, res, ba::AnyAutoSparse, x, ex::$E=$prep(f!, y, ba, x)) = + $valop!(f!, y, res, dense_ad(ba), x, ex) + end + end +end diff --git a/DifferentiationInterface/test/first_order.jl b/DifferentiationInterface/test/first_order.jl index 94da35da9..75ce8b81d 100644 --- a/DifferentiationInterface/test/first_order.jl +++ b/DifferentiationInterface/test/first_order.jl @@ -1,4 +1,4 @@ -all_backends = [ +dense_backends = [ AutoChainRules(Zygote.ZygoteRuleConfig()), AutoDiffractor(), AutoEnzyme(Enzyme.Forward), @@ -15,12 +15,23 @@ all_backends = [ AutoZygote(), ] +sparse_backends = [ + AutoSparseFastDifferentiation(), AutoSparseForwardDiff(), AutoSparseSymbolics() +] + ## -for backend in all_backends +for backend in vcat(dense_backends, sparse_backends) @test check_available(backend) end test_differentiation( - all_backends; second_order=false, logging=get(ENV, "CI", "false") == "false" + dense_backends; second_order=false, logging=get(ENV, "CI", "false") == "false" +); + +test_differentiation( + sparse_backends; + second_order=false, + excluded=[JacobianScenario], + logging=get(ENV, "CI", "false") == "false", ); diff --git a/DifferentiationInterface/test/second_order.jl b/DifferentiationInterface/test/second_order.jl index 7fcda9044..72fc327ef 100644 --- a/DifferentiationInterface/test/second_order.jl +++ b/DifferentiationInterface/test/second_order.jl @@ -1,4 +1,4 @@ -second_order_backends = [ +dense_second_order_backends = [ AutoForwardDiff(), # AutoPolyesterForwardDiff(; chunksize=1), # AutoFastDifferentiation(), @@ -6,7 +6,13 @@ second_order_backends = [ AutoReverseDiff(), ] -second_order_mixed_backends = [ +sparse_second_order_backends = [ + AutoSparseForwardDiff(), # + AutoSparseFastDifferentiation(), + AutoSparseSymbolics(), +] + +mixed_second_order_backends = [ # forward over forward SecondOrder(AutoForwardDiff(), AutoEnzyme(Enzyme.Forward)), # forward over reverse @@ -19,14 +25,24 @@ second_order_mixed_backends = [ ## -for backend in vcat(second_order_backends, second_order_mixed_backends) +for backend in vcat( + dense_second_order_backends, sparse_second_order_backends, mixed_second_order_backends +) check_hessian(backend) end test_differentiation( - vcat(second_order_backends, second_order_mixed_backends); + vcat(dense_second_order_backends, mixed_second_order_backends); + first_order=false, + second_order=true, + logging=get(ENV, "CI", "false") == "false", +); + +test_differentiation( + sparse_second_order_backends; first_order=false, second_order=true, + excluded=[HessianScenario], logging=get(ENV, "CI", "false") == "false", ); diff --git a/DifferentiationInterface/test/sparsity.jl b/DifferentiationInterface/test/sparsity.jl index fd64b2482..a0837f994 100644 --- a/DifferentiationInterface/test/sparsity.jl +++ b/DifferentiationInterface/test/sparsity.jl @@ -8,6 +8,8 @@ sparse_backends = [ sparse_second_order_backends = [ AutoSparseFastDifferentiation(), + AutoSparseForwardDiff(), + AutoSparseSymbolics(), SecondOrder(AutoSparseForwardDiff(), AutoZygote()), SecondOrder(AutoSparseFiniteDiff(), AutoZygote()), ]