Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Enzyme #67

Merged
merged 20 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
b23428b
fix enzyme to match new interface, enable enzyme tests
Red-Portal Jun 21, 2024
ed38340
fix type instability tighten Enzyme compat
Red-Portal Jul 1, 2024
2af7695
add indirection to enforce type stability of `restructure`
Red-Portal Jul 15, 2024
a5be89d
minor formatting changes, fix DistributionsAD inference test
Red-Portal Jul 17, 2024
1727f5f
tighten compat bound for Enzyme
Red-Portal Jul 17, 2024
69a4754
fix remove trailing whitespace in docs
Red-Portal Aug 3, 2024
65a5fc9
Merge remote-tracking branch 'origin/master' into enable_enzyme
Red-Portal Aug 3, 2024
e063d67
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into ena…
Red-Portal Aug 5, 2024
ef43c6c
tighten compat bound for enzyme
Red-Portal Aug 5, 2024
156398c
remove unused comment
Red-Portal Aug 7, 2024
1061ae8
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into ena…
Red-Portal Aug 9, 2024
4978d1f
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into ena…
Red-Portal Aug 10, 2024
d821554
compat tighten julia version
Red-Portal Aug 10, 2024
36d22cf
fix formatting
Red-Portal Aug 10, 2024
c7f60a5
fix formatting
Red-Portal Aug 10, 2024
b0cd0b5
try distable `DistributionsAD`
Red-Portal Aug 10, 2024
72ed3fb
fix tests enable Enzyme inference tests only on 1.10
Red-Portal Aug 11, 2024
32a53a5
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into ena…
Red-Portal Aug 21, 2024
5b46a80
add tests on Julia 1.10
Red-Portal Aug 21, 2024
18ce79a
fix tighten compat bound for Enzyme
Red-Portal Aug 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .github/workflows/Benchmark.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
name: Benchmarks
on:
push:
branches:
- master
pull_request:
branches:
- master
Expand Down Expand Up @@ -52,4 +49,4 @@ jobs:
alert-threshold: "200%"
fail-on-alert: true
benchmark-data-dir-path: benchmarks
auto-push: ${{ github.event_name != 'pull_request' }}
auto-push: false
13 changes: 8 additions & 5 deletions .github/workflows/DocNav.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ jobs:

# Define the URL of the navbar to be used
NAVBAR_URL="https://raw.githubusercontent.com/TuringLang/turinglang.github.io/main/assets/scripts/TuringNavbar.html"

# Update all HTML files in the current directory (gh-pages root)
./insert_navbar.sh . $NAVBAR_URL


# Define file & folder to exclude (comma-separated list), Un-Comment the below line for excluding anything!
EXCLUDE_PATHS="benchmarks"

# Update all HTML files in the current directory (gh-pages root), use `--exclude` only if requred!
./insert_navbar.sh . $NAVBAR_URL --exclude "$EXCLUDE_PATHS"

# Remove the insert_navbar.sh file
rm insert_navbar.sh

# Check if there are any changes
if [[ -n $(git status -s) ]]; then
git add .
Expand Down
33 changes: 0 additions & 33 deletions .github/workflows/JuliaNightly.yml

This file was deleted.

