Skip to content

Commit

Permalink
Improve type stability tests, better use of AutoZero backends (#437)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Sep 1, 2024
1 parent e89aab3 commit ea46769
Show file tree
Hide file tree
Showing 13 changed files with 205 additions and 222 deletions.
1 change: 1 addition & 0 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ include("sparse/hessian.jl")
include("misc/differentiate_with.jl")
include("misc/sparsity_detector.jl")
include("misc/from_primitive.jl")
include("misc/zero_backends.jl")

function __init__()
@require_extensions
Expand Down
8 changes: 4 additions & 4 deletions DifferentiationInterface/src/fallbacks/no_tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ for op in (:pushforward, :pullback, :hvp)
return only(t)
end
@eval function $op!(f::F, result, backend::AbstractADType, x, seed, ex::$E) where {F}
@assert !isa(seed, Tangents)
@assert !isa(seed, Tangents) && !isa(result, Tangents)
t = $op!(f, SingleTangent(result), backend, x, SingleTangent(seed), ex)
return only(t)
end
Expand All @@ -41,7 +41,7 @@ for op in (:pushforward, :pullback, :hvp)
@eval function $val_and_op!(
f::F, result, backend::AbstractADType, x, seed, ex::$E
) where {F}
@assert !isa(seed, Tangents)
@assert !isa(seed, Tangents) && !isa(result, Tangents)
y, t = $val_and_op!(f, SingleTangent(result), backend, x, SingleTangent(seed), ex)
return y, only(t)
end
Expand All @@ -60,7 +60,7 @@ for op in (:pushforward, :pullback, :hvp)
@eval function $op!(
f!::F, y, result, backend::AbstractADType, x, seed, ex::$E
) where {F}
@assert !isa(seed, Tangents)
@assert !isa(seed, Tangents) && !isa(result, Tangents)
t = $op!(f!, y, SingleTangent(result), backend, x, SingleTangent(seed), ex)
return only(t)
end
Expand All @@ -72,7 +72,7 @@ for op in (:pushforward, :pullback, :hvp)
@eval function $val_and_op!(
f!::F, y, result, backend::AbstractADType, x, seed, ex::$E
) where {F}
@assert !isa(seed, Tangents)
@assert !isa(seed, Tangents) && !isa(result, Tangents)
y, t = $val_and_op!(
f!, y, SingleTangent(result), backend, x, SingleTangent(seed), ex
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
zero!(x::AbstractArray) = x .= zero(eltype(x))
struct ReturnZero{T}
template::T
end

(rz::ReturnZero)(i) = zero(rz.template)

_zero!(x::AbstractArray) = x .= zero(eltype(x))

## Forward

Expand All @@ -11,44 +17,44 @@ Used in testing and benchmarking.
struct AutoZeroForward <: AbstractADType end

ADTypes.mode(::AutoZeroForward) = ForwardMode()
DI.check_available(::AutoZeroForward) = true
DI.twoarg_support(::AutoZeroForward) = DI.TwoArgSupported()
check_available(::AutoZeroForward) = true
twoarg_support(::AutoZeroForward) = TwoArgSupported()

DI.prepare_pushforward(f, ::AutoZeroForward, x, tx::Tangents) = NoPushforwardExtras()
DI.prepare_pushforward(f!, y, ::AutoZeroForward, x, tx::Tangents) = NoPushforwardExtras()
prepare_pushforward(f, ::AutoZeroForward, x, tx::Tangents) = NoPushforwardExtras()
prepare_pushforward(f!, y, ::AutoZeroForward, x, tx::Tangents) = NoPushforwardExtras()

function DI.value_and_pushforward(
function value_and_pushforward(
f, ::AutoZeroForward, x, tx::Tangents{B}, ::NoPushforwardExtras
) where {B}
y = f(x)
dys = ntuple(Returns(zero(y)), Val(B))
dys = ntuple(ReturnZero(y), Val(B))
return y, Tangents(dys)
end

function DI.value_and_pushforward(
function value_and_pushforward(
f!, y, ::AutoZeroForward, x, tx::Tangents{B}, ::NoPushforwardExtras
) where {B}
f!(y, x)
dys = ntuple(Returns(zero(y)), Val(B))
dys = ntuple(ReturnZero(y), Val(B))
return y, Tangents(dys)
end

function DI.value_and_pushforward!(
function value_and_pushforward!(
f, ty::Tangents, ::AutoZeroForward, x, tx::Tangents, ::NoPushforwardExtras
)
y = f(x)
for b in eachindex(ty.d)
zero!(ty.d[b])
_zero!(ty.d[b])
end
return y, ty
end

function DI.value_and_pushforward!(
function value_and_pushforward!(
f!, y, ty::Tangents, ::AutoZeroForward, x, tx::Tangents, ::NoPushforwardExtras
)
f!(y, x)
for b in eachindex(ty.d)
zero!(ty.d[b])
_zero!(ty.d[b])
end
return y, ty
end
Expand All @@ -64,44 +70,44 @@ Used in testing and benchmarking.
struct AutoZeroReverse <: AbstractADType end

ADTypes.mode(::AutoZeroReverse) = ReverseMode()
DI.check_available(::AutoZeroReverse) = true
DI.twoarg_support(::AutoZeroReverse) = DI.TwoArgSupported()
check_available(::AutoZeroReverse) = true
twoarg_support(::AutoZeroReverse) = TwoArgSupported()

DI.prepare_pullback(f, ::AutoZeroReverse, x, ty::Tangents) = NoPullbackExtras()
DI.prepare_pullback(f!, y, ::AutoZeroReverse, x, ty::Tangents) = NoPullbackExtras()
prepare_pullback(f, ::AutoZeroReverse, x, ty::Tangents) = NoPullbackExtras()
prepare_pullback(f!, y, ::AutoZeroReverse, x, ty::Tangents) = NoPullbackExtras()

function DI.value_and_pullback(
function value_and_pullback(
f, ::AutoZeroReverse, x, ty::Tangents{B}, ::NoPullbackExtras
) where {B}
y = f(x)
dxs = ntuple(Returns(zero(x)), Val(B))
dxs = ntuple(ReturnZero(x), Val(B))
return y, Tangents(dxs)
end

function DI.value_and_pullback(
function value_and_pullback(
f!, y, ::AutoZeroReverse, x, ty::Tangents{B}, ::NoPullbackExtras
) where {B}
f!(y, x)
dxs = ntuple(Returns(zero(x)), Val(B))
dxs = ntuple(ReturnZero(x), Val(B))
return y, Tangents(dxs)
end

function DI.value_and_pullback!(
function value_and_pullback!(
f, tx::Tangents, ::AutoZeroReverse, x, ty::Tangents, ::NoPullbackExtras
)
y = f(x)
for b in eachindex(tx.d)
zero!(tx.d[b])
_zero!(tx.d[b])
end
return y, tx
end

function DI.value_and_pullback!(
function value_and_pullback!(
f!, y, tx::Tangents, ::AutoZeroReverse, x, ty::Tangents, ::NoPullbackExtras
)
f!(y, x)
for b in eachindex(tx.d)
zero!(tx.d[b])
_zero!(tx.d[b])
end
return y, tx
end
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,4 @@ for backend in vcat(fromprimitive_backends)
@test DifferentiationInterface.pick_batchsize(backend, 100) == 5
end

## Dense backends

test_differentiation(fromprimitive_backends, default_scenarios(); logging=LOGGING);

test_differentiation(
fromprimitive_backends[1],
default_scenarios();
correctness=false,
type_stability=true,
second_order=false,
logging=LOGGING,
);
57 changes: 57 additions & 0 deletions DifferentiationInterface/test/Internals/zero_backends.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using DifferentiationInterface
using DifferentiationInterface: AutoZeroForward, AutoZeroReverse
using DifferentiationInterfaceTest
using ComponentArrays: ComponentArrays
using JLArrays: JLArrays
using StaticArrays: StaticArrays
using Test

LOGGING = get(ENV, "CI", "false") == "false"

zero_backends = [AutoZeroForward(), AutoZeroReverse()]

for backend in zero_backends
@test check_available(backend)
@test check_twoarg(backend)
end

## Type stability

test_differentiation(
zero_backends,
zero.(default_scenarios());
correctness=true,
type_stability=true,
excluded=[:second_derivative],
logging=LOGGING,
)

test_differentiation(
[
SecondOrder(AutoZeroForward(), AutoZeroReverse()),
SecondOrder(AutoZeroReverse(), AutoZeroForward()),
],
default_scenarios();
correctness=false,
type_stability=true,
first_order=false,
logging=LOGGING,
)

## Weird arrays

test_differentiation(
[AutoZeroForward(), AutoZeroReverse()],
zero.(vcat(component_scenarios(), static_scenarios()));
correctness=true,
logging=LOGGING,
)

if VERSION >= v"1.10"
test_differentiation(
[AutoZeroForward(), AutoZeroReverse()],
zero.(gpu_scenarios());
correctness=true,
logging=LOGGING,
)
end
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ include("scenarios/sparse.jl")
include("scenarios/allocfree.jl")
include("scenarios/extensions.jl")

include("utils/zero_backends.jl")
include("utils/misc.jl")
include("utils/filter.jl")

Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterfaceTest/src/scenarios/modify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ maybe_zero(x::AbstractArray) = zero(x)
maybe_zero(x::Tangents) = Tangents(map(maybe_zero, x.d))
maybe_zero(::Nothing) = nothing

function scenario_to_zero(scen::Scenario{op,args,pl}) where {op,args,pl}
function Base.zero(scen::Scenario{op,args,pl}) where {op,args,pl}
return Scenario{op,args,pl}(
scen.f;
x=scen.x,
Expand Down
36 changes: 18 additions & 18 deletions DifferentiationInterfaceTest/src/tests/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,19 @@ function run_benchmark!(
# benchmark
extras = prepare_pushforward(f, ba, x, seed)
bench0 = @be prepare_pushforward(f, ba, x, seed) samples = 1 evals = 1
bench1 = @be (dy=mysimilar(y), ext=deepcopy(extras)) value_and_pushforward!(
bench1 = @be (dy=mysimilar(scen.res1), ext=deepcopy(extras)) value_and_pushforward!(
f, _.dy, ba, x, seed, _.ext
) evals = 1
bench2 = @be (dy=mysimilar(y), ext=deepcopy(extras)) pushforward!(
bench2 = @be (dy=mysimilar(scen.res1), ext=deepcopy(extras)) pushforward!(
f, _.dy, ba, x, seed, _.ext
) evals = 1
# count
cc = CallCounter(f)
extras = prepare_pushforward(cc, ba, x, seed)
calls0 = reset_count!(cc)
value_and_pushforward!(cc, mysimilar(y), ba, x, seed, extras)
value_and_pushforward!(cc, mysimilar(scen.res1), ba, x, seed, extras)
calls1 = reset_count!(cc)
pushforward!(cc, mysimilar(y), ba, x, seed, extras)
pushforward!(cc, mysimilar(scen.res1), ba, x, seed, extras)
calls2 = reset_count!(cc)
(; bench0, bench1, bench2, calls0, calls1, calls2)
catch e
Expand Down Expand Up @@ -250,19 +250,19 @@ function run_benchmark!(
extras = prepare_pushforward(f!, y, ba, x, seed)
bench0 = @be mysimilar(y) prepare_pushforward(f!, _, ba, x, seed) evals = 1 samples =
1
bench1 = @be (y=mysimilar(y), dy=mysimilar(y), ext=deepcopy(extras)) value_and_pushforward!(
bench1 = @be (y=mysimilar(y), dy=mysimilar(scen.res1), ext=deepcopy(extras)) value_and_pushforward!(
f!, _.y, _.dy, ba, x, seed, _.ext
) evals = 1
bench2 = @be (y=mysimilar(y), dy=mysimilar(y), ext=deepcopy(extras)) pushforward!(
bench2 = @be (y=mysimilar(y), dy=mysimilar(scen.res1), ext=deepcopy(extras)) pushforward!(
f!, _.y, _.dy, ba, x, seed, _.ext
) evals = 1
# count
cc! = CallCounter(f!)
extras = prepare_pushforward(cc!, mysimilar(y), ba, x, seed)
calls0 = reset_count!(cc!)
value_and_pushforward!(cc!, mysimilar(y), mysimilar(y), ba, x, seed, extras)
value_and_pushforward!(cc!, mysimilar(y), mysimilar(scen.res1), ba, x, seed, extras)
calls1 = reset_count!(cc!)
pushforward!(cc!, mysimilar(y), mysimilar(y), ba, x, seed, extras)
pushforward!(cc!, mysimilar(y), mysimilar(scen.res1), ba, x, seed, extras)
calls2 = reset_count!(cc!)
(; bench0, bench1, bench2, calls0, calls1, calls2)
catch e
Expand Down Expand Up @@ -326,19 +326,19 @@ function run_benchmark!(
# benchmark
extras = prepare_pullback(f, ba, x, seed)
bench0 = @be prepare_pullback(f, ba, x, seed) samples = 1 evals = 1
bench1 = @be (dx=mysimilar(x), ext=deepcopy(extras)) value_and_pullback!(
bench1 = @be (dx=mysimilar(scen.res1), ext=deepcopy(extras)) value_and_pullback!(
f, _.dx, ba, x, seed, _.ext
) evals = 1
bench2 = @be (dx=mysimilar(x), ext=deepcopy(extras)) pullback!(
bench2 = @be (dx=mysimilar(scen.res1), ext=deepcopy(extras)) pullback!(
f, _.dx, ba, x, seed, _.ext
) evals = 1
# count
cc = CallCounter(f)
extras = prepare_pullback(cc, ba, x, seed)
calls0 = reset_count!(cc)
value_and_pullback!(cc, mysimilar(x), ba, x, seed, extras)
value_and_pullback!(cc, mysimilar(scen.res1), ba, x, seed, extras)
calls1 = reset_count!(cc)
pullback!(cc, mysimilar(x), ba, x, seed, extras)
pullback!(cc, mysimilar(scen.res1), ba, x, seed, extras)
calls2 = reset_count!(cc)
(; bench0, bench1, bench2, calls0, calls1, calls2)
catch e
Expand Down Expand Up @@ -408,19 +408,19 @@ function run_benchmark!(
extras = prepare_pullback(f!, mysimilar(y), ba, x, seed)
bench0 = @be mysimilar(y) prepare_pullback(f!, _, ba, x, seed) samples = 1 evals =
1
bench1 = @be (y=mysimilar(y), dx=mysimilar(x), ext=deepcopy(extras)) value_and_pullback!(
bench1 = @be (y=mysimilar(y), dx=mysimilar(scen.res1), ext=deepcopy(extras)) value_and_pullback!(
f!, _.y, _.dx, ba, x, seed, _.ext
) evals = 1
bench2 = @be (y=mysimilar(y), dx=mysimilar(x), ext=deepcopy(extras)) pullback!(
bench2 = @be (y=mysimilar(y), dx=mysimilar(scen.res1), ext=deepcopy(extras)) pullback!(
f!, _.y, _.dx, ba, x, seed, _.ext
) evals = 1
# count
cc! = CallCounter(f!)
extras = prepare_pullback(cc!, mysimilar(y), ba, x, seed)
calls0 = reset_count!(cc!)
value_and_pullback!(cc!, mysimilar(y), mysimilar(x), ba, x, seed, extras)
value_and_pullback!(cc!, mysimilar(y), mysimilar(scen.res1), ba, x, seed, extras)
calls1 = reset_count!(cc!)
pullback!(cc!, mysimilar(y), mysimilar(x), ba, x, seed, extras)
pullback!(cc!, mysimilar(y), mysimilar(scen.res1), ba, x, seed, extras)
calls2 = reset_count!(cc!)
(; bench0, bench1, bench2, calls0, calls1, calls2)
catch e
Expand Down Expand Up @@ -946,14 +946,14 @@ function run_benchmark!(
# benchmark
extras = prepare_hvp(f, ba, x, seed)
bench0 = @be prepare_hvp(f, ba, x, seed) samples = 1 evals = 1
bench1 = @be (dg=mysimilar(x), ext=deepcopy(extras)) hvp!(
bench1 = @be (dg=mysimilar(scen.res2), ext=deepcopy(extras)) hvp!(
f, _.dg, ba, x, seed, _.ext
) evals = 1
# count
cc = CallCounter(f)
extras = prepare_hvp(cc, ba, x, seed)
calls0 = reset_count!(cc)
hvp!(cc, mysimilar(x), ba, x, seed, extras)
hvp!(cc, mysimilar(scen.res2), ba, x, seed, extras)
calls1 = reset_count!(cc)
(; bench0, bench1, calls0, calls1)
catch e
Expand Down
Loading

0 comments on commit ea46769

Please sign in to comment.