Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HVP for everyone #77

Merged
merged 2 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/src/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ For second-order differentiation, you can either

In Hessian computations, the most efficient combination is often forward-over-reverse, i.e. `SecondOrder(reverse_backend, forward_backend)`.

!!! danger
!!! info
Many backend combinations will fail for second order.
Some because of our implementation, and some because the outer backend cannot differentiate through code generated by the inner backend.

Expand All @@ -120,7 +120,6 @@ end # hide
Markdown.parse(join(vcat(header, subheader, rows...), "\n") * "\n") # hide
```


## Package extensions

```@meta
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ using DifferentiationInterface:
MutationSupported,
MutationNotSupported,
mode,
mutation_behavior
mutation_behavior,
outer
using DifferentiationInterface.DifferentiationTest
import DifferentiationInterface.DifferentiationTest as DT
using Test
Expand Down Expand Up @@ -63,7 +64,7 @@ function DT.run_benchmark(
@testset "$op" for op in operators
results = all_results[backend_string(backend)][op]
if op == :pushforward_allocating
for s in allocating(scenarios)
@testset "$(scen_string(s))" for s in allocating(scenarios)
merge!(
results[scen_id(s)...],
run_benchmark_pushforward_allocating(
Expand All @@ -72,7 +73,7 @@ function DT.run_benchmark(
)
end
elseif op == :pushforward_mutating
for s in mutating(scenarios)
@testset "$(scen_string(s))" for s in mutating(scenarios)
merge!(
results[scen_id(s)...],
run_benchmark_pushforward_mutating(
Expand All @@ -82,22 +83,23 @@ function DT.run_benchmark(
end

elseif op == :pullback_allocating
for s in allocating(scenarios)
@testset "$(scen_string(s))" for s in allocating(scenarios)
merge!(
results[scen_id(s)...],
run_benchmark_pullback_allocating(backend, s; test_allocations),
)
end
elseif op == :pullback_mutating
for s in mutating(scenarios)
@testset "$(scen_string(s))" for s in mutating(scenarios)
merge!(
results[scen_id(s)...],
run_benchmark_pullback_mutating(backend, s; test_allocations),
)
end

elseif op == :derivative_allocating
for s in allocating(scalar_scalar(scenarios))
@testset "$(scen_string(s))" for s in
allocating(scalar_scalar(scenarios))
merge!(
results[scen_id(s)...],
run_benchmark_derivative_allocating(
Expand All @@ -107,7 +109,8 @@ function DT.run_benchmark(
end

elseif op == :multiderivative_allocating
for s in allocating(scalar_array(scenarios))
@testset "$(scen_string(s))" for s in
allocating(scalar_array(scenarios))
merge!(
results[scen_id(s)...],
run_benchmark_multiderivative_allocating(
Expand All @@ -116,7 +119,7 @@ function DT.run_benchmark(
)
end
elseif op == :multiderivative_mutating
for s in mutating(scalar_array(scenarios))
@testset "$(scen_string(s))" for s in mutating(scalar_array(scenarios))
merge!(
results[scen_id(s)...],
run_benchmark_multiderivative_mutating(
Expand All @@ -126,30 +129,32 @@ function DT.run_benchmark(
end

elseif op == :gradient_allocating
for s in allocating(array_scalar(scenarios))
@testset "$(scen_string(s))" for s in
allocating(array_scalar(scenarios))
merge!(
results[scen_id(s)...],
run_benchmark_gradient_allocating(backend, s; test_allocations),
)
end

elseif op == :jacobian_allocating
for s in allocating(array_array(scenarios))
@testset "$(scen_string(s))" for s in allocating(array_array(scenarios))
merge!(
results[scen_id(s)...],
run_benchmark_jacobian_allocating(backend, s; test_allocations),
)
end
elseif op == :jacobian_mutating
for s in mutating(array_array(scenarios))
@testset "$(scen_string(s))" for s in mutating(array_array(scenarios))
merge!(
results[scen_id(s)...],
run_benchmark_jacobian_mutating(backend, s; test_allocations),
)
end

elseif op == :second_derivative_allocating
for s in allocating(scalar_scalar(scenarios))
@testset "$(scen_string(s))" for s in
allocating(scalar_scalar(scenarios))
merge!(
results[scen_id(s)...],
run_benchmark_second_derivative_allocating(
Expand All @@ -159,7 +164,8 @@ function DT.run_benchmark(
end

elseif op == :hessian_vector_product_allocating
for s in allocating(array_scalar(scenarios))
@testset "$(scen_string(s))" for s in
allocating(array_scalar(scenarios))
merge!(
results[scen_id(s)...],
run_benchmark_hessian_vector_product_allocating(
Expand All @@ -168,7 +174,8 @@ function DT.run_benchmark(
)
end
elseif op == :hessian_allocating
for s in allocating(array_scalar(scenarios))
@testset "$(scen_string(s))" for s in
allocating(array_scalar(scenarios))
merge!(
results[scen_id(s)...],
run_benchmark_hessian_allocating(backend, s; test_allocations),
Expand Down Expand Up @@ -357,10 +364,16 @@ function run_benchmark_hessian_vector_product_allocating(
(; f, x, dx) = deepcopy(scenario)
extras = prepare_hessian_vector_product(ba, f, x)
bench1 = @be zero(dx) hessian_vector_product!(_, ba, f, x, dx, extras)
if test_allocations
bench2 = @be (zero(dx), zero(dx)) gradient_and_hessian_vector_product!(
_[1], _[2], ba, f, x, dx, extras
)
if test_allocations # TODO: distinguish
soft_test_zero(minimum(bench1).allocs)
soft_test_zero(minimum(bench2).allocs)
end
return Dict(:hessian_vector_product! => bench1)
return Dict(
:hessian_vector_product! => bench1, :gradient_and_hessian_vector_product! => bench2
)
end

## Hessian
Expand Down
18 changes: 12 additions & 6 deletions ext/DifferentiationInterfaceEnzymeExt/forward_allocating.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,44 @@
function DI.value_and_pushforward!(
_dy::Real, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing
)
y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx))
dx_sametype = convert(typeof(x), dx)
y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx_sametype))
return y, new_dy
end

function DI.value_and_pushforward!(
dy::AbstractArray, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing
)
y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx))
dx_sametype = convert(typeof(x), dx)
y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx_sametype))
dy .= new_dy
return y, dy
end

function DI.pushforward!(_dy::Real, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing)
new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx)))
dx_sametype = convert(typeof(x), dx)
new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx_sametype)))
return new_dy
end

function DI.pushforward!(
dy::AbstractArray, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing
)
new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx)))
dx_sametype = convert(typeof(x), dx)
new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx_sametype)))
dy .= new_dy
return dy
end

function DI.value_and_pushforward(backend::AutoForwardEnzyme, f, x, dx, extras::Nothing)
y, dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx))
dx_sametype = convert(typeof(x), dx)
y, dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx_sametype))
return y, dy
end

