Skip to content

Commit

Permalink
Migrate to DifferentiationInterface (#98)
Browse files Browse the repository at this point in the history
* migrate to DifferentiationInterface
* run formatter
* tighten compat bound for ADTypes
* fix compat bound for docs
  • Loading branch information
Red-Portal authored Sep 30, 2024
1 parent 4eab1ac commit d0efe02
Show file tree
Hide file tree
Showing 22 changed files with 97 additions and 263 deletions.
18 changes: 18 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
steps:
- label: "CUDA with julia {{matrix.julia}}"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
- JuliaCI/julia-test#v1:
test_args: "--quickfail"
agents:
queue: "juliagpu"
cuda: "*"
timeout_in_minutes: 60
env:
GROUP: "GPU"
ADVANCEDVI_TEST_CUDA: "true"
matrix:
setup:
julia:
- "1.10"
20 changes: 9 additions & 11 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -24,50 +25,47 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
AdvancedVIBijectorsExt = "Bijectors"
AdvancedVIEnzymeExt = "Enzyme"
AdvancedVIForwardDiffExt = "ForwardDiff"
AdvancedVIReverseDiffExt = "ReverseDiff"
AdvancedVITapirExt = "Tapir"
AdvancedVIZygoteExt = "Zygote"

[compat]
ADTypes = "0.1, 0.2, 1"
ADTypes = "1"
Accessors = "0.1"
Bijectors = "0.13"
ChainRulesCore = "1.16"
DiffResults = "1"
DifferentiationInterface = "0.6"
Distributions = "0.25.111"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.13"
FillArrays = "1.3"
ForwardDiff = "0.10.36"
ForwardDiff = "0.10"
Functors = "0.4"
LinearAlgebra = "1"
LogDensityProblems = "2"
Mooncake = "0.4"
Optimisers = "0.2.16, 0.3"
ProgressMeter = "1.6"
Random = "1"
Requires = "1.0"
ReverseDiff = "1.15.1"
ReverseDiff = "1"
SimpleUnPack = "1.1.0"
StatsBase = "0.32, 0.33, 0.34"
Tapir = "0.2"
Zygote = "0.6.63"
Zygote = "0.6"
julia = "1.7"

[extras]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
ADTypes = "0.1.6"
ADTypes = "1"
AdvancedVI = "0.3"
Bijectors = "0.13.6"
Distributions = "0.25"
Expand Down
16 changes: 0 additions & 16 deletions ext/AdvancedVIEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

module AdvancedVIEnzymeExt

if isdefined(Base, :get_extension)
Expand All @@ -15,21 +14,6 @@ function AdvancedVI.restructure_ad_forward(::ADTypes.AutoEnzyme, restructure, pa
return restructure(params)::typeof(restructure.model)
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
)
∇x = DiffResults.gradient(out)
fill!(∇x, zero(eltype(∇x)))
_, y = Enzyme.autodiff(
Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true),
Enzyme.Const(f),
Enzyme.Active,
Enzyme.Duplicated(x, ∇x),
)
DiffResults.value!(out, y)
return out
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoEnzyme,
f,
Expand Down
42 changes: 0 additions & 42 deletions ext/AdvancedVIForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,42 +0,0 @@

module AdvancedVIForwardDiffExt

if isdefined(Base, :get_extension)
using ForwardDiff
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
else
using ..ForwardDiff
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
end

getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoForwardDiff,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
)
chunk_size = getchunksize(ad)
config = if isnothing(chunk_size)
ForwardDiff.GradientConfig(f, x)
else
ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk(length(x), chunk_size))
end
ForwardDiff.gradient!(out, f, x, config)
return out
end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoForwardDiff,
f,
x::AbstractVector,
aux,
out::DiffResults.MutableDiffResult,
)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
Empty file added ext/AdvancedVIMooncakeExt.jl
Empty file.
36 changes: 0 additions & 36 deletions ext/AdvancedVIReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,36 +0,0 @@

module AdvancedVIReverseDiffExt

if isdefined(Base, :get_extension)
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using ReverseDiff
else
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..ReverseDiff
end

# ReverseDiff without compiled tape
function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoReverseDiff,
f,
x::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
)
tp = ReverseDiff.GradientTape(f, x)
ReverseDiff.gradient!(out, tp, x)
return out
end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoReverseDiff,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult,
)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
37 changes: 0 additions & 37 deletions ext/AdvancedVITapirExt.jl

