Skip to content

Commit

Permalink
refactor rewrite the documentation for the global interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Oct 24, 2023
1 parent 8af8a5f commit 5f1fb52
Showing 1 changed file with 37 additions and 36 deletions.
73 changes: 37 additions & 36 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,40 +27,42 @@ using StatsBase

# derivatives
"""
value_and_gradient!(
ad::ADTypes.AbstractADType,
f,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult
)
Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad`.
The result is stored in `out`.
The function `f` must return a scalar value.
The gradient is stored in `out` as a vector of the same length as `θ`.
value_and_gradient!(ad, f, θ, out)
Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad` and store the result in `out`.
# Arguments
- `ad::ADTypes.AbstractADType`: Automatic differentiation backend.
- `f`: Function subject to differentiation.
- `θ`: The point to evaluate the gradient.
- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value.
"""
function value_and_gradient! end

# estimators
"""
AbstractVariationalObjective
An VI algorithm supported by `AdvancedVI` should implement a subtype of `AbstractVariationalObjective`.
Furthermore, it should implement the functions `estimate_gradient`.
Abstract type for the VI algorithms supported by `AdvancedVI`.
# Implementations
To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective`.
Also, it should provide gradients by implementing the function `estimate_gradient!`.
If the estimator is stateful, it can implement `init` to initialize the state.
"""
abstract type AbstractVariationalObjective end

"""
init(
rng::Random.AbstractRNG,
obj::AbstractVariationalObjective,
λ::AbstractVector,
restructure
)
init(rng, obj, λ, restructure)
Initialize a state of the variational objective `obj` given the initial variational parameters `λ`.
This function needs to be implemented only if `obj` is stateful.
# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `obj::AbstractVariationalObjective`: Variational objective.
- `λ`: Initial variational parameters.
- `restructure`: Function that reconstructs the variational approximation from `λ`.
"""
init(
::Random.AbstractRNG,
Expand All @@ -70,25 +72,24 @@ init(
) = nothing

"""
estimate_gradient!(
rng ::Random.AbstractRNG,
obj ::AbstractVariationalObjective,
adbackend ::ADTypes.AbstractADType,
out ::DiffResults.MutableDiffResult
prob,
λ,
restructure,
obj_state,
)
Estimate (possibly stochastic) gradients of the objective `obj` targeting `prob` with respect to the variational parameters `λ` using the automatic differentiation backend `adbackend`.
The estimated objective value and gradient are then stored in `out`.
If the objective is stateful, `obj_state` is its previous state, otherwise, it is `nothing`.
estimate_gradient!(rng, obj, adbackend, out, prob, λ, restructure, obj_state)
Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ`
# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `obj::AbstractVariationalObjective`: Variational objective.
- `adbackend::ADTypes.AbstractADType`: Automatic differentiation backend.
- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates.
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
- `λ`: Variational parameters to evaluate the gradient on.
- `restructure`: Function that reconstructs the variational approximation from `λ`.
- `obj_state`: Previous state of the objective.
# Returns
- `out`: The `MutableDiffResult` containing the objective value and gradient estimates.
- `obj_state`: The updated state of the objective estimator.
- `stat`: Statistics and logs generated during estimation. (Type: `<: NamedTuple`)
- `out::MutableDiffResult`: Buffer containing the objective value and gradient estimates.
- `obj_state`: The updated state of the objective.
- `stat::NamedTuple`: Statistics and logs generated during estimation.
"""
function estimate_gradient! end

Expand Down

0 comments on commit 5f1fb52

Please sign in to comment.