Skip to content

Commit

Permalink
remove signature with user-defined restructure (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal authored Jun 7, 2024
1 parent 314eacf commit 75eb334
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 79 deletions.
6 changes: 3 additions & 3 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,17 @@ Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D)
function StatsBase.entropy(q::MvLocationScale)
@unpack location, scale, dist = q
n_dims = length(location)
n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale))
n_dims*convert(eltype(location), entropy(dist)) + first(logdet(scale))
end

function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale))
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale))
end

function Distributions._logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale))
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale))
end

function Distributions.rand(q::MvLocationScale)
Expand Down
69 changes: 15 additions & 54 deletions src/optimize.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@

"""
optimize(problem, objective, restructure, param_init, max_iter, objargs...; kwargs...)
optimize(problem, objective, variational_dist_init, max_iter, objargs...; kwargs...)
optimize(problem, objective, q_init, max_iter, objargs...; kwargs...)
Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients.
The variational approximation can be constructed by passing the variational parameters `param_init` or the initial variational approximation `variational_dist_init` to the function `restructure`.
The trainable parameters in the variational approximation are expected to be extractable through `Optimisers.destructure`.
This requires the variational approximation to be marked as a functor through `Functors.@functor`.
# Arguments
- `objective::AbstractVariationalObjective`: Variational Objective.
- `param_init`: Initial value of the variational parameters.
- `restruct`: Function that reconstructs the variational approximation from the flattened parameters.
- `variational_dist_init`: Initial variational distribution. The variational parameters must be extractable through `Optimisers.destructure`.
- `q_init`: Initial variational distribution. The variational parameters must be extractable through `Optimisers.destructure`.
- `max_iter::Int`: Maximum number of iterations.
- `objargs...`: Arguments to be passed to `objective`.
Expand Down Expand Up @@ -50,8 +48,7 @@ function optimize(
rng ::Random.AbstractRNG,
problem,
objective ::AbstractVariationalObjective,
restructure,
params_init,
q_init,
max_iter ::Int,
objargs...;
adtype ::ADTypes.AbstractADType,
Expand All @@ -65,9 +62,9 @@ function optimize(
barlen = 31,
showspeed = true,
enabled = show_progress
)
),
)
params = copy(params_init)
params, restructure = Optimisers.destructure(deepcopy(q_init))
opt_st = maybe_init_optimizer(state_init, optimizer, params)
obj_st = maybe_init_objective(state_init, rng, objective, params, restructure)
grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params))
Expand All @@ -78,7 +75,7 @@ function optimize(

grad_buf, obj_st, stat′ = estimate_gradient!(
rng, objective, adtype, grad_buf, problem,
params, restructure, obj_st, objargs...
params, restructure, obj_st, objargs...
)
stat = merge(stat, stat′)

Expand All @@ -98,60 +95,24 @@ function optimize(
pm_next!(prog, stat)
push!(stats, stat)
end
state = (optimizer=opt_st, objective=obj_st)
stats = map(identity, stats)
params, stats, state
end

function optimize(
problem,
objective ::AbstractVariationalObjective,
restructure,
params_init,
max_iter ::Int,
objargs...;
kwargs...
)
optimize(
Random.default_rng(),
problem,
objective,
restructure,
params_init,
max_iter,
objargs...;
kwargs...
)
state = (optimizer=opt_st, objective=obj_st)
stats = map(identity, stats)
restructure(params), stats, state
end

function optimize(rng ::Random.AbstractRNG,
problem,
objective ::AbstractVariationalObjective,
variational_dist_init,
n_max_iter ::Int,
objargs...;
kwargs...)
λ, restructure = Optimisers.destructure(variational_dist_init)
λ, logstats, state = optimize(
rng, problem, objective, restructure, λ, n_max_iter, objargs...; kwargs...
)
restructure(λ), logstats, state
end


function optimize(
problem,
objective ::AbstractVariationalObjective,
variational_dist_init,
max_iter ::Int,
objective::AbstractVariationalObjective,
q_init,
max_iter ::Int,
objargs...;
kwargs...
)
optimize(
Random.default_rng(),
problem,
objective,
variational_dist_init,
q_init,
max_iter,
objargs...;
kwargs...
Expand Down
22 changes: 0 additions & 22 deletions test/interface/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,6 @@ using Test
show_progress = false,
adtype,
)

λ₀, re = Optimisers.destructure(q0)
optimize(
model, obj, re, λ₀, T;
optimizer,
show_progress = false,
adtype,
)
end

@testset "restructure" begin
λ₀, re = Optimisers.destructure(q0)

rng = StableRNG(seed)
λ, stats, _ = optimize(
rng, model, obj, re, λ₀, T;
optimizer,
show_progress = false,
adtype,
)
@test λ == λ_ref
@test stats == stats_ref
end

@testset "callback" begin
Expand Down

1 comment on commit 75eb334

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 75eb334 Previous: 314eacf Ratio
normal + bijector/meanfield/ForwardDiff 498137471 ns 500583347 ns 1.00
normal + bijector/meanfield/ReverseDiff 141700614.5 ns 136374500 ns 1.04

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.