Skip to content

Commit

Permalink
Improve docs, revamp checks (#449)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Sep 5, 2024
1 parent 1069266 commit a2793bd
Show file tree
Hide file tree
Showing 43 changed files with 152 additions and 217 deletions.
5 changes: 1 addition & 4 deletions DifferentiationInterface/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ An interface to various automatic differentiation (AD) backends in Julia.

## Goal

This package provides a backend-agnostic syntax to differentiate functions of the following types:

- _one-argument functions_ (allocating): `f(x) = y`
- _two-argument functions_ (mutating): `f!(y, x) = nothing`
This package provides a unified syntax to differentiate functions.

## Features

Expand Down
3 changes: 1 addition & 2 deletions DifferentiationInterface/docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ value_gradient_and_hessian!

```@docs
check_available
check_twoarg
check_hessian
check_inplace
DifferentiationInterface.outer
DifferentiationInterface.inner
```
Expand Down
34 changes: 13 additions & 21 deletions DifferentiationInterface/docs/src/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ This page is about the latter, check out [that page](@ref "Operators") to learn

## List of backends

We support all dense backend choices from [ADTypes.jl](https://github.com/SciML/ADTypes.jl):
We support the following dense backend choices from [ADTypes.jl](https://github.com/SciML/ADTypes.jl):

- [`AutoChainRules`](@extref ADTypes.AutoChainRules)
- [`AutoDiffractor`](@extref ADTypes.AutoDiffractor)
- [`AutoEnzyme`](@extref ADTypes.AutoEnzyme)
- [`AutoFastDifferentiation`](@extref ADTypes.AutoFastDifferentiation)
Expand Down Expand Up @@ -37,11 +38,17 @@ We strongly recommend that users upgrade to Julia 1.10 or above, where all backe

## Features

Given a backend object, you can use:

- [`check_available`](@ref) to know whether the required AD package is loaded
- [`check_inplace`](@ref) to know whether the backend supports in-place functions (all backends support out-of-place functions)

```@setup backends
using ADTypes
using DifferentiationInterface
import Markdown
import ChainRulesCore
import Diffractor
import Enzyme
import FastDifferentiation
Expand All @@ -56,6 +63,7 @@ import Tracker
import Zygote
backend_examples = [
AutoChainRules(; ruleconfig=Zygote.ZygoteRuleConfig()),
AutoDiffractor(),
AutoEnzyme(),
AutoFastDifferentiation(),
Expand All @@ -72,17 +80,16 @@ backend_examples = [
checkmark(x::Bool) = x ? '✅' : '❌'
unicode_check_available(backend) = checkmark(check_available(backend))
unicode_check_hessian(backend) = checkmark(check_hessian(backend; verbose=false))
unicode_check_twoarg(backend) = checkmark(check_twoarg(backend))
unicode_check_inplace(backend) = checkmark(check_inplace(backend))
io = IOBuffer()
# Table header
println(io, "| Backend | Availability | Two-argument functions | Hessian support |")
println(io, "|:--------|:------------:|:----------------------:|:---------------:|")
println(io, "| Backend | Availability | In-place functions |")
println(io, "|:--------|:------------:|:----------------------:|")
for b in backend_examples
join(io, ["`$(nameof(typeof(b)))`", unicode_check_available(b), unicode_check_twoarg(b), unicode_check_hessian(b)], '|')
join(io, ["`$(nameof(typeof(b)))`", unicode_check_available(b), unicode_check_inplace(b)], '|')
println(io, '|' )
end
backend_table = Markdown.parse(String(take!(io)))
Expand All @@ -92,21 +99,6 @@ backend_table = Markdown.parse(String(take!(io)))
backend_table #hide
```

### Availability

You can use [`check_available`](@ref) to verify whether a given backend is loaded.

### Support for two-argument functions

All backends are compatible with one-argument functions `f(x) = y`.
Only some are compatible with two-argument functions `f!(y, x) = nothing`.
You can use [`check_twoarg`](@ref) to verify this compatibility.

### Support for Hessian

Only some backends are able to compute Hessians.
You can use [`check_hessian`](@ref) to verify this feature (beware that it will try to compute a small Hessian, so it is not instantaneous like the other checks).

## Backend switch

The wrapper [`DifferentiateWith`](@ref) allows you to switch between backends.
Expand Down
8 changes: 3 additions & 5 deletions DifferentiationInterface/docs/src/dev_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ Most operators have 4 variants, which look like this in the first order: `operat
## New operator

To implement a new operator for an existing backend, you need to write 5 methods: 1 for [preparation](@ref Preparation) and 4 corresponding to the variants of the operator (see above).
In some cases, a subset of those methods will be enough, but most of the time, forgetting one will trigger errors.
For first-order operators, you may also want to support [two-argument functions](@ref "Mutation and signatures"), which requires another 5 methods (defined on `f!` instead of `f`).
For first-order operators, you may also want to support [in-place functions](@ref "Mutation and signatures"), which requires another 5 methods (defined on `f!` instead of `f`).

The method `prepare_operator` must output an `extras` object of the correct type.
For instance, `prepare_gradient(f, backend, x)` must return a [`DifferentiationInterface.GradientExtras`](@ref).
Expand All @@ -40,7 +39,7 @@ Your AD package needs to be registered first.
### Core code

In the main package, you should define a new struct `SuperDiffBackend` which subtypes [`ADTypes.AbstractADType`](@extref ADTypes), and endow it with the fields you need to parametrize your differentiation routines.
You also have to define [`ADTypes.mode`](@extref) and [`DifferentiationInterface.twoarg_support`](@ref) on `SuperDiffBackend`.
You also have to define [`ADTypes.mode`](@extref) and [`DifferentiationInterface.inplace_support`](@ref) on `SuperDiffBackend`.

!!! info
In the end, this backend struct will need to be contributed to [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
Expand Down Expand Up @@ -79,5 +78,4 @@ GROUP = get(ENV, "JULIA_DI_TEST_GROUP", "Back/SuperDiff")

but don't forget to switch it back before pushing.

Finally, you need to add your backend to the documentation, modifying every page that involves a list of backends.
That includes the README.
Finally, you need to add your backend to the documentation, modifying every page that involves a list of backends (including the `README.md`).
16 changes: 6 additions & 10 deletions DifferentiationInterface/docs/src/implementations.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
# Implementations

DifferentiationInterface.jl provides a handful of [operators](@ref "Operators") like [`gradient`](@ref) or [`jacobian`](@ref), each with several variants:

- **out-of-place** or **in-place** behavior
- **with** or **without primal** output value
- support for **one-argument functions** `y = f(x)` or **two-argument functions** `f!(y, x)`
DifferentiationInterface.jl provides a handful of [operators](@ref "Operators") like [`gradient`](@ref) or [`jacobian`](@ref), each with several variants: out-of-place or in-place behavior, with or without primal output value.

While it is possible to define every operator using just [`pushforward`](@ref) and [`pullback`](@ref), some backends have more efficient implementations of high-level operators.
When they are available, we nearly always call these backend-specific overloads.
Expand All @@ -24,7 +20,7 @@ The cells can have three values:
```@setup overloads
using ADTypes: AbstractADType
using DifferentiationInterface
using DifferentiationInterface: twoarg_support, TwoArgSupported
using DifferentiationInterface: inplace_support, InPlaceSupported
using Markdown: Markdown
using Diffractor: Diffractor
Expand Down Expand Up @@ -152,16 +148,16 @@ function print_overloads(backend, ext::Symbol)
io = IOBuffer()
ext = Base.get_extension(DifferentiationInterface, ext)
println(io, "#### One-argument functions `y = f(x)`")
println(io, "#### Out-of-place functions `f(x) = y`")
println(io)
print_overload_table(io, operators_and_types_f(backend), ext)
println(io, "#### Two-argument functions `f!(y, x)`")
println(io, "#### In-place functions `f!(y, x) = nothing`")
println(io)
if twoarg_support(backend) == TwoArgSupported()
if inplace_support(backend) == InPlaceSupported()
print_overload_table(io, operators_and_types_f!(backend), ext)
else
println(io, "Backend doesn't support mutating functions.")
println(io, "Backend doesn't support in-place functions.")
end
return Markdown.parse(String(take!(io)))
Expand Down
Loading

0 comments on commit a2793bd

Please sign in to comment.