-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add score estimator with baseline variance reduction. * run formatter Co-authored-by: Kyurae Kim <[email protected]> --------- Co-authored-by: Kyurae Kim <[email protected]>
- Loading branch information
1 parent
be5d7b2
commit 4eab1ac
Showing
18 changed files
with
618 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,5 @@ | ||
|
||
# AdvancedVI.jl Continuous Benchmarking | ||
|
||
This subdirectory contains code for continuous benchmarking of the performance of `AdvancedVI.jl`. | ||
The initial version was heavily inspired by the setup of [Lux.jl](https://github.com/LuxDL/Lux.jl/tree/main). | ||
The Github action and pages integration is provided by https://github.com/benchmark-action/github-action-benchmark/ and [BenchmarkTools.jl](https://github.com/JuliaCI/BenchmarkTools.jl). | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,20 @@ | ||
|
||
function variational_standard_mvnormal(type::Type, n_dims::Int, family::Symbol) | ||
if family == :meanfield | ||
AdvancedVI.MeanFieldGaussian( | ||
zeros(type, n_dims), Diagonal(ones(type, n_dims)) | ||
) | ||
AdvancedVI.MeanFieldGaussian(zeros(type, n_dims), Diagonal(ones(type, n_dims))) | ||
else | ||
AdvancedVI.FullRankGaussian( | ||
zeros(type, n_dims), Matrix(type, I, n_dims, n_dims) | ||
) | ||
AdvancedVI.FullRankGaussian(zeros(type, n_dims), Matrix(type, I, n_dims, n_dims)) | ||
end | ||
end | ||
|
||
function variational_objective(objective::Symbol; kwargs...) | ||
if objective == :RepGradELBO | ||
AdvancedVI.RepGradELBO(kwargs[:n_montecarlo]) | ||
elseif objective == :RepGradELBOSTL | ||
AdvancedVI.RepGradELBO(kwargs[:n_montecarlo], entropy=StickingTheLandingEntropy()) | ||
AdvancedVI.RepGradELBO(kwargs[:n_montecarlo]; entropy=StickingTheLandingEntropy()) | ||
elseif objective == :ScoreGradELBO | ||
throw("ScoreGradELBO not supported yet. Please use ScoreGradELBOSTL instead.") | ||
elseif objective == :ScoreGradELBOSTL | ||
AdvancedVI.ScoreGradELBO(kwargs[:n_montecarlo]; entropy=StickingTheLandingEntropy()) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,50 @@ | ||
|
||
# [General Usage](@id general) | ||
|
||
Each VI algorithm provides the followings: | ||
1. Variational families supported by each VI algorithm. | ||
2. A variational objective corresponding to the VI algorithm. | ||
Note that each variational family is subject to its own constraints. | ||
Thus, please refer to the documentation of the variational inference algorithm of interest. | ||
|
||
1. Variational families supported by each VI algorithm. | ||
2. A variational objective corresponding to the VI algorithm. | ||
Note that each variational family is subject to its own constraints. | ||
Thus, please refer to the documentation of the variational inference algorithm of interest. | ||
|
||
## Optimizing a Variational Objective | ||
|
||
After constructing a *variational objective* `objective` and initializing a *variational approximation*, one can optimize `objective` by calling `optimize`: | ||
|
||
```@docs | ||
optimize | ||
``` | ||
|
||
## Estimating the Objective | ||
|
||
In some cases, it is useful to directly estimate the objective value. | ||
This can be done by the following funciton: | ||
|
||
```@docs | ||
estimate_objective | ||
``` | ||
|
||
!!! info | ||
Note that `estimate_objective` is not expected to be differentiated through, and may not result in optimal statistical performance. | ||
|
||
Note that `estimate_objective` is not expected to be differentiated through, and may not result in optimal statistical performance. | ||
|
||
## Advanced Usage | ||
|
||
Each variational objective is a subtype of the following abstract type: | ||
|
||
```@docs | ||
AdvancedVI.AbstractVariationalObjective | ||
``` | ||
|
||
Furthermore, `AdvancedVI` only interacts with each variational objective by querying gradient estimates. | ||
Therefore, to create a new custom objective to be optimized through `AdvancedVI`, it suffices to implement the following function: | ||
|
||
```@docs | ||
AdvancedVI.estimate_gradient! | ||
``` | ||
|
||
If an objective needs to be stateful, one can implement the following function to inialize the state. | ||
|
||
```@docs | ||
AdvancedVI.init | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.