diff --git a/Project.toml b/Project.toml index 7a7fad8..d6b8020 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = [ "Vaibhav Dixit , Guillaume Dalle and contributors", ] -version = "1.6.0" +version = "1.6.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/dense.jl b/src/dense.jl index 8f52f62..6757476 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -50,7 +50,8 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). AutoEnzyme(; mode=nothing, constant_function::Bool=false) The `constant_function` keyword argument (and type parameter) determines whether the function object itself should be considered constant or not during differentiation with Enzyme.jl. -For simple functions, `constant_function` should usually be set to `false`, but in the case of closures or callable structs which contain differentiated data that can be treated as constant, `constant_function` should be set to `true` for increased performance (more details below). +For simple functions, `constant_function` should usually be set to `true`, which leads to increased performance. +However, in the case of closures or callable structs which contain differentiated data, `constant_function` should be set to `false` to ensure correctness (more details below). # Fields @@ -61,30 +62,39 @@ For simple functions, `constant_function` should usually be set to `false`, but # Notes -If `constant_function = true` but the enclosed data is not truly constant, then Enzyme.jl will not compute the correct derivative values. -An example of such a function is: +We now give several examples of functions. +For each one, we explain how `constant_function` should be set in order to compute the correct derivative with respect to the input `x`. ```julia -cache = [0.0] -function f(x) - cache[1] = x[1]^2 - cache[1] + x[1] +function f1(x) + return x[1] end ``` -In this case, the enclosed cache is a function of the differentiated input, and thus its values are non-constant with respect to the input. -Thus, in order to compute the correct derivative of the output, the derivative must propagate through the `cache` value, and said `cache` must not be treated as constant. - -Conversely, the following function can treat `parameter` as a constant, because `parameter` is never modified based on the input `x`: +The function `f1` is not a closure, it does not contain any data. +Thus `f1` can be differentiated with `AutoEnzyme(constant_function=true)` (although here setting `constant_function=false` would change neither correctness nor performance). ```julia parameter = [0.0] -function f(x) - parameter[1] + x[1] +function f2(x) + return parameter[1] + x[1] +end +``` + +The function `f2` is a closure over `parameter`, but `parameter` is never modified based on the input `x`. +Thus, `f2` can be differentiated with `AutoEnzyme(constant_function=true)` (setting `constant_function=false` would not change correctness but would hinder performance). + +```julia +cache = [0.0] +function f3(x) + cache[1] = x[1] + return cache[1] + x[1] end ``` -In this case, `constant_function = true` would allow the chosen differentiation system to perform extra memory and compute optimizations, under the assumption that `parameter` is kept constant. +The function `f3` is a closure over `cache`, and `cache` is modified based on the input `x`. +That means `cache` cannot be treated as constant, since derivative values must be propagated through it. +Thus `f3` must be differentiated with `AutoEnzyme(constant_function=false)` (setting `constant_function=true` would make the result incorrect). """ struct AutoEnzyme{M, constant_function} <: AbstractADType mode::M