From cb3b8380ad03ed8e84f7a8bc679cc076ba4ede2e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Jun 2024 20:27:18 -0400 Subject: [PATCH] fix avoid re-defining the differentiation objective to support AD pre-compilation (#66) * update interface for objective initialization * improve `RepGradELBO` to not redefine AD forward path * add auxiliary argument to `value_and_gradient!` --- ext/AdvancedVIBijectorsExt.jl | 6 ++--- ext/AdvancedVIForwardDiffExt.jl | 23 ++++++++++++++----- ext/AdvancedVIReverseDiffExt.jl | 19 +++++++++++++--- ext/AdvancedVIZygoteExt.jl | 23 +++++++++++++++---- src/AdvancedVI.jl | 32 ++++++++++++++++++++------ src/objectives/elbo/repgradelbo.jl | 36 +++++++++++++++++------------- src/optimize.jl | 2 +- src/utils.jl | 17 ++++++++++---- test/Project.toml | 3 +-- test/interface/repgradelbo.jl | 29 ++++++++++++++++++++++++ 10 files changed, 146 insertions(+), 44 deletions(-) diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index 4a88d6fb..a227fdf2 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -42,9 +42,9 @@ function AdvancedVI.reparam_with_entropy( n_samples::Int, ent_est ::AdvancedVI.AbstractEntropyEstimator ) - transform = q.transform - q_unconst = q.dist - q_unconst_stop = q_stop.dist + transform = q.transform + q_unconst = q.dist + q_unconst_stop = q_stop.dist # Draw samples and compute entropy of the uncontrained distribution unconstr_samples, unconst_entropy = AdvancedVI.reparam_with_entropy( diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl index 5949bdf8..a8afd031 100644 --- a/ext/AdvancedVIForwardDiffExt.jl +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -14,16 +14,29 @@ end getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} + ad ::ADTypes.AutoForwardDiff, + f, + x ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult +) chunk_size = getchunksize(ad) config = if isnothing(chunk_size) - ForwardDiff.GradientConfig(f, θ) + ForwardDiff.GradientConfig(f, x) else - ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) + ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk(length(x), chunk_size)) end - ForwardDiff.gradient!(out, f, θ, config) + ForwardDiff.gradient!(out, f, x, config) return out end +function AdvancedVI.value_and_gradient!( + ad ::ADTypes.AutoForwardDiff, + f, + x ::AbstractVector, + aux, + out::DiffResults.MutableDiffResult +) + AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) +end + end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl index 520cd9ff..392f5cea 100644 --- a/ext/AdvancedVIReverseDiffExt.jl +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -13,11 +13,24 @@ end # ReverseDiff without compiled tape function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult + ad::ADTypes.AutoReverseDiff, + f, + x::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult ) - tp = ReverseDiff.GradientTape(f, θ) - ReverseDiff.gradient!(out, tp, θ) + tp = ReverseDiff.GradientTape(f, x) + ReverseDiff.gradient!(out, tp, x) return out end +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoReverseDiff, + f, + x::AbstractVector{<:Real}, + aux, + out::DiffResults.MutableDiffResult +) + AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) +end + end diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl index 7b8f8817..806c08e4 100644 --- a/ext/AdvancedVIZygoteExt.jl +++ b/ext/AdvancedVIZygoteExt.jl @@ -4,21 +4,36 @@ module AdvancedVIZygoteExt if isdefined(Base, :get_extension) using AdvancedVI using AdvancedVI: ADTypes, DiffResults + using ChainRulesCore using Zygote else using ..AdvancedVI using ..AdvancedVI: ADTypes, DiffResults + using ..ChainRulesCore using ..Zygote end function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoZygote, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult + ::ADTypes.AutoZygote, + f, + x::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult ) - y, back = Zygote.pullback(f, θ) - ∇θ = back(one(y)) + y, back = Zygote.pullback(f, x) + ∇x = back(one(y)) DiffResults.value!(out, y) - DiffResults.gradient!(out, only(∇θ)) + DiffResults.gradient!(out, only(∇x)) return out end +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoZygote, + f, + x::AbstractVector{<:Real}, + aux, + out::DiffResults.MutableDiffResult +) + AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) +end + end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 7c7a1fc8..7a09030b 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -25,18 +25,35 @@ using StatsBase # derivatives """ - value_and_gradient!(ad, f, θ, out) + value_and_gradient!(ad, f, x, out) + value_and_gradient!(ad, f, x, aux, out) -Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad` and store the result in `out`. +Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`. +`f` may receive auxiliary input as `f(x,aux)`. # Arguments - `ad::ADTypes.AbstractADType`: Automatic differentiation backend. - `f`: Function subject to differentiation. -- `θ`: The point to evaluate the gradient. +- `x`: The point to evaluate the gradient. +- `aux`: Auxiliary input passed to `f`. - `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value. """ function value_and_gradient! end +""" + stop_gradient(x) + +Stop the gradient from propagating to `x` if the selected ad backend supports it. +Otherwise, it is equivalent to `identity`. + +# Arguments +- `x`: Input + +# Returns +- `x`: Same value as the input. +""" +function stop_gradient end + # Update for gradient descent step """ update_variational_params!(family_type, opt_st, params, restructure, grad) @@ -78,7 +95,7 @@ If the estimator is stateful, it can implement `init` to initialize the state. abstract type AbstractVariationalObjective end """ - init(rng, obj, λ, restructure) + init(rng, obj, prob, params, restructure) Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. This function needs to be implemented only if `obj` is stateful. @@ -86,14 +103,15 @@ This function needs to be implemented only if `obj` is stateful. # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `obj::AbstractVariationalObjective`: Variational objective. -- `λ`: Initial variational parameters. +- `params`: Initial variational parameters. - `restructure`: Function that reconstructs the variational approximation from `λ`. """ init( ::Random.AbstractRNG, ::AbstractVariationalObjective, - ::AbstractVector, - ::Any + ::Any, + ::Any, + ::Any, ) = nothing """ diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index 2d95d076..27a937e8 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -56,14 +56,13 @@ function estimate_energy_with_samples(prob, samples) end """ - reparam_with_entropy(rng, q, q_stop, n_samples, ent_est) + reparam_with_entropy(rng, q, n_samples, ent_est) Draw `n_samples` from `q` and compute its entropy. # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `q`: Variational approximation. -- `q_stop`: `q` but with its gradient stopped. - `n_samples::Int`: Number of Monte Carlo samples - `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.) @@ -72,7 +71,11 @@ Draw `n_samples` from `q` and compute its entropy. - `entropy`: An estimate (or exact value) of the differential entropy of `q`. """ function reparam_with_entropy( - rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator + rng ::Random.AbstractRNG, + q, + q_stop, + n_samples::Int, + ent_est ::AbstractEntropyEstimator ) samples = rand(rng, q, n_samples) entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop) @@ -94,28 +97,31 @@ end estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) = estimate_objective(Random.default_rng(), obj, q, prob; n_samples) +function estimate_repgradelbo_ad_forward(params′, aux) + @unpack rng, obj, problem, restructure, q_stop = aux + q = restructure(params′) + samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy) + energy = estimate_energy_with_samples(problem, samples) + elbo = energy + entropy + -elbo +end + function estimate_gradient!( rng ::Random.AbstractRNG, obj ::RepGradELBO, adtype::ADTypes.AbstractADType, out ::DiffResults.MutableDiffResult, prob, - λ, + params, restructure, state, ) - q_stop = restructure(λ) - function f(λ′) - q = restructure(λ′) - samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy) - energy = estimate_energy_with_samples(prob, samples) - elbo = energy + entropy - -elbo - end - value_and_gradient!(adtype, f, λ, out) - + q_stop = restructure(params) + aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop) + value_and_gradient!( + adtype, estimate_repgradelbo_ad_forward, params, aux, out + ) nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) - out, nothing, stat end diff --git a/src/optimize.jl b/src/optimize.jl index 4eb6644a..e5fe374d 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -66,7 +66,7 @@ function optimize( ) params, restructure = Optimisers.destructure(deepcopy(q_init)) opt_st = maybe_init_optimizer(state_init, optimizer, params) - obj_st = maybe_init_objective(state_init, rng, objective, params, restructure) + obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure) grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) stats = NamedTuple[] diff --git a/src/utils.jl b/src/utils.jl index 98b79b2d..3ae59a78 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,19 +6,28 @@ end function maybe_init_optimizer( state_init::NamedTuple, optimizer ::Optimisers.AbstractRule, - params ::AbstractVector + params ) - haskey(state_init, :optimizer) ? state_init.optimizer : Optimisers.setup(optimizer, params) + if haskey(state_init, :optimizer) + state_init.optimizer + else + Optimisers.setup(optimizer, params) + end end function maybe_init_objective( state_init::NamedTuple, rng ::Random.AbstractRNG, objective ::AbstractVariationalObjective, - params ::AbstractVector, + problem, + params, restructure ) - haskey(state_init, :objective) ? state_init.objective : init(rng, objective, params, restructure) + if haskey(state_init, :objective) + state_init.objective + else + init(rng, objective, problem, params, restructure) + end end eachsample(samples::AbstractMatrix) = eachcol(samples) diff --git a/test/Project.toml b/test/Project.toml index f3acedea..a0dba17f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,9 +1,9 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -26,7 +26,6 @@ ADTypes = "0.2.1, 1" Bijectors = "0.13" Distributions = "0.25.100" DistributionsAD = "0.6.45" -Enzyme = "0.12" FillArrays = "1.6.1" ForwardDiff = "0.10.36" Functors = "0.4.5" diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index 61ff0111..ac9bfeca 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -26,3 +26,32 @@ using Test @test elbo ≈ elbo_ref rtol=0.1 end end + +@testset "interface RepGradELBO STL variance reduction" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = normal_meanfield(rng, Float64) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + @testset for ad in [ + ADTypes.AutoForwardDiff(), + ADTypes.AutoReverseDiff(), + ADTypes.AutoZygote() + ] + q_true = MeanFieldGaussian( + Vector{eltype(μ_true)}(μ_true), + Diagonal(Vector{eltype(L_true)}(diag(L_true))) + ) + params, re = Optimisers.destructure(q_true) + obj = RepGradELBO(10; entropy=StickingTheLandingEntropy()) + out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) + + aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true) + AdvancedVI.value_and_gradient!( + ad, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out + ) + grad = DiffResults.gradient(out) + @test norm(grad) ≈ 0 atol=1e-5 + end +end