diff --git a/docs/src/backends.md b/docs/src/backends.md index 8ea1646a7..3ff08aeaa 100644 --- a/docs/src/backends.md +++ b/docs/src/backends.md @@ -96,7 +96,7 @@ For second-order differentiation, you can either In Hessian computations, the most efficient combination is often forward-over-reverse, i.e. `SecondOrder(reverse_backend, forward_backend)`. -!!! danger +!!! info Many backend combinations will fail for second order. Some because of our implementation, and some because the outer backend cannot differentiate through code generated by the inner backend. @@ -120,7 +120,6 @@ end # hide Markdown.parse(join(vcat(header, subheader, rows...), "\n") * "\n") # hide ``` - ## Package extensions ```@meta diff --git a/ext/DifferentiationInterfaceChairmarksExt/DifferentiationInterfaceChairmarksExt.jl b/ext/DifferentiationInterfaceChairmarksExt/DifferentiationInterfaceChairmarksExt.jl index adcd2f95c..2b282980d 100644 --- a/ext/DifferentiationInterfaceChairmarksExt/DifferentiationInterfaceChairmarksExt.jl +++ b/ext/DifferentiationInterfaceChairmarksExt/DifferentiationInterfaceChairmarksExt.jl @@ -9,7 +9,8 @@ using DifferentiationInterface: MutationSupported, MutationNotSupported, mode, - mutation_behavior + mutation_behavior, + outer using DifferentiationInterface.DifferentiationTest import DifferentiationInterface.DifferentiationTest as DT using Test @@ -63,7 +64,7 @@ function DT.run_benchmark( @testset "$op" for op in operators results = all_results[backend_string(backend)][op] if op == :pushforward_allocating - for s in allocating(scenarios) + @testset "$(scen_string(s))" for s in allocating(scenarios) merge!( results[scen_id(s)...], run_benchmark_pushforward_allocating( @@ -72,7 +73,7 @@ function DT.run_benchmark( ) end elseif op == :pushforward_mutating - for s in mutating(scenarios) + @testset "$(scen_string(s))" for s in mutating(scenarios) merge!( results[scen_id(s)...], run_benchmark_pushforward_mutating( @@ -82,14 +83,14 @@ function DT.run_benchmark( end elseif op == :pullback_allocating - for s in allocating(scenarios) + @testset "$(scen_string(s))" for s in allocating(scenarios) merge!( results[scen_id(s)...], run_benchmark_pullback_allocating(backend, s; test_allocations), ) end elseif op == :pullback_mutating - for s in mutating(scenarios) + @testset "$(scen_string(s))" for s in mutating(scenarios) merge!( results[scen_id(s)...], run_benchmark_pullback_mutating(backend, s; test_allocations), @@ -97,7 +98,8 @@ function DT.run_benchmark( end elseif op == :derivative_allocating - for s in allocating(scalar_scalar(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(scalar_scalar(scenarios)) merge!( results[scen_id(s)...], run_benchmark_derivative_allocating( @@ -107,7 +109,8 @@ function DT.run_benchmark( end elseif op == :multiderivative_allocating - for s in allocating(scalar_array(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(scalar_array(scenarios)) merge!( results[scen_id(s)...], run_benchmark_multiderivative_allocating( @@ -116,7 +119,7 @@ function DT.run_benchmark( ) end elseif op == :multiderivative_mutating - for s in mutating(scalar_array(scenarios)) + @testset "$(scen_string(s))" for s in mutating(scalar_array(scenarios)) merge!( results[scen_id(s)...], run_benchmark_multiderivative_mutating( @@ -126,7 +129,8 @@ function DT.run_benchmark( end elseif op == :gradient_allocating - for s in allocating(array_scalar(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(array_scalar(scenarios)) merge!( results[scen_id(s)...], run_benchmark_gradient_allocating(backend, s; test_allocations), @@ -134,14 +138,14 @@ function DT.run_benchmark( end elseif op == :jacobian_allocating - for s in allocating(array_array(scenarios)) + @testset "$(scen_string(s))" for s in allocating(array_array(scenarios)) merge!( results[scen_id(s)...], run_benchmark_jacobian_allocating(backend, s; test_allocations), ) end elseif op == :jacobian_mutating - for s in mutating(array_array(scenarios)) + @testset "$(scen_string(s))" for s in mutating(array_array(scenarios)) merge!( results[scen_id(s)...], run_benchmark_jacobian_mutating(backend, s; test_allocations), @@ -149,7 +153,8 @@ function DT.run_benchmark( end elseif op == :second_derivative_allocating - for s in allocating(scalar_scalar(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(scalar_scalar(scenarios)) merge!( results[scen_id(s)...], run_benchmark_second_derivative_allocating( @@ -159,7 +164,8 @@ function DT.run_benchmark( end elseif op == :hessian_vector_product_allocating - for s in allocating(array_scalar(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(array_scalar(scenarios)) merge!( results[scen_id(s)...], run_benchmark_hessian_vector_product_allocating( @@ -168,7 +174,8 @@ function DT.run_benchmark( ) end elseif op == :hessian_allocating - for s in allocating(array_scalar(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(array_scalar(scenarios)) merge!( results[scen_id(s)...], run_benchmark_hessian_allocating(backend, s; test_allocations), @@ -357,10 +364,16 @@ function run_benchmark_hessian_vector_product_allocating( (; f, x, dx) = deepcopy(scenario) extras = prepare_hessian_vector_product(ba, f, x) bench1 = @be zero(dx) hessian_vector_product!(_, ba, f, x, dx, extras) - if test_allocations + bench2 = @be (zero(dx), zero(dx)) gradient_and_hessian_vector_product!( + _[1], _[2], ba, f, x, dx, extras + ) + if test_allocations # TODO: distinguish soft_test_zero(minimum(bench1).allocs) + soft_test_zero(minimum(bench2).allocs) end - return Dict(:hessian_vector_product! => bench1) + return Dict( + :hessian_vector_product! => bench1, :gradient_and_hessian_vector_product! => bench2 + ) end ## Hessian diff --git a/ext/DifferentiationInterfaceEnzymeExt/forward_allocating.jl b/ext/DifferentiationInterfaceEnzymeExt/forward_allocating.jl index f07c7ab7e..735ba8eaf 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/forward_allocating.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/forward_allocating.jl @@ -3,38 +3,44 @@ function DI.value_and_pushforward!( _dy::Real, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing ) - y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx)) + dx_sametype = convert(typeof(x), dx) + y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx_sametype)) return y, new_dy end function DI.value_and_pushforward!( dy::AbstractArray, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing ) - y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx)) + dx_sametype = convert(typeof(x), dx) + y, new_dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx_sametype)) dy .= new_dy return y, dy end function DI.pushforward!(_dy::Real, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing) - new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx))) + dx_sametype = convert(typeof(x), dx) + new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx_sametype))) return new_dy end function DI.pushforward!( dy::AbstractArray, backend::AutoForwardEnzyme, f, x, dx, extras::Nothing ) - new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx))) + dx_sametype = convert(typeof(x), dx) + new_dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx_sametype))) dy .= new_dy return dy end function DI.value_and_pushforward(backend::AutoForwardEnzyme, f, x, dx, extras::Nothing) - y, dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx)) + dx_sametype = convert(typeof(x), dx) + y, dy = autodiff(backend.mode, f, Duplicated, Duplicated(x, dx_sametype)) return y, dy end function DI.pushforward(backend::AutoForwardEnzyme, f, x, dx, extras::Nothing) - dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx))) + dx_sametype = convert(typeof(x), dx) + dy = only(autodiff(backend.mode, f, DuplicatedNoNeed, Duplicated(x, dx_sametype))) return dy end diff --git a/ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl b/ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl index 734d51cde..f542e50f4 100644 --- a/ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl +++ b/ext/DifferentiationInterfaceEnzymeExt/reverse_allocating.jl @@ -22,8 +22,10 @@ end function DI.value_and_pullback!( dx::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray, dy::Number, extras::Nothing ) - dx .= zero(eltype(dx)) - _, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx)) + dx_sametype = convert(typeof(x), dx) + dx_sametype .= zero(eltype(dx_sametype)) + _, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx_sametype)) + dx .= dx_sametype dx .*= dy return y, dx end @@ -39,8 +41,10 @@ end function DI.pullback!( dx::AbstractArray, ::AutoReverseEnzyme, f, x::AbstractArray, dy::Number, extras::Nothing ) - dx .= zero(eltype(dx)) - autodiff(Reverse, f, Active, Duplicated(x, dx)) + dx_sametype = convert(typeof(x), dx) + dx_sametype .= zero(eltype(dx_sametype)) + autodiff(Reverse, f, Active, Duplicated(x, dx_sametype)) + dx .= dx_sametype dx .*= dy return dx end diff --git a/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl index fc6db4369..0dadf7745 100644 --- a/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl +++ b/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -7,7 +7,8 @@ using DifferentiationInterface: MutationSupported, MutationNotSupported, mode, - mutation_behavior + mutation_behavior, + outer import DifferentiationInterface as DI using DifferentiationInterface.DifferentiationTest import DifferentiationInterface.DifferentiationTest as DT diff --git a/ext/DifferentiationInterfaceForwardDiffExt/test_correctness.jl b/ext/DifferentiationInterfaceForwardDiffExt/test_correctness.jl index 3c999b66a..8a2099b64 100644 --- a/ext/DifferentiationInterfaceForwardDiffExt/test_correctness.jl +++ b/ext/DifferentiationInterfaceForwardDiffExt/test_correctness.jl @@ -10,62 +10,68 @@ function DT.test_correctness( @testset verbose = true "$(backend_string(backend))" for backend in backends @testset "$op" for op in operators if op == :pushforward_allocating - @testset "$(typeof(s))" for s in allocating(scenarios) + @testset "$(scen_string(s))" for s in allocating(scenarios) test_correctness_pushforward_allocating(backend, s) end elseif op == :pushforward_mutating - @testset "$(typeof(s))" for s in mutating(scenarios) + @testset "$(scen_string(s))" for s in mutating(scenarios) test_correctness_pushforward_mutating(backend, s) end elseif op == :pullback_allocating - @testset "$(typeof(s))" for s in allocating(scenarios) + @testset "$(scen_string(s))" for s in allocating(scenarios) test_correctness_pullback_allocating(backend, s) end elseif op == :pullback_mutating - @testset "$(typeof(s))" for s in mutating(scenarios) + @testset "$(scen_string(s))" for s in mutating(scenarios) test_correctness_pullback_mutating(backend, s) end elseif op == :derivative_allocating - @testset "$(typeof(s))" for s in allocating(scalar_scalar(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(scalar_scalar(scenarios)) test_correctness_derivative_allocating(backend, s) end elseif op == :multiderivative_allocating - @testset "$(typeof(s))" for s in allocating(scalar_array(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(scalar_array(scenarios)) test_correctness_multiderivative_allocating(backend, s) end elseif op == :multiderivative_mutating - @testset "$(typeof(s))" for s in mutating(scalar_array(scenarios)) + @testset "$(scen_string(s))" for s in mutating(scalar_array(scenarios)) test_correctness_multiderivative_mutating(backend, s) end elseif op == :gradient_allocating - @testset "$(typeof(s))" for s in allocating(array_scalar(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(array_scalar(scenarios)) test_correctness_gradient_allocating(backend, s) end elseif op == :jacobian_allocating - @testset "$(typeof(s))" for s in allocating(array_array(scenarios)) + @testset "$(scen_string(s))" for s in allocating(array_array(scenarios)) test_correctness_jacobian_allocating(backend, s) end elseif op == :jacobian_mutating - @testset "$(typeof(s))" for s in mutating(array_array(scenarios)) + @testset "$(scen_string(s))" for s in mutating(array_array(scenarios)) test_correctness_jacobian_mutating(backend, s) end elseif op == :second_derivative_allocating - @testset "$(typeof(s))" for s in allocating(scalar_scalar(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(scalar_scalar(scenarios)) test_correctness_second_derivative_allocating(backend, s) end elseif op == :hessian_vector_product_allocating - @testset "$(typeof(s))" for s in allocating(array_scalar(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(array_scalar(scenarios)) test_correctness_hessian_vector_product_allocating(backend, s) end elseif op == :hessian_allocating - @testset "$(typeof(s))" for s in allocating(array_scalar(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(array_scalar(scenarios)) test_correctness_hessian_allocating(backend, s) end @@ -393,14 +399,30 @@ function test_correctness_hessian_vector_product_allocating( hvp_in6 = zero(hvp_out5) hvp_out6 = DI.hessian_vector_product!(hvp_in6, ba, f, x, dx) + grad_out7, hvp_out7 = DI.gradient_and_hessian_vector_product(ba, f, x, dx) + grad_in8, hvp_in8 = zero(grad_out7), zero(hvp_out7) + grad_out8, hvp_out8 = DI.gradient_and_hessian_vector_product!( + grad_in8, hvp_in8, ba, f, x, dx + ) + + @testset "Gradient value" begin + @test grad_out7 ≈ grad_true rtol = 1e-3 + @test grad_out8 ≈ grad_true rtol = 1e-3 + @testset "Mutation" begin + @test grad_in8 ≈ grad_true rtol = 1e-3 + end + end + @testset "Hessian-vector product value" begin @test hvp_out5 ≈ hvp_true rtol = 1e-3 @test hvp_out6 ≈ hvp_true rtol = 1e-3 + @test hvp_out7 ≈ hvp_true rtol = 1e-3 + @test hvp_out8 ≈ hvp_true rtol = 1e-3 @testset "Mutation" begin @test hvp_in6 ≈ hvp_true rtol = 1e-3 + @test hvp_in8 ≈ hvp_true rtol = 1e-3 end end - # TODO: add gradient end ## Hessian diff --git a/ext/DifferentiationInterfaceJETExt/DifferentiationInterfaceJETExt.jl b/ext/DifferentiationInterfaceJETExt/DifferentiationInterfaceJETExt.jl index 0c145cde0..5cf3b9810 100644 --- a/ext/DifferentiationInterfaceJETExt/DifferentiationInterfaceJETExt.jl +++ b/ext/DifferentiationInterfaceJETExt/DifferentiationInterfaceJETExt.jl @@ -8,7 +8,8 @@ using DifferentiationInterface: MutationSupported, MutationNotSupported, mode, - mutation_behavior + mutation_behavior, + outer using DifferentiationInterface.DifferentiationTest import DifferentiationInterface.DifferentiationTest as DT using JET: @test_opt @@ -26,62 +27,68 @@ function DT.test_type_stability( @testset verbose = true "$(backend_string(backend))" for backend in backends @testset "$op" for op in operators if op == :pushforward_allocating - @testset "$(typeof(s))" for s in allocating(scenarios) + @testset "$(scen_string(s))" for s in allocating(scenarios) test_type_pushforward_allocating(backend, s) end elseif op == :pushforward_mutating - @testset "$(typeof(s))" for s in mutating(scenarios) + @testset "$(scen_string(s))" for s in mutating(scenarios) test_type_pushforward_mutating(backend, s) end elseif op == :pullback_allocating - @testset "$(typeof(s))" for s in allocating(scenarios) + @testset "$(scen_string(s))" for s in allocating(scenarios) test_type_pullback_allocating(backend, s) end elseif op == :pullback_mutating - @testset "$(typeof(s))" for s in mutating(scenarios) + @testset "$(scen_string(s))" for s in mutating(scenarios) test_type_pullback_mutating(backend, s) end elseif op == :derivative_allocating - @testset "$(typeof(s))" for s in allocating(scalar_scalar(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(scalar_scalar(scenarios)) test_type_derivative_allocating(backend, s) end elseif op == :multiderivative_allocating - @testset "$(typeof(s))" for s in allocating(scalar_array(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(scalar_array(scenarios)) test_type_multiderivative_allocating(backend, s) end elseif op == :multiderivative_mutating - @testset "$(typeof(s))" for s in mutating(scalar_array(scenarios)) + @testset "$(scen_string(s))" for s in mutating(scalar_array(scenarios)) test_type_multiderivative_mutating(backend, s) end elseif op == :gradient_allocating - @testset "$(typeof(s))" for s in allocating(array_scalar(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(array_scalar(scenarios)) test_type_gradient_allocating(backend, s) end elseif op == :jacobian_allocating - @testset "$(typeof(s))" for s in allocating(array_array(scenarios)) + @testset "$(scen_string(s))" for s in allocating(array_array(scenarios)) test_type_jacobian_allocating(backend, s) end elseif op == :jacobian_mutating - @testset "$(typeof(s))" for s in mutating(array_array(scenarios)) + @testset "$(scen_string(s))" for s in mutating(array_array(scenarios)) test_type_jacobian_mutating(backend, s) end elseif op == :second_derivative_allocating - @testset "$(typeof(s))" for s in allocating(scalar_scalar(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(scalar_scalar(scenarios)) test_type_second_derivative_allocating(backend, s) end elseif op == :hessian_vector_product_allocating - @testset "$(typeof(s))" for s in allocating(array_scalar(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(array_scalar(scenarios)) test_type_hessian_vector_product_allocating(backend, s) end elseif op == :hessian_allocating - @testset "$(typeof(s))" for s in allocating(array_scalar(scenarios)) + @testset "$(scen_string(s))" for s in + allocating(array_scalar(scenarios)) test_type_hessian_allocating(backend, s) end @@ -208,13 +215,20 @@ end function test_type_hessian_vector_product_allocating(ba::AbstractADType, scen::Scenario) (; f, x, dx) = deepcopy(scen) + grad_in = zero(dx) hvp_in = zero(dx) @test_opt ignored_modules = (LinearAlgebra,) hessian_vector_product!( hvp_in, ba, f, x, dx ) @test_opt ignored_modules = (LinearAlgebra,) hessian_vector_product(ba, f, x, dx) - # TODO: add gradient + @test_opt ignored_modules = (LinearAlgebra,) gradient_and_hessian_vector_product!( + grad_in, hvp_in, ba, f, x, dx + ) + @test_opt ignored_modules = (LinearAlgebra,) gradient_and_hessian_vector_product( + ba, f, x, dx + ) end + ## Hessian function test_type_hessian_allocating(ba::AbstractADType, scen::Scenario) diff --git a/src/DifferentiationInterface.jl b/src/DifferentiationInterface.jl index 29375f4a8..a1ca19cb8 100644 --- a/src/DifferentiationInterface.jl +++ b/src/DifferentiationInterface.jl @@ -12,7 +12,7 @@ module DifferentiationInterface using ADTypes: ADTypes, AbstractADType using DocStringExtensions using FillArrays: OneElement -using LinearAlgebra: dot +using LinearAlgebra: dot, norm using Test: Test """ diff --git a/src/DifferentiationTest/DifferentiationTest.jl b/src/DifferentiationTest/DifferentiationTest.jl index fddaee2c0..e10b94520 100644 --- a/src/DifferentiationTest/DifferentiationTest.jl +++ b/src/DifferentiationTest/DifferentiationTest.jl @@ -13,7 +13,9 @@ using ..DifferentiationInterface: ReverseMode, SymbolicMode, SecondOrder, - mode + mode, + inner, + outer using ADTypes using ADTypes: AbstractADType using DocStringExtensions @@ -25,7 +27,7 @@ include("default_scenarios.jl") include("test_operators.jl") include("pretty.jl") -export Scenario, default_scenarios +export Scenario, default_scenarios, scen_string export allocating, mutating export scalar_scalar, scalar_array, array_scalar, array_array export test_operators, parse_benchmark diff --git a/src/DifferentiationTest/pretty.jl b/src/DifferentiationTest/pretty.jl index 5d06c324a..cde815cd9 100644 --- a/src/DifferentiationTest/pretty.jl +++ b/src/DifferentiationTest/pretty.jl @@ -33,5 +33,5 @@ function backend_string(backend::AbstractADType) end function backend_string(backend::SecondOrder) - return "$(backend_string(backend.outer)) / $(backend_string(backend.inner))" + return "$(backend_string(inner(backend))) + $(backend_string(outer(backend)))" end diff --git a/src/DifferentiationTest/scenario.jl b/src/DifferentiationTest/scenario.jl index 4676e146e..6a8312c4d 100644 --- a/src/DifferentiationTest/scenario.jl +++ b/src/DifferentiationTest/scenario.jl @@ -23,6 +23,8 @@ $(TYPEDFIELDS) dy::Y end +scen_string(scen::Scenario) = "$(string(scen.f)): $(typeof(scen.x)) -> $(typeof(scen.y))" + function Scenario(f, x::Union{Number,AbstractArray}) y = f(x) dx = similar_random(x) diff --git a/src/DifferentiationTest/test_operators.jl b/src/DifferentiationTest/test_operators.jl index e29e841c2..5014c9226 100644 --- a/src/DifferentiationTest/test_operators.jl +++ b/src/DifferentiationTest/test_operators.jl @@ -110,7 +110,7 @@ function test_operators( operators; first_order, second_order, allocating, mutating, excluded ) result = nothing - @testset verbose = true "Backend tests" begin + set = @testset verbose = true "Backend tests" begin if correctness test_correctness(backends, operators, scenarios) end diff --git a/src/backends.jl b/src/backends.jl index bce797cd2..181946a2d 100644 --- a/src/backends.jl +++ b/src/backends.jl @@ -19,6 +19,8 @@ end available(backend::SecondOrder) = available(inner(backend)) && available(outer(backend)) +square!(y::AbstractArray, x::AbstractArray) = y .= x .^ 2 + """ supports_mutation(backend) @@ -27,13 +29,18 @@ Might take a while due to compilation time. """ function supports_mutation(backend::AbstractADType) try - value_and_jacobian!([0.0], [0.0;;], backend, copyto!, [1.0]) - return true + x = [3.0] + y = [0.0] + jac = [0.0;;] + value_and_jacobian!(y, jac, backend, square!, x) + return isapprox(y, [9.0]; rtol=1e-3) && isapprox(jac, [6.0;;]; rtol=1e-3) catch e return false end end +sqnorm(x::AbstractArray) = sum(abs2, x) + """ supports_hessian(backend) @@ -42,8 +49,9 @@ Might take a while due to compilation time. """ function supports_hessian(backend::AbstractADType) try - hessian(backend, sum, [1.0]) - return true + x = [3.0] + hess = hessian(backend, sqnorm, x) + return isapprox(hess, [2.0;;]; rtol=1e-3) catch e return false end diff --git a/src/hessian_vector_product.jl b/src/hessian_vector_product.jl index 14c6f3dbf..e5f5d7e46 100644 --- a/src/hessian_vector_product.jl +++ b/src/hessian_vector_product.jl @@ -12,9 +12,6 @@ Start by reading the allocating versions gradient_and_hessian_vector_product(backend, f, x, v, [extras]) -> (grad, hvp) Compute the gradient `grad = ∇f(x)` and the Hessian-vector product `hvp = ∇²f(x) * v` of an array-to-scalar function. - -!!! warning - Only works with a forward outer mode. """ function gradient_and_hessian_vector_product( backend::AbstractADType, @@ -48,18 +45,17 @@ function gradient_and_hessian_vector_product_aux( end function gradient_and_hessian_vector_product_aux( - backend, f, x, v, extras, ::AbstractMode, ::ReverseMode -) - throw(ArgumentError("HVP must be computed without gradient for reverse-over-something")) + backend, f::F, x, v, extras, ::AbstractMode, ::ReverseMode +) where {F} + grad = gradient(inner(backend), f, x) + hvp = hessian_vector_product(backend, f, x, v, extras) + return grad, hvp end """ gradient_and_hessian_vector_product!(grad, backend, hvp, backend, f, x, v, [extras]) -> (grad, hvp) Compute the gradient `grad = ∇f(x)` and the Hessian-vector product `hvp = ∇²f(x) * v` of an array-to-scalar function, overwriting `grad` and `hvp`. - -!!! warning - Only works with a forward outer mode. """ function gradient_and_hessian_vector_product!( grad::AbstractArray, @@ -85,12 +81,30 @@ function gradient_and_hessian_vector_product!( extras=prepare_hessian_vector_product(backend, f, x), ) where {F} return gradient_and_hessian_vector_product_aux!( - grad, hvp, backend, f, x, v, extras, mode(inner(backend)), mode(outer(backend)) + grad, + hvp, + backend, + f, + x, + v, + extras, + mode(inner(backend)), + mode(outer(backend)), + mutation_behavior(inner(backend)), ) end function gradient_and_hessian_vector_product_aux!( - grad, hvp, backend, f::F, x, v, extras, ::AbstractMode, ::ForwardMode + grad, + hvp, + backend, + f::F, + x, + v, + extras, + ::AbstractMode, + ::ForwardMode, + ::MutationSupported, ) where {F} function grad_aux!(storage, z) gradient!(storage, inner(backend), f, z, extras) @@ -100,9 +114,38 @@ function gradient_and_hessian_vector_product_aux!( end function gradient_and_hessian_vector_product_aux!( - grad, hvp, backend, f, x, v, extras, ::AbstractMode, ::ReverseMode -) - throw(ArgumentError("HVP must be computed without gradient for reverse-over-something")) + grad, + hvp, + backend, + f::F, + x, + v, + extras, + ::AbstractMode, + ::ForwardMode, + ::MutationNotSupported, +) where {F} + grad_aux(z) = gradient(inner(backend), f, z, extras) + new_grad, hvp = value_and_pushforward!(hvp, outer(backend), grad_aux, x, v, extras) + grad .= new_grad + return grad, hvp +end + +function gradient_and_hessian_vector_product_aux!( + grad, + hvp, + backend, + f::F, + x, + v, + extras, + ::AbstractMode, + ::AbstractMode, + ::MutationBehavior, +) where {F} + grad = gradient!(grad, inner(backend), f, x) + hvp = hessian_vector_product!(hvp, backend, f, x, v, extras) + return grad, hvp end ## All backends can give the HVP diff --git a/src/mutation.jl b/src/mutation.jl index 7d6e5e390..e5b15ca04 100644 --- a/src/mutation.jl +++ b/src/mutation.jl @@ -17,6 +17,10 @@ struct MutationNotSupported <: MutationBehavior end """ mutation_behavior(backend) -Return the mutation behavior of a backend. +Return the mutation behavior of a backend in a statically predictable way. + +# Note + +This is different from [`supports_mutation`](@ref), which performs an actual call to `jacobian!`. """ mutation_behavior(::AbstractADType) = MutationSupported() diff --git a/src/second_order.jl b/src/second_order.jl index 547ecbea7..7f44dccc9 100644 --- a/src/second_order.jl +++ b/src/second_order.jl @@ -16,6 +16,7 @@ end inner(backend::SecondOrder) = backend.inner outer(backend::SecondOrder) = backend.outer +mode(backend::SecondOrder) = (mode(inner(backend)), mode(outer(backend))) function Base.show(io::IO, backend::SecondOrder) return print(io, "SecondOrder($(inner(backend)), $(outer(backend)))") diff --git a/src/utils.jl b/src/utils.jl index bd4bce766..0d43a573a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -19,6 +19,7 @@ mysimilar(x::AbstractArray{T}) where {T} = similar(x, T, axes(x)) # strip struct update!(_old::Number, new::Number) = new update!(old, new) = old .= new +update!(old, new::Nothing) = old zero!(x::Number) = zero(x) zero!(x) = x .= zero(eltype(x)) diff --git a/test/chainrules_reverse.jl b/test/chainrules_reverse.jl index f4e94f7f1..67e0c0253 100644 --- a/test/chainrules_reverse.jl +++ b/test/chainrules_reverse.jl @@ -9,5 +9,6 @@ using Test @test available(AutoChainRules(ZygoteRuleConfig())) @test !supports_mutation(AutoChainRules(ZygoteRuleConfig())) +@test supports_hessian(AutoChainRules(ZygoteRuleConfig())) -test_operators(AutoChainRules(ZygoteRuleConfig()); second_order=false, type_stability=false); +test_operators(AutoChainRules(ZygoteRuleConfig()); type_stability=false); diff --git a/test/diffractor.jl b/test/diffractor.jl index 2dc8d8f81..42a609c7a 100644 --- a/test/diffractor.jl +++ b/test/diffractor.jl @@ -9,5 +9,6 @@ using Test @test available(AutoDiffractor()) @test !supports_mutation(AutoDiffractor()) +@test !supports_hessian(AutoDiffractor()) test_operators(AutoDiffractor(); second_order=false, type_stability=false); diff --git a/test/enzyme_forward.jl b/test/enzyme_forward.jl index f2d8fca4d..42f969069 100644 --- a/test/enzyme_forward.jl +++ b/test/enzyme_forward.jl @@ -9,6 +9,7 @@ using Test @test available(AutoEnzyme(Enzyme.Forward)) @test supports_mutation(AutoEnzyme(Enzyme.Forward)) +@test !supports_hessian(AutoEnzyme(Enzyme.Forward)) test_operators( AutoEnzyme(Enzyme.Forward); second_order=false, excluded=[:jacobian_allocating] diff --git a/test/enzyme_reverse.jl b/test/enzyme_reverse.jl index 14d4c04e1..071dc7ce6 100644 --- a/test/enzyme_reverse.jl +++ b/test/enzyme_reverse.jl @@ -9,5 +9,6 @@ using Test @test available(AutoEnzyme(Enzyme.Reverse)) @test supports_mutation(AutoEnzyme(Enzyme.Reverse)) +@test !supports_hessian(AutoEnzyme(Enzyme.Reverse)) test_operators(AutoEnzyme(Enzyme.Reverse); second_order=false); diff --git a/test/fastdifferentiation.jl b/test/fastdifferentiation.jl index e18ebc5cf..e32f55459 100644 --- a/test/fastdifferentiation.jl +++ b/test/fastdifferentiation.jl @@ -8,6 +8,8 @@ using JET: JET using Test @test available(AutoFastDifferentiation()) +@test !supports_mutation(AutoFastDifferentiation()) +@test !supports_hessian(AutoFastDifferentiation()) test_operators( AutoFastDifferentiation(); diff --git a/test/finitediff.jl b/test/finitediff.jl index 3de000011..996594380 100644 --- a/test/finitediff.jl +++ b/test/finitediff.jl @@ -9,6 +9,7 @@ using Test @test available(AutoFiniteDiff()) @test supports_mutation(AutoFiniteDiff()) +@test !supports_hessian(AutoFiniteDiff()) test_operators(AutoFiniteDiff(); second_order=false, excluded=[:jacobian_allocating]); test_operators(AutoFiniteDiff(), [:jacobian_allocating]; type_stability=false); diff --git a/test/reversediff.jl b/test/reversediff.jl index 39dc47b44..88cdb4b41 100644 --- a/test/reversediff.jl +++ b/test/reversediff.jl @@ -9,6 +9,7 @@ using Test @test available(AutoReverseDiff()) @test supports_mutation(AutoReverseDiff()) +@test supports_hessian(AutoReverseDiff()) -test_operators(AutoReverseDiff(); second_order=false, type_stability=false); -test_operators(AutoReverseDiff(; compile=true); second_order=false, type_stability=false); +test_operators(AutoReverseDiff(); type_stability=false); +test_operators(AutoReverseDiff(; compile=true); type_stability=false); diff --git a/test/second_order.jl b/test/second_order.jl index ece76f2c5..634d3fd3a 100644 --- a/test/second_order.jl +++ b/test/second_order.jl @@ -2,23 +2,72 @@ using ADTypes using DifferentiationInterface using DifferentiationInterface.DifferentiationTest -using Enzyme: Enzyme +using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff +using Enzyme: Enzyme using ReverseDiff: ReverseDiff using Zygote: Zygote using JET: JET using Test -cross_backends = [ - # forward over reverse - SecondOrder(AutoZygote(), AutoForwardDiff()), - SecondOrder(AutoZygote(), AutoEnzyme(Enzyme.Forward)), - SecondOrder(AutoReverseDiff(), AutoForwardDiff()), - # reverse over forward - SecondOrder(AutoForwardDiff(), AutoEnzyme(Enzyme.Reverse)), -] +@testset verbose = true "Forward over forward" begin + @testset "$(backend_string(backend))" for backend in [ + SecondOrder(AutoEnzyme(Enzyme.Forward), AutoFiniteDiff()), + SecondOrder(AutoEnzyme(Enzyme.Forward), AutoForwardDiff()), + SecondOrder(AutoForwardDiff(), AutoEnzyme(Enzyme.Forward)), + SecondOrder(AutoForwardDiff(), AutoFiniteDiff()), + SecondOrder(AutoFiniteDiff(), AutoEnzyme(Enzyme.Forward)), + SecondOrder(AutoFiniteDiff(), AutoForwardDiff()), + ] + test_operators(backend; first_order=false, type_stability=false) + end +end; + +@testset verbose = true "Forward over reverse" begin + @testset "$(backend_string(backend))" for backend in [ + # SecondOrder(AutoEnzyme(Enzyme.Reverse), AutoEnzyme(Enzyme.Forward)), + SecondOrder(AutoEnzyme(Enzyme.Reverse), AutoFiniteDiff()), + # SecondOrder(AutoEnzyme(Enzyme.Reverse), AutoForwardDiff()), + SecondOrder(AutoReverseDiff(), AutoEnzyme(Enzyme.Forward)), + SecondOrder(AutoReverseDiff(), AutoFiniteDiff()), + SecondOrder(AutoReverseDiff(), AutoForwardDiff()), + SecondOrder(AutoZygote(), AutoEnzyme(Enzyme.Forward)), + SecondOrder(AutoZygote(), AutoFiniteDiff()), + SecondOrder(AutoZygote(), AutoForwardDiff()), + ] + test_operators(backend; first_order=false, type_stability=false) + end +end; + +@testset verbose = true "Reverse over forward" begin + @testset "$(backend_string(backend))" for backend in [ + # SecondOrder(AutoEnzyme(Enzyme.Forward), AutoEnzyme(Enzyme.Reverse)), + # SecondOrder(AutoEnzyme(Enzyme.Forward), AutoReverseDiff()), + # SecondOrder(AutoEnzyme(Enzyme.Forward), AutoZygote()), + # SecondOrder(AutoFiniteDiff(), AutoEnzyme(Enzyme.Reverse)), + SecondOrder(AutoFiniteDiff(), AutoReverseDiff()), + SecondOrder(AutoFiniteDiff(), AutoZygote()), + SecondOrder(AutoForwardDiff(), AutoEnzyme(Enzyme.Reverse)), + # SecondOrder(AutoForwardDiff(), AutoReverseDiff()), + # SecondOrder(AutoForwardDiff(), AutoZygote()), + ] + test_operators(backend; first_order=false, type_stability=false) + end +end; -@testset "$(backend_string(backend))" for backend in cross_backends - test_operators(backend; first_order=false, mutating=false, type_stability=false) +@testset verbose = true "Reverse over reverse" begin + @testset "$(backend_string(backend))" for backend in [ + # SecondOrder(AutoEnzyme(Enzyme.Reverse), AutoEnzyme(Enzyme.Reverse)), + # SecondOrder(AutoEnzyme(Enzyme.Reverse), AutoReverseDiff()), + # SecondOrder(AutoEnzyme(Enzyme.Reverse), AutoZygote()), + # SecondOrder(AutoReverseDiff(), AutoEnzyme(Enzyme.Reverse)), + SecondOrder(AutoReverseDiff(), AutoReverseDiff()), + # SecondOrder(AutoReverseDiff(), AutoZygote()), + SecondOrder(AutoZygote(), AutoEnzyme(Enzyme.Reverse)), + SecondOrder(AutoZygote(), AutoReverseDiff()), + SecondOrder(AutoZygote(), AutoZygote()), + ] + test_operators(backend; first_order=false, type_stability=false) + end end; diff --git a/test/tracker.jl b/test/tracker.jl index 732427351..649454af6 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -9,6 +9,7 @@ using Test @test available(AutoTracker()) @test !supports_mutation(AutoTracker()) +@test !supports_hessian(AutoTracker()) test_operators( AutoTracker(); diff --git a/test/zero.jl b/test/zero.jl index 15e4c01a7..9621c60de 100644 --- a/test/zero.jl +++ b/test/zero.jl @@ -11,12 +11,16 @@ using Test @test available(AutoZeroReverse()) @test available(SecondOrder(AutoZeroForward(), AutoZeroReverse())) -test_operators([AutoZeroForward(), AutoZeroReverse()]; correctness=false); +test_operators( + [AutoZeroForward(), AutoZeroReverse()]; second_order=false, correctness=false +); test_operators( [ + SecondOrder(AutoZeroForward(), AutoZeroForward()), SecondOrder(AutoZeroForward(), AutoZeroReverse()), SecondOrder(AutoZeroReverse(), AutoZeroForward()), + SecondOrder(AutoZeroReverse(), AutoZeroReverse()), ]; first_order=false, correctness=false, @@ -24,13 +28,19 @@ test_operators( # allocs (experimental) -result = test_operators( +test_operators( [AutoZeroForward(), AutoZeroReverse()]; correctness=false, type_stability=false, - benchmark=true, allocations=true, second_order=false, ); +result = test_operators( + [AutoZeroForward(), AutoZeroReverse()]; + correctness=false, + type_stability=false, + benchmark=true, +); + data = parse_benchmark(result) diff --git a/test/zygote.jl b/test/zygote.jl index 58a00b3c9..34e9f9df7 100644 --- a/test/zygote.jl +++ b/test/zygote.jl @@ -9,5 +9,6 @@ using Test @test available(AutoZygote()) @test !supports_mutation(AutoZygote()) +@test supports_hessian(AutoZygote()) -test_operators(AutoZygote(); second_order=false, type_stability=false); +test_operators(AutoZygote(); type_stability=false);