diff --git a/Project.toml b/Project.toml index 2ada300..7a7fad8 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = [ "Vaibhav Dixit , Guillaume Dalle and contributors", ] -version = "1.5.4" +version = "1.6.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/dense.jl b/src/dense.jl index 4ca7045..8f52f62 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -39,7 +39,7 @@ struct AutoDiffractor <: AbstractADType end mode(::AutoDiffractor) = ForwardOrReverseMode() """ - AutoEnzyme{M} + AutoEnzyme{M,constant_function} Struct used to select the [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) backend for automatic differentiation. @@ -47,7 +47,10 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). # Constructors - AutoEnzyme(; mode=nothing) + AutoEnzyme(; mode=nothing, constant_function::Bool=false) + +The `constant_function` keyword argument (and type parameter) determines whether the function object itself should be considered constant or not during differentiation with Enzyme.jl. +For simple functions, `constant_function` should usually be set to `false`, but in the case of closures or callable structs which contain differentiated data that can be treated as constant, `constant_function` should be set to `true` for increased performance (more details below). # Fields @@ -55,9 +58,44 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). + an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required + `nothing` to choose the best mode automatically + +# Notes + +If `constant_function = true` but the enclosed data is not truly constant, then Enzyme.jl will not compute the correct derivative values. +An example of such a function is: + +```julia +cache = [0.0] +function f(x) + cache[1] = x[1]^2 + cache[1] + x[1] +end +``` + +In this case, the enclosed cache is a function of the differentiated input, and thus its values are non-constant with respect to the input. +Thus, in order to compute the correct derivative of the output, the derivative must propagate through the `cache` value, and said `cache` must not be treated as constant. + +Conversely, the following function can treat `parameter` as a constant, because `parameter` is never modified based on the input `x`: + +```julia +parameter = [0.0] +function f(x) + parameter[1] + x[1] +end +``` + +In this case, `constant_function = true` would allow the chosen differentiation system to perform extra memory and compute optimizations, under the assumption that `parameter` is kept constant. """ -Base.@kwdef struct AutoEnzyme{M} <: AbstractADType - mode::M = nothing +struct AutoEnzyme{M, constant_function} <: AbstractADType + mode::M +end + +function AutoEnzyme(mode::M; constant_function::Bool = false) where {M} + return AutoEnzyme{M, constant_function}(mode) +end + +function AutoEnzyme(; mode::M = nothing, constant_function::Bool = false) where {M} + return AutoEnzyme{M, constant_function}(mode) end mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension diff --git a/test/dense.jl b/test/dense.jl index 15f784d..739cf59 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -28,19 +28,25 @@ end @testset "AutoEnzyme" begin ad = AutoEnzyme() @test ad isa AbstractADType - @test ad isa AutoEnzyme{Nothing} + @test ad isa AutoEnzyme{Nothing, false} @test mode(ad) isa ForwardOrReverseMode @test ad.mode === nothing + ad = AutoEnzyme(EnzymeCore.Forward; constant_function = true) + @test ad isa AbstractADType + @test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), true} + @test mode(ad) isa ForwardMode + @test ad.mode == EnzymeCore.Forward + ad = AutoEnzyme(; mode = EnzymeCore.Forward) @test ad isa AbstractADType - @test ad isa AutoEnzyme{typeof(EnzymeCore.Forward)} + @test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), false} @test mode(ad) isa ForwardMode @test ad.mode == EnzymeCore.Forward - ad = AutoEnzyme(; mode = EnzymeCore.Reverse) + ad = AutoEnzyme(; mode = EnzymeCore.Reverse, constant_function = true) @test ad isa AbstractADType - @test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse)} + @test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse), true} @test mode(ad) isa ReverseMode @test ad.mode == EnzymeCore.Reverse end diff --git a/test/misc.jl b/test/misc.jl index 0ca1ddd..b3e501f 100644 --- a/test/misc.jl +++ b/test/misc.jl @@ -21,11 +21,6 @@ end @test length(string(sparse_backend1)) < length(string(sparse_backend2)) end -import ADTypes - -struct FakeSparsityDetector <: ADTypes.AbstractSparsityDetector end -struct FakeColoringAlgorithm <: ADTypes.AbstractColoringAlgorithm end - for backend in [ # dense ADTypes.AutoChainRules(; ruleconfig = :rc),