From 5f1fb52b5be0c46ea295087f9f3644396f239c66 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 24 Oct 2023 00:06:50 -0400 Subject: [PATCH] refactor rewrite the documentation for the global interfaces --- src/AdvancedVI.jl | 73 ++++++++++++++++++++++++----------------------- 1 file changed, 37 insertions(+), 36 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index dd7f10ae..54c2b1eb 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -27,17 +27,15 @@ using StatsBase # derivatives """ - value_and_gradient!( - ad::ADTypes.AbstractADType, - f, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult - ) - -Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad`. -The result is stored in `out`. -The function `f` must return a scalar value. -The gradient is stored in `out` as a vector of the same length as `θ`. + value_and_gradient!(ad, f, θ, out) + +Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad` and store the result in `out`. + +# Arguments +- `ad::ADTypes.AbstractADType`: Automatic differentiation backend. +- `f`: Function subject to differentiation. +- `θ`: The point to evaluate the gradient. +- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value. """ function value_and_gradient! end @@ -45,22 +43,26 @@ function value_and_gradient! end """ AbstractVariationalObjective -An VI algorithm supported by `AdvancedVI` should implement a subtype of `AbstractVariationalObjective`. -Furthermore, it should implement the functions `estimate_gradient`. +Abstract type for the VI algorithms supported by `AdvancedVI`. + +# Implementations +To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective`. +Also, it should provide gradients by implementing the function `estimate_gradient!`. If the estimator is stateful, it can implement `init` to initialize the state. """ abstract type AbstractVariationalObjective end """ - init( - rng::Random.AbstractRNG, - obj::AbstractVariationalObjective, - λ::AbstractVector, - restructure - ) + init(rng, obj, λ, restructure) Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. This function needs to be implemented only if `obj` is stateful. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +- `λ`: Initial variational parameters. +- `restructure`: Function that reconstructs the variational approximation from `λ`. """ init( ::Random.AbstractRNG, @@ -70,25 +72,24 @@ init( ) = nothing """ - estimate_gradient!( - rng ::Random.AbstractRNG, - obj ::AbstractVariationalObjective, - adbackend ::ADTypes.AbstractADType, - out ::DiffResults.MutableDiffResult - prob, - λ, - restructure, - obj_state, - ) - -Estimate (possibly stochastic) gradients of the objective `obj` targeting `prob` with respect to the variational parameters `λ` using the automatic differentiation backend `adbackend`. -The estimated objective value and gradient are then stored in `out`. -If the objective is stateful, `obj_state` is its previous state, otherwise, it is `nothing`. + estimate_gradient!(rng, obj, adbackend, out, prob, λ, restructure, obj_state) + +Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +- `adbackend::ADTypes.AbstractADType`: Automatic differentiation backend. +- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. +- `λ`: Variational parameters to evaluate the gradient on. +- `restructure`: Function that reconstructs the variational approximation from `λ`. +- `obj_state`: Previous state of the objective. # Returns -- `out`: The `MutableDiffResult` containing the objective value and gradient estimates. -- `obj_state`: The updated state of the objective estimator. -- `stat`: Statistics and logs generated during estimation. (Type: `<: NamedTuple`) +- `out::MutableDiffResult`: Buffer containing the objective value and gradient estimates. +- `obj_state`: The updated state of the objective. +- `stat::NamedTuple`: Statistics and logs generated during estimation. """ function estimate_gradient! end