function DI.pushforward(backend::AutoForwardEnzyme, f, x, dx, extras::Nothing)
dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx)))
dx_sametype = convert(typeof(x), dx)
dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx_sametype)))
return dy
end

Expand Down
12 changes: 8 additions & 4 deletions ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ end
function DI.value_and_pullback!(
dx::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray, dy::Number, extras::Nothing
)
dx .= zero(eltype(dx))
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx))
dx_sametype = convert(typeof(x), dx)
dx_sametype .= zero(eltype(dx_sametype))
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx_sametype))
dx .= dx_sametype
dx .*= dy
return y, dx
end
Expand All @@ -39,8 +41,10 @@ end
function DI.pullback!(
dx::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray, dy::Number, extras::Nothing
)
dx .= zero(eltype(dx))
autodiff(Reverse, f, Active, Duplicated(x, dx))
dx_sametype = convert(typeof(x), dx)
dx_sametype .= zero(eltype(dx_sametype))
autodiff(Reverse, f, Active, Duplicated(x, dx_sametype))
dx .= dx_sametype
dx .*= dy
return dx
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ using DifferentiationInterface:
MutationSupported,
MutationNotSupported,
mode,
mutation_behavior
mutation_behavior,
outer
import DifferentiationInterface as DI
using DifferentiationInterface.DifferentiationTest
import DifferentiationInterface.DifferentiationTest as DT
Expand Down
50 changes: 36 additions & 14 deletions ext/DifferentiationInterfaceForwardDiffExt/test_correctness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,62 +10,68 @@ function DT.test_correctness(
@testset verbose = true "$(backend_string(backend))" for backend in backends
@testset "$op" for op in operators
if op == :pushforward_allocating
@testset "$(typeof(s))" for s in allocating(scenarios)
@testset "$(scen_string(s))" for s in allocating(scenarios)
test_correctness_pushforward_allocating(backend, s)
end
elseif op == :pushforward_mutating
@testset "$(typeof(s))" for s in mutating(scenarios)
@testset "$(scen_string(s))" for s in mutating(scenarios)
test_correctness_pushforward_mutating(backend, s)
end

elseif op == :pullback_allocating
@testset "$(typeof(s))" for s in allocating(scenarios)
@testset "$(scen_string(s))" for s in allocating(scenarios)
test_correctness_pullback_allocating(backend, s)
end
elseif op == :pullback_mutating
@testset "$(typeof(s))" for s in mutating(scenarios)
@testset "$(scen_string(s))" for s in mutating(scenarios)
test_correctness_pullback_mutating(backend, s)
end

elseif op == :derivative_allocating
@testset "$(typeof(s))" for s in allocating(scalar_scalar(scenarios))
@testset "$(scen_string(s))" for s in
allocating(scalar_scalar(scenarios))
test_correctness_derivative_allocating(backend, s)
end

elseif op == :multiderivative_allocating
@testset "$(typeof(s))" for s in allocating(scalar_array(scenarios))
@testset "$(scen_string(s))" for s in
allocating(scalar_array(scenarios))
test_correctness_multiderivative_allocating(backend, s)
end
elseif op == :multiderivative_mutating
@testset "$(typeof(s))" for s in mutating(scalar_array(scenarios))
@testset "$(scen_string(s))" for s in mutating(scalar_array(scenarios))
test_correctness_multiderivative_mutating(backend, s)
end

elseif op == :gradient_allocating
@testset "$(typeof(s))" for s in allocating(array_scalar(scenarios))
@testset "$(scen_string(s))" for s in
allocating(array_scalar(scenarios))
test_correctness_gradient_allocating(backend, s)
end

elseif op == :jacobian_allocating
@testset "$(typeof(s))" for s in allocating(array_array(scenarios))
@testset "$(scen_string(s))" for s in allocating(array_array(scenarios))
test_correctness_jacobian_allocating(backend, s)
end
elseif op == :jacobian_mutating
@testset "$(typeof(s))" for s in mutating(array_array(scenarios))
@testset "$(scen_string(s))" for s in mutating(array_array(scenarios))
test_correctness_jacobian_mutating(backend, s)
end

elseif op == :second_derivative_allocating
@testset "$(typeof(s))" for s in allocating(scalar_scalar(scenarios))
@testset "$(scen_string(s))" for s in
allocating(scalar_scalar(scenarios))
test_correctness_second_derivative_allocating(backend, s)
end

elseif op == :hessian_vector_product_allocating
@testset "$(typeof(s))" for s in allocating(array_scalar(scenarios))
@testset "$(scen_string(s))" for s in
allocating(array_scalar(scenarios))
test_correctness_hessian_vector_product_allocating(backend, s)
end
elseif op == :hessian_allocating
@testset "$(typeof(s))" for s in allocating(array_scalar(scenarios))
@testset "$(scen_string(s))" for s in
allocating(array_scalar(scenarios))
test_correctness_hessian_allocating(backend, s)
end

Expand Down Expand Up @@ -393,14 +399,30 @@ function test_correctness_hessian_vector_product_allocating(
hvp_in6 = zero(hvp_out5)
hvp_out6 = DI.hessian_vector_product!(hvp_in6, ba, f, x, dx)

grad_out7, hvp_out7 = DI.gradient_and_hessian_vector_product(ba, f, x, dx)
grad_in8, hvp_in8 = zero(grad_out7), zero(hvp_out7)
grad_out8, hvp_out8 = DI.gradient_and_hessian_vector_product!(
grad_in8, hvp_in8, ba, f, x, dx
)

@testset "Gradient value" begin
@test grad_out7 ≈ grad_true rtol = 1e-3
@test grad_out8 ≈ grad_true rtol = 1e-3
@testset "Mutation" begin
@test grad_in8 ≈ grad_true rtol = 1e-3
end
end

@testset "Hessian-vector product value" begin
@test hvp_out5 ≈ hvp_true rtol = 1e-3
@test hvp_out6 ≈ hvp_true rtol = 1e-3
@test hvp_out7 ≈ hvp_true rtol = 1e-3
@test hvp_out8 ≈ hvp_true rtol = 1e-3
@testset "Mutation" begin
@test hvp_in6 ≈ hvp_true rtol = 1e-3
@test hvp_in8 ≈ hvp_true rtol = 1e-3
end
end
# TODO: add gradient
end

## Hessian
Expand Down
Loading
Loading