Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add benchmarks #104

Merged
merged 13 commits into from
Oct 5, 2024
2 changes: 1 addition & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COMPATHELPER_PRIV: ${{ secrets.COMPATHELPER_PRIV }}
run: julia -e 'using CompatHelper; CompatHelper.main()'
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs=["", "bench", "test", "docs"])'
22 changes: 22 additions & 0 deletions bench/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,35 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
BenchmarkTools = "1"
Bijectors = "0.13"
Distributions = "0.25.111"
DistributionsAD = "0.6"
Enzyme = "0.13.7"
FillArrays = "1"
ForwardDiff = "0.10"
InteractiveUtils = "1"
LogDensityProblems = "2"
Mooncake = "0.4.5"
Optimisers = "0.3"
Random = "1"
ReverseDiff = "1"
SimpleUnPack = "1"
StableRNGs = "1"
Zygote = "0.6"
julia = "1.10"
85 changes: 57 additions & 28 deletions bench/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@

using ADTypes, ForwardDiff, ReverseDiff, Zygote
using ADTypes
using AdvancedVI
using BenchmarkTools
using Bijectors
using Distributions
using DistributionsAD
using Enzyme, ForwardDiff, ReverseDiff, Zygote, Mooncake
using FillArrays
using InteractiveUtils
using LinearAlgebra
Expand All @@ -17,37 +18,65 @@ BLAS.set_num_threads(min(4, Threads.nthreads()))
@info sprint(versioninfo)
@info "BLAS threads: $(BLAS.get_num_threads())"

include("utils.jl")
include("normallognormal.jl")
include("unconstrdist.jl")

const SUITES = BenchmarkGroup()

# Comment until https://github.com/TuringLang/Bijectors.jl/pull/315 is merged
# SUITES["normal + bijector"]["meanfield"]["Zygote"] =
# @benchmarkable normallognormal(
# ;
# fptype = Float64,
# adtype = AutoZygote(),
# family = :meanfield,
# objective = :RepGradELBO,
# n_montecarlo = 4,
# )

SUITES["normal + bijector"]["meanfield"]["ReverseDiff"] = @benchmarkable normallognormal(;
fptype=Float64,
adtype=AutoReverseDiff(),
family=:meanfield,
objective=:RepGradELBO,
n_montecarlo=4,
)

SUITES["normal + bijector"]["meanfield"]["ForwardDiff"] = @benchmarkable normallognormal(;
fptype=Float64,
adtype=AutoForwardDiff(),
family=:meanfield,
objective=:RepGradELBO,
n_montecarlo=4,
)
function variational_standard_mvnormal(type::Type, n_dims::Int, family::Symbol)
if family == :meanfield
MeanFieldGaussian(zeros(type, n_dims), Diagonal(ones(type, n_dims)))
else
FullRankGaussian(zeros(type, n_dims), Matrix(type, I, n_dims, n_dims))
end
end

begin
fptype = Float64

for (probname, prob) in [
("normal + bijector", normallognormal(; n_dims=10, fptype))
("normal", normal(; n_dims=10, fptype))
]
max_iter = 10^4
d = LogDensityProblems.dimension(prob)
optimizer = Optimisers.Adam(fptype(1e-3))

for (objname, obj) in [
("RepGradELBO", RepGradELBO(10)),
("RepGradELBO + STL", RepGradELBO(10; entropy=StickingTheLandingEntropy())),
],
(adname, adtype) in [
("Zygote", AutoZygote()),
("ForwardDiff", AutoForwardDiff()),
("ReverseDiff", AutoReverseDiff()),
#("Mooncake", AutoMooncake(; config=nothing)),
#("Enzyme", AutoEnzyme()),
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
],
(familyname, family) in [
("meanfield", MeanFieldGaussian(zeros(d), Diagonal(ones(d)))),
(
"fullrank",
FullRankGaussian(zeros(d), LowerTriangular(Matrix{fptype}(I, d, d))),
),
]

b = Bijectors.bijector(prob)
binv = inverse(b)
q = Bijectors.TransformedDistribution(family, binv)

SUITES[probname][objname][familyname][adname] = @benchmarkable AdvancedVI.optimize(
$prob,
$obj,
$q,
$max_iter;
adtype=$adtype,
optimizer=$optimizer,
show_progress=false,
)
end
end
end

BenchmarkTools.tune!(SUITES; verbose=true)
results = BenchmarkTools.run(SUITES; verbose=true)
Expand Down
22 changes: 1 addition & 21 deletions bench/normallognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,10 @@ function Bijectors.bijector(model::NormalLogNormal)
)
end

function normallognormal(; fptype, adtype, family, objective, max_iter=10^3, kwargs...)
n_dims = 10
function normallognormal(; n_dims=10, fptype=Float64)
μ_x = fptype(5.0)
σ_x = fptype(0.3)
μ_y = Fill(fptype(5.0), n_dims)
σ_y = Fill(fptype(0.3), n_dims)
model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2))

obj = variational_objective(objective; kwargs...)

d = LogDensityProblems.dimension(model)
q = variational_standard_mvnormal(fptype, d, family)

b = Bijectors.bijector(model)
binv = inverse(b)
q_transformed = Bijectors.TransformedDistribution(q, binv)

return AdvancedVI.optimize(
model,
obj,
q_transformed,
max_iter;
adtype,
optimizer=Optimisers.Adam(fptype(1e-3)),
show_progress=false,
)
end
26 changes: 26 additions & 0 deletions bench/unconstrdist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

struct UnconstrDist{D <: ContinuousMultivariateDistribution}
dist::D
end

function LogDensityProblems.logdensity(model::UnconstrDist, x)
return logpdf(model.dist, x)
end

function LogDensityProblems.dimension(model::UnconstrDist)
return length(model.dist)
end

function LogDensityProblems.capabilities(::Type{<:UnconstrDist})
return LogDensityProblems.LogDensityOrder{0}()
end

function Bijectors.bijector(model::UnconstrDist)
return identity
end

function normal(; n_dims=10, fptype=Float64)
μ = fill(fptype(5), n_dims)
Σ = Diagonal(ones(fptype, n_dims))
UnconstrDist(MvNormal(μ, Σ))
end
20 changes: 0 additions & 20 deletions bench/utils.jl

This file was deleted.

Loading