Skip to content

Commit

Permalink
add docstrings to elbo objective forward ad paths
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Dec 4, 2024
1 parent 5ff79e3 commit eda4ea0
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 7 deletions.
22 changes: 20 additions & 2 deletions src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,27 @@ function estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int=obj.n_samp
return estimate_objective(Random.default_rng(), obj, q, prob; n_samples)
end

function estimate_repgradelbo_ad_forward(params′, aux)
"""
estimate_repgradelbo_ad_forward(params, aux)
AD-guaranteed forward path of the reparameterization gradient objective.
# Arguments
- `params`: Variational parameters.
- `aux`: Auxiliary information excluded from the AD path.
# Auxiliary Information
`aux` should containt the following entries:
- `rng`: Random number generator.
- `obj`: The `RepGradELBO` objective.
- `problem`: The target `LogDensityProblem`.
- `adtype`: The `ADType` used for differentiating the forward path.
- `restructure`: Callable for restructuring the varitional distribution from `params`.
- `q_stop`: A copy of `restructure(params)` with its gradient "stopped" (excluded from the AD path).
"""
function estimate_repgradelbo_ad_forward(params, aux)
(; rng, obj, problem, adtype, restructure, q_stop) = aux
q = restructure_ad_forward(adtype, restructure, params)
q = restructure_ad_forward(adtype, restructure, params)
samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy)
energy = estimate_energy_with_samples(problem, samples)
elbo = energy + entropy
Expand Down
26 changes: 21 additions & 5 deletions src/objectives/elbo/scoregradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,27 @@ function estimate_objective(obj::ScoreGradELBO, q, prob; n_samples::Int=obj.n_sa
return estimate_objective(Random.default_rng(), obj, q, prob; n_samples)
end

function estimate_scoregradelbo_ad_forward(params′, aux)
(; logprob, adtype, restructure, samples) = aux
q = restructure_ad_forward(adtype, restructure, params′)
ℓπ = logprob
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples))
"""
estimate_scoregradelbo_ad_forward(params, aux)
AD-guaranteed forward path of the score gradient objective.
# Arguments
- `params`: Variational parameters.
- `aux`: Auxiliary information excluded from the AD path.
# Auxiliary Information
`aux` should containt the following entries:
- `samples_stop`: Samples drawn from `q = restructure(params)` but with their gradients stopped (excluded from the AD path).
- `logprob_stop`: Log-densities of the target `LogDensityProblem` evaluated over `samples_stop`.
- `adtype`: The `ADType` used for differentiating the forward path.
- `restructure`: Callable for restructuring the varitional distribution from `params`.
"""
function estimate_scoregradelbo_ad_forward(params, aux)
(; samples_stop, logprob_stop, adtype, restructure) = aux
q = restructure_ad_forward(adtype, restructure, params)
ℓπ = logprob_stop
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples_stop))
f = ℓq - ℓπ
return (mean(abs2, f) - mean(f)^2) / 2
end
Expand Down

0 comments on commit eda4ea0

Please sign in to comment.