Skip to content

Commit

Permalink
Err and correctly handle complex returns in reverse mode (#1309)
Browse files Browse the repository at this point in the history
* Err and correctly handle complex returns in reverse mode

* Update index.md

* fixup
  • Loading branch information
wsmoses authored Feb 28, 2024
1 parent 23f6222 commit 599658d
Show file tree
Hide file tree
Showing 12 changed files with 635 additions and 171 deletions.
180 changes: 180 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,183 @@ da2

Sometimes, determining how to perform this zeroing can be complicated.
That is why Enzyme provides a helper function `Enzyme.make_zero` that does this automatically.

### Complex Numbers

Differentiation of a function which returns a complex number is ambiguous, because there are several different gradients which may be desired. Rather than assume a specific of these conventions and potentially result in user error when the resulting derivative is not the desired one, Enzyme forces users to specify the desired convention by returning a real number instead.

Consider the function `f(z) = z*z`. If we were to differentiate this and have real inputs and outputs, the derivative `f'(z)` would be unambiguously `2*z`. However, consider breaking down a complex number down into real and imaginary parts. Suppose now we were to call `f` with the explicit real and imaginary components, `z = x + i y`. This means that `f` is a function that takes an input of two values and returns two values `f(x, y) = u(x, y) + i v(x, y)`. In the case of `z*z` this means that `u(x,y) = x*x-y*y` and `v(x,y) = 2*x*y`.


If we were to look at all first-order derivatives in total, we would end up with a 2x2 matrix (i.e. Jacobian), the derivative of each output wrt each input. Let's try to compute this, first by hand, then with Enzyme.

```
grad u(x, y) = [d/dx u, d/dy u] = [d/dx x*x-y*y, d/dy x*x-y*y] = [2*x, -2*y];
grad v(x, y) = [d/dx v, d/dy v] = [d/dx 2*x*y, d/dy 2*x*y] = [2*y, 2*x];
```

Reverse mode differentiation computes the derivative of all inputs with respect to a single output by propagating the derivative of the return to its inputs. Here, we can explicitly differentiate with respect to the real and imaginary results, respectively, to find this matrix.

```jldoctest complex
f(z) = z * z
# a fixed input to use for testing
z = 3.1 + 2.7im
grad_u = Enzyme.autodiff(Reverse, z->real(f(z)), Active, Active(z))[1][1]
grad_v = Enzyme.autodiff(Reverse, z->imag(f(z)), Active, Active(z))[1][1]
(grad_u, grad_v)
# output
(6.2 - 5.4im, 5.4 + 6.2im)
```

This is somewhat inefficient, since we need to call the forward pass twice, once for the real part, once for the imaginary. We can solve this using batched derivatives in Enzyme, which computes several derivatives for the same function all in one go. To make it work, we're going to need to use split mode, which allows us to provide a custom derivative return value.

```jldoctest complex
fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(f)}, Active, Active{ComplexF64})
# Compute the reverse pass seeded with a differntial return of 1.0 + 0.0im
grad_u = rev(Const(f), Active(z), 1.0 + 0.0im, fwd(Const(f), Active(z))[1])[1][1]
# Compute the reverse pass seeded with a differntial return of 0.0 + 1.0im
grad_v = rev(Const(f), Active(z), 0.0 + 1.0im, fwd(Const(f), Active(z))[1])[1][1]
(grad_u, grad_v)
# output
(6.2 - 5.4im, 5.4 + 6.2im)
```

Now let's make this batched

```jldoctest complex
fwd, rev = Enzyme.autodiff_thunk(ReverseSplitWidth(ReverseSplitNoPrimal, Val(2)), Const{typeof(f)}, Active, Active{ComplexF64})
# Compute the reverse pass seeded with a differential return of 1.0 + 0.0im and 0.0 + 1.0im in one go!
rev(Const(f), Active(z), (1.0 + 0.0im, 0.0 + 1.0im), fwd(Const(f), Active(z))[1])[1][1]
# output
(6.2 - 5.4im, 5.4 + 6.2im)
```

In contrast, Forward mode differentiation computes the derivative of all outputs with respect to a single input by providing a differential input. Thus we need to seed the shadow input with either 1.0 or 1.0im, respectively. This will compute the transpose of the matrix we found earlier.

```
d/dx f(x, y) = d/dx [u(x,y), v(x,y)] = d/dx [x*x-y*y, 2*x*y] = [ 2*x, 2*y];
d/dy f(x, y) = d/dy [u(x,y), v(x,y)] = d/dy [x*x-y*y, 2*x*y] = [-2*y, 2*x];
```

```jldoctest complex
d_dx = Enzyme.autodiff(Forward, f, Duplicated(z, 1.0+0.0im))[1]
d_dy = Enzyme.autodiff(Forward, f, Duplicated(z, 0.0+1.0im))[1]
(d_dx, d_dy)
# output
(6.2 + 5.4im, -5.4 + 6.2im)
```

Again, we can go ahead and batch this.
```jldoctest complex
Enzyme.autodiff(Forward, f, BatchDuplicated(z, (1.0+0.0im, 0.0+1.0im)))[1]
# output
(var"1" = 6.2 + 5.4im, var"2" = -5.4 + 6.2im)
```

Taking Jacobians with respect to the real and imaginary results is fine, but for a complex scalar function it would be really nice to have a single complex derivative. More concretely, in this case when differentiating `z*z`, it would be nice to simply return `2*z`. However, there are four independent variables in the 2x2 jacobian, but only two in a complex number.

Complex differentiation is often viewed in the lens of directional derivatives. For example, what is the derivative of the function as the real input increases, or as the imaginary input increases. Consider the derivative along the real axis, $\texttt{lim}_{\Delta x \rightarrow 0} \frac{f(x+\Delta x, y)-f(x, y)}{\Delta x}$. This simplifies to $\texttt{lim}_{\Delta x \rightarrow 0} \frac{u(x+\Delta x, y)-u(x, y) + i \left[ v(x+\Delta x, y)-v(x, y)\right]}{\Delta x} = \frac{\partial}{\partial x} u(x,y) + i\frac{\partial}{\partial x} v(x,y)$. This is exactly what we computed by seeding forward mode with a shadow of `1.0 + 0.0im`.

For completeness, we can also consider the derivative along the imaginary axis $\texttt{lim}_{\Delta y \rightarrow 0} \frac{f(x, y+\Delta y)-f(x, y)}{i\Delta y}$. Here this simplifies to $\texttt{lim}_{u(x, y+\Delta y)-u(x, y) + i \left[ v(x, y+\Delta y)-v(x, y)\right]}{i\Delta y} = -i\frac{\partial}{\partial y} u(x,y) + \frac{\partial}{\partial y} v(x,y)$. Except for the $i$ in the denominator of the limit, this is the same as the result of Forward mode, when seeding x with a shadow of `0.0 + 1.0im`. We can thus compute the derivative along the real axis by multiplying our second Forward mode call by `-im`.

```jldoctest complex
d_real = Enzyme.autodiff(Forward, f, Duplicated(z, 1.0+0.0im))[1]
d_im = -im * Enzyme.autodiff(Forward, f, Duplicated(z, 0.0+1.0im))[1]
(d_real, d_im)
# output
(6.2 + 5.4im, 6.2 + 5.4im)
```

Interestingly, the derivative of `z*z` is the same when computed in either axis. That is because this function is part of a special class of functions that are invariant to the input direction, called holomorphic.

Thus, for holomorphic functions, we can simply seed Forward-mode AD with a shadow of one for whatever input we are differenitating. This is nice since seeding the shadow with an input of one is exactly what we'd do for real-valued funtions as well.

Reverse-mode AD, however, is more tricky. This is because holomorphic functions are invariant to the direction of differentiation (aka the derivative inputs), not the direction of the differential return.

However, if a function is holomorphic, the two derivative functions we computed above must be the same. As a result, $\frac{\partial}{\partial x} u = \frac{\partial}{\partial y} v$ and $\frac{\partial}{\partial y} u = -\frac{\partial}{\partial x} v$.

We saw earlier, that performing reverse-mode AD with a return seed of `1.0 + 0.0im` yielded `[d/dx u, d/dy u]`. Thus, for a holomorphic function, a real-seeded Reverse-mode AD computes `[d/dx u, -d/dx v]`, which is the complex conjugate of the derivative.


```jldoctest complex
conj(grad_u)
# output
6.2 + 5.4im
```

In the case of a scalar-input scalar-output function, that's sufficient. However, most of the time one uses reverse mode, it involves either several inputs or outputs, perhaps via memory. This case requires additional handling to properly sum all the partial derivatives from the use of each input and apply the conjugate operator at only the ones relevant to the differential return.

For simplicity, Enzyme provides a helper utlity `ReverseHolomorphic` which performs Reverse mode properly here, assuming that the function is indeed holomorphic and thus has a well-defined single derivative.

```jldoctest complex
Enzyme.autodiff(ReverseHolomorphic, f, Active, Active(z))[1][1]
# output
6.2 + 5.4im
```

For even non-holomorphic functions, complex analysis allows us to define $\frac{\partial}{\partial z} = \frac{1}{2}\left(\frac{\partial}{\partial x} - i \frac{\partial}{\partial y} \right)$. For non-holomorphic functions, this allows us to compute `d/dz`. Let's consider `myabs2(z) = z * conj(z)`. We can compute the derivative wrt z of this in Forward mode as follows, which as one would expect results in a result of `conj(z)`:

```jldoctest complex
myabs2(z) = z * conj(z)
dabs2_dx, dabs2_dy = Enzyme.autodiff(Forward, myabs2, BatchDuplicated(z, (1.0 + 0.0im, 0.0 + 1.0im)))[1]
(dabs2_dx - im * dabs2_dy) / 2
# output
3.1 - 2.7im
```

Similarly, we can compute `d/d conj(z) = d/dx + i d/dy`.

```jldoctest complex
(dabs2_dx + im * dabs2_dy) / 2
# output
3.1 + 2.7im
```

Computing this in Reverse mode is more tricky. Let's expand `f` in terms of `u` and `v`. $\frac{\partial}{\partial z} f = \frac12 \left( [u_x + i v_x] - i [u_y + i v_y] \right) = \frac12 \left( [u_x + v_y] + i [v_x - u_y] \right)$. Thus `d/dz = (conj(grad_u) + im * conj(grad_v))/2`.

```jldoctest complex
abs2_fwd, abs2_rev = Enzyme.autodiff_thunk(ReverseSplitWidth(ReverseSplitNoPrimal, Val(2)), Const{typeof(myabs2)}, Active, Active{ComplexF64})
# Compute the reverse pass seeded with a differential return of 1.0 + 0.0im and 0.0 + 1.0im in one go!
gradabs2_u, gradabs2_v = abs2_rev(Const(myabs2), Active(z), (1.0 + 0.0im, 0.0 + 1.0im), abs2_fwd(Const(myabs2), Active(z))[1])[1][1]
(conj(gradabs2_u) + im * conj(gradabs2_v)) / 2
# output
3.1 - 2.7im
```

For `d/d conj(z)`, $\frac12 \left( [u_x + i v_x] + i [u_y + i v_y] \right) = \frac12 \left( [u_x - v_y] + i [v_x + u_y] \right)$. Thus `d/d conj(z) = (grad_u + im * grad_v)/2`.

```jldoctest complex
(gradabs2_u + im * gradabs2_v) / 2
# output
3.1 + 2.7im
```

Note: when writing rules for complex scalar functions, in reverse mode one needs to conjugate the differential return, and similarly the true result will be the conjugate of that value (in essence you can think of reverse-mode AD as working in the conjugate space).
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.6.5"
version = "0.7.0"

[compat]
Adapt = "3, 4"
Expand Down
17 changes: 11 additions & 6 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module EnzymeCore

export Forward, Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal
export ReverseSplitModified, ReverseSplitWidth
export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal
export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed
export DefaultABI, FFIABI, InlineABI
export BatchDuplicatedFunc
Expand Down Expand Up @@ -49,6 +49,7 @@ struct Active{T} <: Annotation{T}
end

Active(i::Integer) = Active(float(i))
Active(ci::Complex{T}) where T <: Integer = Active(float(ci))

"""
Duplicated(x, ∂f_∂x)
Expand Down Expand Up @@ -178,14 +179,18 @@ Abstract type for what differentiation mode will be used.
abstract type Mode{ABI} end

"""
struct ReverseMode{ReturnPrimal,ABI} <: Mode{ABI}
struct ReverseMode{ReturnPrimal,ABI,Holomorphic} <: Mode{ABI}
Reverse mode differentiation.
- `ReturnPrimal`: Should Enzyme return the primal return value from the augmented-forward.
"""
struct ReverseMode{ReturnPrimal,ABI} <: Mode{ABI} end
const Reverse = ReverseMode{false,DefaultABI}()
const ReverseWithPrimal = ReverseMode{true,DefaultABI}()
- `ABI`: What runtime ABI to use
- `Holomorphic`: Whether the complex result function is holomorphic and we should compute d/dz
"""
struct ReverseMode{ReturnPrimal,ABI,Holomorphic} <: Mode{ABI} end
const Reverse = ReverseMode{false,DefaultABI, false}()
const ReverseWithPrimal = ReverseMode{true,DefaultABI, false}()
const ReverseHolomorphic = ReverseMode{false,DefaultABI, true}()
const ReverseHolomorphicWithPrimal = ReverseMode{true,DefaultABI, true}()

"""
struct ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetween,ABI} <: Mode{ABI}
Expand Down
4 changes: 2 additions & 2 deletions lib/EnzymeTestUtils/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
ConstructionBase = "1.4.1"
Enzyme = "0.11"
EnzymeCore = "0.5, 0.6"
Enzyme = "0.11, 0.12"
EnzymeCore = "0.5, 0.6, 0.7"
FiniteDifferences = "0.12.12"
MetaTesting = "0.1"
Quaternions = "0.7"
Expand Down
Loading

2 comments on commit 599658d

@wsmoses
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir="lib/EnzymeTestUtils"

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Version 0.1.4 already exists

Please sign in to comment.