diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index 152cd15d..e60538a1 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -57,17 +57,17 @@ Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D) function StatsBase.entropy(q::MvLocationScale) @unpack location, scale, dist = q n_dims = length(location) - n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale)) + n_dims*convert(eltype(location), entropy(dist)) + first(logdet(scale)) end function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale)) end function Distributions._logpdf(q::MvLocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q - sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale)) end function Distributions.rand(q::MvLocationScale) diff --git a/src/optimize.jl b/src/optimize.jl index acb455d2..325de5a2 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -1,17 +1,15 @@ """ - optimize(problem, objective, restructure, param_init, max_iter, objargs...; kwargs...) - optimize(problem, objective, variational_dist_init, max_iter, objargs...; kwargs...) + optimize(problem, objective, q_init, max_iter, objargs...; kwargs...) Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients. -The variational approximation can be constructed by passing the variational parameters `param_init` or the initial variational approximation `variational_dist_init` to the function `restructure`. +The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`. +This requires the variational approximation to be marked as a functor through `Functors.@functor`. # Arguments - `objective::AbstractVariationalObjective`: Variational Objective. -- `param_init`: Initial value of the variational parameters. -- `restruct`: Function that reconstructs the variational approximation from the flattened parameters. -- `variational_dist_init`: Initial variational distribution. The variational parameters must be extractable through `Optimisers.destructure`. +- `q_init`: Initial variational distribution. The variational parameters must be extractable through `Optimisers.destructure`. - `max_iter::Int`: Maximum number of iterations. - `objargs...`: Arguments to be passed to `objective`. @@ -50,8 +48,7 @@ function optimize( rng ::Random.AbstractRNG, problem, objective ::AbstractVariationalObjective, - restructure, - params_init, + q_init, max_iter ::Int, objargs...; adtype ::ADTypes.AbstractADType, @@ -65,9 +62,9 @@ function optimize( barlen = 31, showspeed = true, enabled = show_progress - ) + ), ) - params = copy(params_init) + params, restructure = Optimisers.destructure(deepcopy(q_init)) opt_st = maybe_init_optimizer(state_init, optimizer, params) obj_st = maybe_init_objective(state_init, rng, objective, params, restructure) grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) @@ -78,7 +75,7 @@ function optimize( grad_buf, obj_st, stat′ = estimate_gradient!( rng, objective, adtype, grad_buf, problem, - params, restructure, obj_st, objargs... + params, restructure, obj_st, objargs... ) stat = merge(stat, stat′) @@ -98,52 +95,16 @@ function optimize( pm_next!(prog, stat) push!(stats, stat) end - state = (optimizer=opt_st, objective=obj_st) - stats = map(identity, stats) - params, stats, state -end - -function optimize( - problem, - objective ::AbstractVariationalObjective, - restructure, - params_init, - max_iter ::Int, - objargs...; - kwargs... -) - optimize( - Random.default_rng(), - problem, - objective, - restructure, - params_init, - max_iter, - objargs...; - kwargs... - ) + state = (optimizer=opt_st, objective=obj_st) + stats = map(identity, stats) + restructure(params), stats, state end -function optimize(rng ::Random.AbstractRNG, - problem, - objective ::AbstractVariationalObjective, - variational_dist_init, - n_max_iter ::Int, - objargs...; - kwargs...) - λ, restructure = Optimisers.destructure(variational_dist_init) - λ, logstats, state = optimize( - rng, problem, objective, restructure, λ, n_max_iter, objargs...; kwargs... - ) - restructure(λ), logstats, state -end - - function optimize( problem, - objective ::AbstractVariationalObjective, - variational_dist_init, - max_iter ::Int, + objective::AbstractVariationalObjective, + q_init, + max_iter ::Int, objargs...; kwargs... ) @@ -151,7 +112,7 @@ function optimize( Random.default_rng(), problem, objective, - variational_dist_init, + q_init, max_iter, objargs...; kwargs... diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl index 9666893b..2606851c 100644 --- a/test/interface/optimize.jl +++ b/test/interface/optimize.jl @@ -33,28 +33,6 @@ using Test show_progress = false, adtype, ) - - λ₀, re = Optimisers.destructure(q0) - optimize( - model, obj, re, λ₀, T; - optimizer, - show_progress = false, - adtype, - ) - end - - @testset "restructure" begin - λ₀, re = Optimisers.destructure(q0) - - rng = StableRNG(seed) - λ, stats, _ = optimize( - rng, model, obj, re, λ₀, T; - optimizer, - show_progress = false, - adtype, - ) - @test λ == λ_ref - @test stats == stats_ref end @testset "callback" begin