This file was deleted.

36 changes: 0 additions & 36 deletions ext/AdvancedVIZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,36 +0,0 @@

module AdvancedVIZygoteExt

if isdefined(Base, :get_extension)
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using ChainRulesCore
using Zygote
else
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..ChainRulesCore
using ..Zygote
end

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoZygote, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
)
y, back = Zygote.pullback(f, x)
∇x = back(one(y))
DiffResults.value!(out, y)
DiffResults.gradient!(out, only(∇x))
return out
end

function AdvancedVI.value_and_gradient!(
ad::ADTypes.AutoZygote,
f,
x::AbstractVector{<:Real},
aux,
out::DiffResults.MutableDiffResult,
)
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
end

end
20 changes: 14 additions & 6 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@ using LinearAlgebra

using LogDensityProblems

using ADTypes, DiffResults
using ADTypes
using DiffResults
using DifferentiationInterface
using ChainRulesCore

using FillArrays

using StatsBase

# derivatives
# Derivatives
"""
value_and_gradient!(ad, f, x, out)
value_and_gradient!(ad, f, x, aux, out)
Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`.
Expand All @@ -38,7 +39,14 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif
- `aux`: Auxiliary input passed to `f`.
- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value.
"""
function value_and_gradient! end
function value_and_gradient!(
ad::ADTypes.AbstractADType, f, x, aux, out::DiffResults.MutableDiffResult
)
grad_buf = DiffResults.gradient(out)
y, _ = DifferentiationInterface.value_and_gradient!(f, grad_buf, ad, x, Constant(aux))
DiffResults.value!(out, y)
return out
end

"""
restructure_ad_forward(adtype, restructure, params)
Expand Down Expand Up @@ -131,7 +139,7 @@ function estimate_objective end
export estimate_objective

"""
estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state)
estimate_gradient!(rng, obj, adtype, out, prob, params, restructure, obj_state)
Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ`
Expand All @@ -141,7 +149,7 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ
- `adtype::ADTypes.AbstractADType`: Automatic differentiation backend.
- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates.
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
- `λ`: Variational parameters to evaluate the gradient on.
- `params`: Variational parameters to evaluate the gradient on.
- `restructure`: Function that reconstructs the variational approximation from `λ`.
- `obj_state`: Previous state of the objective.
Expand Down
2 changes: 1 addition & 1 deletion src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ The arguments are as follows:
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(param)` reconstructs the variational approximation.
- `gradient`: The estimated (possibly stochastic) gradient.
`cb` can return a `NamedTuple` containing some additional information computed within `cb`.
`callback` can return a `NamedTuple` containing some additional information computed within `cb`.
This will be appended to the statistic of the current corresponding iteration.
Otherwise, just return `nothing`.
Expand Down
6 changes: 4 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand All @@ -26,7 +26,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
ADTypes = "0.2.1, 1"
Bijectors = "0.13"
DiffResults = "1.0"
DiffResults = "1"
DifferentiationInterface = "0.6"
Distributions = "0.25.111"
DistributionsAD = "0.6.45"
FillArrays = "1.6.1"
Expand All @@ -41,6 +42,7 @@ ReverseDiff = "1.15.1"
SimpleUnPack = "1.1.0"
StableRNGs = "1.0.0"
Statistics = "1"
StatsBase = "0.34"
Test = "1"
Tracker = "0.2.20"
Zygote = "0.6.63"
Expand Down
8 changes: 5 additions & 3 deletions test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ AD_distributionsad = Dict(
:Zygote => AutoZygote(),
)

if @isdefined(Tapir)
AD_distributionsad[:Tapir] = AutoTapir(; safe_mode=false)
if @isdefined(Mooncake)
AD_distributionsad[:Mooncake] = AutoMooncake(; config=nothing)
end

if @isdefined(Enzyme)
AD_distributionsad[:Enzyme] = AutoEnzyme()
AD_distributionsad[:Enzyme] = AutoEnzyme(;
mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const
)
end

@testset "inference RepGradELBO DistributionsAD" begin
Expand Down
Loading

0 comments on commit d0efe02

Please sign in to comment.