Skip to content

Commit

Permalink
Fallback from sparse to dense (#183)
Browse files Browse the repository at this point in the history
* Fallback from sparse to dense

* Fix ambiguities

* Missing Any

* Fix more fixes
  • Loading branch information
gdalle authored Apr 16, 2024
1 parent 19b5be4 commit 2751235
Show file tree
Hide file tree
Showing 16 changed files with 269 additions and 139 deletions.
3 changes: 1 addition & 2 deletions DifferentiationInterface/docs/src/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ AutoZygote

### Sparse

!!! warning
For sparse backends, only the Jacobian and Hessian operators are implemented.
For sparse backends, only the Jacobian and Hessian operators are implemented differently, the other operators behave the same as for the corresponding dense backend.

```@docs
AutoSparseFastDifferentiation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ const AnyAutoFastDifferentiation = Union{
AutoFastDifferentiation,AutoSparseFastDifferentiation
}

DI.check_available(::AnyAutoFastDifferentiation) = true
DI.mode(::AnyAutoFastDifferentiation) = ADTypes.AbstractSymbolicDifferentiationMode
DI.pushforward_performance(::AnyAutoFastDifferentiation) = DI.PushforwardFast()
DI.pullback_performance(::AnyAutoFastDifferentiation) = DI.PullbackSlow()
DI.check_available(::AutoFastDifferentiation) = true
DI.mode(::AutoFastDifferentiation) = ADTypes.AbstractSymbolicDifferentiationMode
DI.pushforward_performance(::AutoFastDifferentiation) = DI.PushforwardFast()
DI.pullback_performance(::AutoFastDifferentiation) = DI.PullbackSlow()

monovec(x::Number) = Fill(x, 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E2} <: PushforwardExtras
jvp_exe!::E2
end

function DI.prepare_pushforward(f, ::AnyAutoFastDifferentiation, x, dx)
function DI.prepare_pushforward(f, ::AutoFastDifferentiation, x, dx)
y_prototype = f(x)
x_var = if x isa Number
only(make_variables(:x))
Expand All @@ -24,11 +24,7 @@ function DI.prepare_pushforward(f, ::AnyAutoFastDifferentiation, x, dx)
end

function DI.pushforward(
f,
::AnyAutoFastDifferentiation,
x,
dx,
extras::FastDifferentiationOneArgPushforwardExtras,
f, ::AutoFastDifferentiation, x, dx, extras::FastDifferentiationOneArgPushforwardExtras
)
v_vec = vcat(myvec(x), myvec(dx))
if extras.y_prototype isa Number
Expand All @@ -41,7 +37,7 @@ end
function DI.pushforward!(
f,
dy,
::AnyAutoFastDifferentiation,
::AutoFastDifferentiation,
x,
dx,
extras::FastDifferentiationOneArgPushforwardExtras,
Expand All @@ -53,7 +49,7 @@ end

function DI.value_and_pushforward(
f,
backend::AnyAutoFastDifferentiation,
backend::AutoFastDifferentiation,
x,
dx,
extras::FastDifferentiationOneArgPushforwardExtras,
Expand All @@ -64,7 +60,7 @@ end
function DI.value_and_pushforward!(
f,
dy,
backend::AnyAutoFastDifferentiation,
backend::AutoFastDifferentiation,
x,
dx,
extras::FastDifferentiationOneArgPushforwardExtras,
Expand All @@ -84,7 +80,7 @@ struct FastDifferentiationOneArgDerivativeExtras{Y,E1,E2} <: DerivativeExtras
der_exe!::E2
end

function DI.prepare_derivative(f, ::AnyAutoFastDifferentiation, x)
function DI.prepare_derivative(f, ::AutoFastDifferentiation, x)
y_prototype = f(x)
x_var = only(make_variables(:x))
y_var = f(x_var)
Expand All @@ -98,7 +94,7 @@ function DI.prepare_derivative(f, ::AnyAutoFastDifferentiation, x)
end

function DI.derivative(
f, ::AnyAutoFastDifferentiation, x, extras::FastDifferentiationOneArgDerivativeExtras
f, ::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgDerivativeExtras
)
if extras.y_prototype isa Number
return only(extras.der_exe(monovec(x)))
Expand All @@ -108,19 +104,15 @@ function DI.derivative(
end

function DI.derivative!(
f,
der,
::AnyAutoFastDifferentiation,
x,
extras::FastDifferentiationOneArgDerivativeExtras,
f, der, ::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgDerivativeExtras
)
extras.der_exe!(vec(der), monovec(x))
return der
end

function DI.value_and_derivative(
f,
backend::AnyAutoFastDifferentiation,
backend::AutoFastDifferentiation,
x,
extras::FastDifferentiationOneArgDerivativeExtras,
)
Expand All @@ -130,7 +122,7 @@ end
function DI.value_and_derivative!(
f,
der,
backend::AnyAutoFastDifferentiation,
backend::AutoFastDifferentiation,
x,
extras::FastDifferentiationOneArgDerivativeExtras,
)
Expand All @@ -144,7 +136,7 @@ struct FastDifferentiationOneArgGradientExtras{E1,E2} <: GradientExtras
jac_exe!::E2
end

function DI.prepare_gradient(f, backend::AnyAutoFastDifferentiation, x)
function DI.prepare_gradient(f, backend::AutoFastDifferentiation, x)
y_prototype = f(x)
x_var = make_variables(:x, size(x)...)
y_var = f(x_var)
Expand All @@ -158,37 +150,30 @@ function DI.prepare_gradient(f, backend::AnyAutoFastDifferentiation, x)
end

function DI.gradient(
f, ::AnyAutoFastDifferentiation, x, extras::FastDifferentiationOneArgGradientExtras
f, ::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgGradientExtras
)
jac = extras.jac_exe(vec(x))
grad_vec = @view jac[1, :]
return reshape(grad_vec, size(x))
end

function DI.gradient!(
f,
grad,
::AnyAutoFastDifferentiation,
x,
extras::FastDifferentiationOneArgGradientExtras,
f, grad, ::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgGradientExtras
)
extras.jac_exe!(reshape(grad, 1, length(grad)), vec(x))
return grad
end

function DI.value_and_gradient(
f,
backend::AnyAutoFastDifferentiation,
x,
extras::FastDifferentiationOneArgGradientExtras,
f, backend::AutoFastDifferentiation, x, extras::FastDifferentiationOneArgGradientExtras
)
return f(x), DI.gradient(f, backend, x, extras)
end

function DI.value_and_gradient!(
f,
grad,
backend::AnyAutoFastDifferentiation,
backend::AutoFastDifferentiation,
x,
extras::FastDifferentiationOneArgGradientExtras,
)
Expand Down Expand Up @@ -261,7 +246,7 @@ struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,E1,E2} <:
der2_exe!::E2
end

function DI.prepare_second_derivative(f, ::AnyAutoFastDifferentiation, x)
function DI.prepare_second_derivative(f, ::AutoFastDifferentiation, x)
y_prototype = f(x)
x_var = only(make_variables(:x))
y_var = f(x_var)
Expand All @@ -278,7 +263,7 @@ end

function DI.second_derivative(
f,
::AnyAutoFastDifferentiation,
::AutoFastDifferentiation,
x,
extras::FastDifferentiationAllocatingSecondDerivativeExtras,
)
Expand All @@ -292,7 +277,7 @@ end
function DI.second_derivative!(
f,
der2,
backend::AnyAutoFastDifferentiation,
backend::AutoFastDifferentiation,
x,
extras::FastDifferentiationAllocatingSecondDerivativeExtras,
)
Expand All @@ -307,7 +292,7 @@ struct FastDifferentiationHVPExtras{E1,E2} <: HVPExtras
hvp_exe!::E2
end

function DI.prepare_hvp(f, ::AnyAutoFastDifferentiation, x, v)
function DI.prepare_hvp(f, ::AutoFastDifferentiation, x, v)
x_var = make_variables(:x, size(x)...)
y_var = f(x_var)

Expand All @@ -318,14 +303,14 @@ function DI.prepare_hvp(f, ::AnyAutoFastDifferentiation, x, v)
return FastDifferentiationHVPExtras(hvp_exe, hvp_exe!)
end

function DI.hvp(f, ::AnyAutoFastDifferentiation, x, v, extras::FastDifferentiationHVPExtras)
function DI.hvp(f, ::AutoFastDifferentiation, x, v, extras::FastDifferentiationHVPExtras)
v_vec = vcat(vec(x), vec(v))
hv_vec = extras.hvp_exe(v_vec)
return reshape(hv_vec, size(x))
end

function DI.hvp!(
f, p, ::AnyAutoFastDifferentiation, x, v, extras::FastDifferentiationHVPExtras
f, p, ::AutoFastDifferentiation, x, v, extras::FastDifferentiationHVPExtras
)
v_vec = vcat(vec(x), vec(v))
extras.hvp_exe!(p, v_vec)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ struct FastDifferentiationTwoArgPushforwardExtras{E1,E2} <: PushforwardExtras
jvp_exe!::E2
end

function DI.prepare_pushforward(f!, y, ::AnyAutoFastDifferentiation, x, dx)
function DI.prepare_pushforward(f!, y, ::AutoFastDifferentiation, x, dx)
x_var = if x isa Number
only(make_variables(:x))
else
Expand All @@ -25,7 +25,7 @@ end
function DI.value_and_pushforward(
f!,
y,
::AnyAutoFastDifferentiation,
::AutoFastDifferentiation,
x,
dx,
extras::FastDifferentiationTwoArgPushforwardExtras,
Expand All @@ -40,7 +40,7 @@ function DI.value_and_pushforward!(
f!,
y,
dy,
::AnyAutoFastDifferentiation,
::AutoFastDifferentiation,
x,
dx,
extras::FastDifferentiationTwoArgPushforwardExtras,
Expand All @@ -54,7 +54,7 @@ end
function DI.pushforward(
f!,
y,
::AnyAutoFastDifferentiation,
::AutoFastDifferentiation,
x,
dx,
extras::FastDifferentiationTwoArgPushforwardExtras,
Expand All @@ -68,7 +68,7 @@ function DI.pushforward!(
f!,
y,
dy,
::AnyAutoFastDifferentiation,
::AutoFastDifferentiation,
x,
dx,
extras::FastDifferentiationTwoArgPushforwardExtras,
Expand All @@ -85,7 +85,7 @@ struct FastDifferentiationTwoArgDerivativeExtras{E1,E2} <: DerivativeExtras
der_exe!::E2
end

function DI.prepare_derivative(f!, y, ::AnyAutoFastDifferentiation, x)
function DI.prepare_derivative(f!, y, ::AutoFastDifferentiation, x)
x_var = only(make_variables(:x))
y_var = make_variables(:y, size(y)...)
f!(y_var, x_var)
Expand All @@ -99,11 +99,7 @@ function DI.prepare_derivative(f!, y, ::AnyAutoFastDifferentiation, x)
end

function DI.value_and_derivative(
f!,
y,
::AnyAutoFastDifferentiation,
x,
extras::FastDifferentiationTwoArgDerivativeExtras,
f!, y, ::AutoFastDifferentiation, x, extras::FastDifferentiationTwoArgDerivativeExtras
)
f!(y, x)
der = reshape(extras.der_exe(monovec(x)), size(y))
Expand All @@ -114,7 +110,7 @@ function DI.value_and_derivative!(
f!,
y,
der,
::AnyAutoFastDifferentiation,
::AutoFastDifferentiation,
x,
extras::FastDifferentiationTwoArgDerivativeExtras,
)
Expand All @@ -124,11 +120,7 @@ function DI.value_and_derivative!(
end

function DI.derivative(
f!,
y,
::AnyAutoFastDifferentiation,
x,
extras::FastDifferentiationTwoArgDerivativeExtras,
f!, y, ::AutoFastDifferentiation, x, extras::FastDifferentiationTwoArgDerivativeExtras
)
der = reshape(extras.der_exe(monovec(x)), size(y))
return der
Expand All @@ -138,7 +130,7 @@ function DI.derivative!(
f!,
y,
der,
::AnyAutoFastDifferentiation,
::AutoFastDifferentiation,
x,
extras::FastDifferentiationTwoArgDerivativeExtras,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module DifferentiationInterfaceSparseDiffToolsExt
using ADTypes
import DifferentiationInterface as DI
using DifferentiationInterface:
HessianExtras, JacobianExtras, NoHessianExtras, SecondOrder, inner, outer
AnyAutoSparse, HessianExtras, JacobianExtras, NoHessianExtras, SecondOrder, inner, outer
using SparseDiffTools:
JacPrototypeSparsityDetection,
SymbolicsSparsityDetection,
Expand All @@ -12,35 +12,14 @@ using SparseDiffTools:
sparse_jacobian_cache
using Symbolics: Symbolics

AnyOneArgAutoSparse = Union{
AnyAutoSparseNoSymbolic = Union{
AutoSparseFiniteDiff,
AutoSparseForwardDiff,
AutoSparsePolyesterForwardDiff,
AutoSparseReverseDiff,
AutoSparseZygote,
}

AnyTwoArgAutoSparse = Union{
AutoSparseFiniteDiff,
AutoSparseForwardDiff,
AutoSparsePolyesterForwardDiff,
AutoSparseReverseDiff,
}

dense(::AutoSparseFiniteDiff) = AutoFiniteDiff()
dense(backend::AutoSparseReverseDiff) = AutoReverseDiff(backend.compile)
dense(::AutoSparseZygote) = AutoZygote()

function dense(backend::AutoSparseForwardDiff{chunksize,T}) where {chunksize,T}
return AutoForwardDiff{chunksize,T}(backend.tag)
end

function dense(::AutoSparsePolyesterForwardDiff{chunksize}) where {chunksize}
return AutoSparsePolyesterForwardDiff{chunksize}()
end

DI.check_available(backend::AnyOneArgAutoSparse) = DI.check_available(dense(backend))

include("onearg.jl")
include("twoarg.jl")

Expand Down
Loading

0 comments on commit 2751235

Please sign in to comment.