2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ ChainRulesCore = "1.16"
DiffResults = "1"
Distributions = "0.25.87"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.12"
Enzyme = "0.12.28"
FillArrays = "1.3"
ForwardDiff = "0.10.36"
Functors = "0.4"
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://turinglang.org/AdvancedVI.jl/stable/)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://turinglang.org/AdvancedVI.jl/dev/)
[![Build Status](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/CI.yml/badge.svg?branch=master)](https://github.com/TuringLang/AdvancedVI.jl/actions/workflows/CI.yml?query=branch%3Amaster)
[![JuliaNightly](https://github.com/TuringLang/AdvancedVI.jl/workflows/JuliaNightly/badge.svg?branch=master)](https://github.com/TuringLang/AdvancedVI.jl/actions?query=workflow%3AJuliaNightly+branch%3Amaster)
[![Coverage](https://codecov.io/gh/TuringLang/AdvancedVI.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/TuringLang/AdvancedVI.jl)

# AdvancedVI.jl
Expand Down
5 changes: 3 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ makedocs(;
"General Usage" => "general.md",
"Examples" => "examples.md",
"ELBO Maximization" => [
"Overview" => "elbo/overview.md",
"Reparameterization Gradient Estimator" => "elbo/repgradelbo.md",
"Overview" => "elbo/overview.md",
"Reparameterization Gradient Estimator" => "elbo/repgradelbo.md",
"Sample Average Approximation" => "elbo/saa.md",
"Location-Scale Variational Family" => "locscale.md",
]],
)
Expand Down
47 changes: 40 additions & 7 deletions ext/AdvancedVIEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,48 @@ else
using ..AdvancedVI: ADTypes, DiffResults
end

# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916)

AdvancedVI.restructure_ad_forward(
::ADTypes.AutoEnzyme, restructure, params
) = restructure(params)::typeof(restructure.model)

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoEnzyme,
f,
x ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult
)
Enzyme.API.runtimeActivity!(true)
∇x = DiffResults.gradient(out)
fill!(∇x, zero(eltype(∇x)))
_, y = Enzyme.autodiff(
Enzyme.ReverseWithPrimal,
f,
yebai marked this conversation as resolved.
Show resolved Hide resolved
Enzyme.Active,
Enzyme.Duplicated(x, ∇x)
)
DiffResults.value!(out, y)
return out
end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult
) where {T<:Real}
y = f(θ)
::ADTypes.AutoEnzyme,
f,
x ::AbstractVector{<:Real},
aux,
out ::DiffResults.MutableDiffResult
)
Enzyme.API.runtimeActivity!(true)
∇x = DiffResults.gradient(out)
fill!(∇x, zero(eltype(∇x)))
_, y = Enzyme.autodiff(
Enzyme.ReverseWithPrimal,
f,
Enzyme.Active,
Enzyme.Duplicated(x, ∇x),
Enzyme.Const(aux)
)
DiffResults.value!(out, y)
∇θ = DiffResults.gradient(out)
fill!(∇θ, zero(T))
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ))
return out
end

Expand Down
15 changes: 7 additions & 8 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,17 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif
function value_and_gradient! end

"""
stop_gradient(x)
restructure_ad_forward(adtype, restructure, params)

Stop the gradient from propagating to `x` if the selected ad backend supports it.
Otherwise, it is equivalent to `identity`.
Apply `restructure` to `params`.
This is an indirection for handling the type stability of `restructure`, as some AD backends require strict type stability in the AD path.

# Arguments
- `x`: Input

# Returns
- `x`: Same value as the input.
- `ad::ADTypes.AbstractADType`: Automatic differentiation backend.
- `restructure`: Callable for restructuring the varitional distribution from `params`.
- `params`: Variational Parameters.
"""
function stop_gradient end
restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params)

