diff --git a/src/dense.jl b/src/dense.jl index ad2fcbf..8f52f62 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -50,7 +50,7 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). AutoEnzyme(; mode=nothing, constant_function::Bool=false) The `constant_function` keyword argument (and type parameter) determines whether the function object itself should be considered constant or not during differentiation with Enzyme.jl. -For simple functions, `constant_function` should usually be set to `false`, but in the case of closures or callable structs which contain differentiated data that can be treated as constant, `constant_function` should be set to `true` for increased performance. +For simple functions, `constant_function` should usually be set to `false`, but in the case of closures or callable structs which contain differentiated data that can be treated as constant, `constant_function` should be set to `true` for increased performance (more details below). # Fields @@ -58,6 +58,33 @@ For simple functions, `constant_function` should usually be set to `false`, but + an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required + `nothing` to choose the best mode automatically + +# Notes + +If `constant_function = true` but the enclosed data is not truly constant, then Enzyme.jl will not compute the correct derivative values. +An example of such a function is: + +```julia +cache = [0.0] +function f(x) + cache[1] = x[1]^2 + cache[1] + x[1] +end +``` + +In this case, the enclosed cache is a function of the differentiated input, and thus its values are non-constant with respect to the input. +Thus, in order to compute the correct derivative of the output, the derivative must propagate through the `cache` value, and said `cache` must not be treated as constant. + +Conversely, the following function can treat `parameter` as a constant, because `parameter` is never modified based on the input `x`: + +```julia +parameter = [0.0] +function f(x) + parameter[1] + x[1] +end +``` + +In this case, `constant_function = true` would allow the chosen differentiation system to perform extra memory and compute optimizations, under the assumption that `parameter` is kept constant. """ struct AutoEnzyme{M, constant_function} <: AbstractADType mode::M