# Update for gradient descent step
"""
Expand Down
4 changes: 2 additions & 2 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ Functors.@functor MvLocationScale (location, scale)
# is very inefficient.
# begin
struct RestructureMeanField{S <: Diagonal, D, L}
q::MvLocationScale{S, D, L}
model::MvLocationScale{S, D, L}
end

function (re::RestructureMeanField)(flat::AbstractVector)
n_dims = div(length(flat), 2)
location = first(flat, n_dims)
scale = Diagonal(last(flat, n_dims))
MvLocationScale(location, scale, re.q.dist, re.q.scale_eps)
MvLocationScale(location, scale, re.model.dist, re.model.scale_eps)
end

function Optimisers.destructure(
Expand Down
13 changes: 10 additions & 3 deletions src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) =
estimate_objective(Random.default_rng(), obj, q, prob; n_samples)

function estimate_repgradelbo_ad_forward(params′, aux)
@unpack rng, obj, problem, restructure, q_stop = aux
q = restructure(params′)
@unpack rng, obj, problem, adtype, restructure, q_stop = aux
q = restructure_ad_forward(adtype, restructure, params′)
samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy)
energy = estimate_energy_with_samples(problem, samples)
elbo = energy + entropy
Expand All @@ -117,7 +117,14 @@ function estimate_gradient!(
state,
)
q_stop = restructure(params)
aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop)
aux = (
rng = rng,
adtype = adtype,
obj = obj,
problem = prob,
restructure = restructure,
q_stop = q_stop
)
value_and_gradient!(
adtype, estimate_repgradelbo_ad_forward, params, aux, out
)
Expand Down
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand All @@ -24,8 +25,10 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
ADTypes = "0.2.1, 1"
Bijectors = "0.13"
DiffResults = "1.0"
Distributions = "0.25.100"
DistributionsAD = "0.6.45"
Enzyme = "0.12.28"
FillArrays = "1.6.1"
ForwardDiff = "0.10.36"
Functors = "0.4.5"
Expand Down
10 changes: 5 additions & 5 deletions test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype ∈ [Float64, Float32],
(modelname, modelconstr) ∈ Dict(
:Normal=> normal_meanfield,
:Normal => normal_meanfield,
),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
Expand All @@ -12,9 +12,9 @@
),
(adbackname, adtype) ∈ Dict(
:ForwarDiff => AutoForwardDiff(),
#:ReverseDiff => AutoReverseDiff(),
#:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment
:Zygote => AutoZygote(),
#:Enzyme => AutoEnzyme(),
:Enzyme => AutoEnzyme(),
)

seed = (0x38bef07cf9cc549d)
Expand All @@ -32,8 +32,8 @@
# where ρ = 1 - ημ, μ is the strong convexity constant.
contraction_rate = 1 - η*strong_convexity

μ0 = Zeros(realtype, n_dims)
L0 = Diagonal(Ones(realtype, n_dims))
μ0 = zeros(realtype, n_dims)
L0 = Diagonal(ones(realtype, n_dims))
q0 = TuringDiagMvNormal(μ0, diag(L0))

@testset "convergence" begin
Expand Down
2 changes: 1 addition & 1 deletion test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
#:Enzyme => AutoEnzyme(),
:Enzyme => AutoEnzyme(),
)

seed = (0x38bef07cf9cc549d)
Expand Down
4 changes: 2 additions & 2 deletions test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
(adbackname, adtype) in Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
#:Zygote => AutoZygote(),
#:Enzyme => AutoEnzyme(),
:Zygote => AutoZygote(),
:Enzyme => AutoEnzyme(),
)

seed = (0x38bef07cf9cc549d)
Expand Down
21 changes: 20 additions & 1 deletion test/interface/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Test
:ForwardDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
# :Enzyme => AutoEnzyme() # Currently not tested against
:Enzyme => AutoEnzyme(),
)
D = 10
A = randn(D, D)
Expand All @@ -19,4 +19,23 @@ using Test
@test ∇ ≈ (A + A')*λ/2
@test f ≈ λ'*A*λ / 2
end

@testset "$(adname) with auxiliary input" for (adname, adsymbol) ∈ Dict(
:ForwardDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Enzyme => AutoEnzyme(),
)
D = 10
A = randn(D, D)
λ = randn(D)
b = randn(D)
grad_buf = DiffResults.GradientResult(λ)
f(λ′, aux) = λ′'*A*λ′ / 2 + dot(aux.b, λ′)
AdvancedVI.value_and_gradient!(adsymbol, f, λ, (b=b,), grad_buf)
∇ = DiffResults.gradient(grad_buf)
f = DiffResults.value(grad_buf)
@test ∇ ≈ (A + A')*λ/2 + b
@test f ≈ λ'*A*λ / 2 + dot(b, λ)
end
end
5 changes: 3 additions & 2 deletions test/interface/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ end
@testset for ad in [
ADTypes.AutoForwardDiff(),
ADTypes.AutoReverseDiff(),
ADTypes.AutoZygote()
ADTypes.AutoZygote(),
ADTypes.AutoEnzyme()
]
q_true = MeanFieldGaussian(
Vector{eltype(μ_true)}(μ_true),
Expand All @@ -47,7 +48,7 @@ end
obj = RepGradELBO(10; entropy=StickingTheLandingEntropy())
out = DiffResults.DiffResult(zero(eltype(params)), similar(params))

aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true)
aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=ad)
AdvancedVI.value_and_gradient!(
ad, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out
)
Expand Down
2 changes: 0 additions & 2 deletions test/models/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ function normal_meanfield(rng::Random.AbstractRNG, realtype::Type)

σ0 = realtype(0.3)
μ = Fill(realtype(5), n_dims)
#randn(rng, realtype, n_dims)
σ = Fill(σ0, n_dims)
#log.(exp.(randn(rng, realtype, n_dims)) .+ 1)

model = TestNormal(μ, Diagonal(σ.^2))

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using DistributionsAD
using LogDensityProblems
using Optimisers
using ADTypes
using ForwardDiff, ReverseDiff, Zygote
using ForwardDiff, ReverseDiff, Zygote, Enzyme

using AdvancedVI

Expand Down
Loading