From 548a281cd8628ef4c5dc757cc740296d0f6d66bc Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 11 Aug 2024 22:39:06 +0100 Subject: [PATCH 1/7] add subsampling objective --- src/objectives/subsampling.jl | 72 +++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 src/objectives/subsampling.jl diff --git a/src/objectives/subsampling.jl b/src/objectives/subsampling.jl new file mode 100644 index 00000000..323b51fd --- /dev/null +++ b/src/objectives/subsampling.jl @@ -0,0 +1,72 @@ + +# This function/signature will be moved to src/AdvancedVI.jl +""" + subsample(model, batch) + +# Arguments +- `model`: Model subject to subsampling. Could be the target model or the variational approximation. +- `batch`: Data points or indices corresponding to the subsampled "batch." + +# Returns +- `sub`: Subsampled model. +""" +subsample(model::Any, ::Any) = model + +struct Subsampling{ + O<:AbstractVariationalObjective,D<:AbstractVector +} <: AbstractVariationalObjective + batchsize::Int + objective::O + data::D +end + +function init_batch(rng::Random.AbstractRNG, data::AbstractVector, batchsize::Int) + shuffled = Random.shuffle(rng, data) + batches = Iterators.partition(shuffled, batchsize) + return enumerate(batches) +end + +function AdvancedVI.init(rng::Random.AbstractRNG, sub::Subsampling, params, restructure) + @unpack batchsize, objective, indices = sub + epoch = 1 + sub_state = (epoch, init_batch(rng, indices, batchsize)) + obj_state = AdvancedVI.init(rng, objective, params, restructure) + return (sub_state, obj_state) +end + +function next_batch(rng::Random.AbstractRNG, sub::Subsampling, sub_state) + epoch, batch_itr = sub_state + (step, batch), batch_itr′ = Iterators.peel(batch_itr) + epoch′, batch_itr′′ = if isempty(batch_itr′) + epoch + 1, init_batch(rng, sub.data, sub.batchsize) + else + epoch, batch_itr′ + end + stat = (epoch=epoch, step=step) + return batch, (epoch′, batch_itr′′), stat +end + +function estimate_gradient!( + rng::Random.AbstractRNG, + sub::Subsampling, + adtype::ADTypes.AbstractADType, + out::DiffResults.MutableDiffResult, + prob, + params, + restructure, + state, +) + obj = sub.objective + sub_st, obj_st = state + q = restructure(params) + + batch, sub_st′, sub_stat = next_batch(rng, sub, sub_st) + prob_sub = subsample(prob, batch) + q_sub = subsample(q, batch) + params_sub, re_sub = Optimisers.destructure(q_sub) + + out, obj_st′, obj_stat = AdvancedVI.estimate_gradient( + rng, obj, adtype, out, prob_sub, params, params_sub, re_sub, obj_st + ) + return out, (sub_st′, obj_st′), merge(sub_stat, obj_stat) +end From 644d314ea66653c19629a2d1c1ee1cd48c153749 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sun, 11 Aug 2024 23:08:03 +0100 Subject: [PATCH 2/7] fix wrong function name --- src/objectives/subsampling.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/objectives/subsampling.jl b/src/objectives/subsampling.jl index 323b51fd..9dfddc65 100644 --- a/src/objectives/subsampling.jl +++ b/src/objectives/subsampling.jl @@ -65,7 +65,7 @@ function estimate_gradient!( q_sub = subsample(q, batch) params_sub, re_sub = Optimisers.destructure(q_sub) - out, obj_st′, obj_stat = AdvancedVI.estimate_gradient( + out, obj_st′, obj_stat = AdvancedVI.estimate_gradient!( rng, obj, adtype, out, prob_sub, params, params_sub, re_sub, obj_st ) return out, (sub_st′, obj_st′), merge(sub_stat, obj_stat) From da01ffc35dcced7ef17891282347d435857731f0 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 9 Sep 2024 22:43:42 -0700 Subject: [PATCH 3/7] add `subsampled` objective with tests --- src/AdvancedVI.jl | 18 +++++++ src/objectives/subsampled.jl | 90 +++++++++++++++++++++++++++++++++ src/objectives/subsampling.jl | 72 --------------------------- test/Project.toml | 4 ++ test/interface/subsampled.jl | 94 +++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 + 6 files changed, 208 insertions(+), 72 deletions(-) create mode 100644 src/objectives/subsampled.jl delete mode 100644 src/objectives/subsampling.jl create mode 100644 test/interface/subsampled.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 8ac1b645..0684ca5f 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -130,6 +130,24 @@ function estimate_objective end export estimate_objective +# Oejectives + +""" + subsample(model, batch) + +# Arguments +- `model`: Model subject to subsampling. Could be the target model or the variational approximation. +- `batch`: Data points or indices corresponding to the subsampled "batch." + +# Returns +- `sub`: Subsampled model. +""" +subsample(model::Any, ::Any) = model + +include("objectives/subsampled.jl") + +export Subsampled + """ estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state) diff --git a/src/objectives/subsampled.jl b/src/objectives/subsampled.jl new file mode 100644 index 00000000..fb8c36a6 --- /dev/null +++ b/src/objectives/subsampled.jl @@ -0,0 +1,90 @@ + +struct Subsampled{O<:AbstractVariationalObjective,D<:AbstractVector} <: + AbstractVariationalObjective + objective::O + batchsize::Int + data::D +end + +function init_batch(rng::Random.AbstractRNG, data::AbstractVector, batchsize::Int) + shuffled = Random.shuffle(rng, data) + batches = Iterators.partition(shuffled, batchsize) + return enumerate(batches) +end + +function AdvancedVI.init( + rng::Random.AbstractRNG, sub::Subsampled, prob, params, restructure +) + @unpack batchsize, objective, data = sub + epoch = 1 + sub_state = (epoch, init_batch(rng, data, batchsize)) + obj_state = AdvancedVI.init(rng, objective, prob, params, restructure) + return (sub_state, obj_state) +end + +function next_batch(rng::Random.AbstractRNG, sub::Subsampled, sub_state) + epoch, batch_itr = sub_state + (step, batch), batch_itr′ = Iterators.peel(batch_itr) + epoch′, batch_itr′′ = if isempty(batch_itr′) + epoch + 1, init_batch(rng, sub.data, sub.batchsize) + else + epoch, batch_itr′ + end + stat = (epoch=epoch, step=step) + return batch, (epoch′, batch_itr′′), stat +end + +function estimate_objective( + rng::Random.AbstractRNG, + sub::Subsampled, + q, + prob; + n_batches::Int=ceil(Int, length(sub.data) / sub.batchsize), + kwargs..., +) + @unpack objective, batchsize, data = sub + sub_st = (1, init_batch(rng, data, batchsize)) + return mean(1:n_batches) do _ + batch, sub_st, _ = next_batch(rng, sub, sub_st) + prob_sub = subsample(prob, batch) + q_sub = subsample(q, batch) + estimate_objective(rng, objective, q_sub, prob_sub; kwargs...) + end +end + +function estimate_objective( + sub::Subsampled, + q, + prob; + n_batches::Int=ceil(Int, length(sub.data) / sub.batchsize), + kwargs..., +) + return estimate_objective(Random.default_rng(), sub, q, prob; n_batches, kwargs...) +end + +function estimate_gradient!( + rng::Random.AbstractRNG, + sub::Subsampled, + adtype::ADTypes.AbstractADType, + out::DiffResults.MutableDiffResult, + prob, + params, + restructure, + state, + objargs...; + kwargs..., +) + obj = sub.objective + sub_st, obj_st = state + q = restructure(params) + + batch, sub_st′, sub_stat = next_batch(rng, sub, sub_st) + prob_sub = subsample(prob, batch) + q_sub = subsample(q, batch) + params_sub, re_sub = Optimisers.destructure(q_sub) + + out, obj_st′, obj_stat = AdvancedVI.estimate_gradient!( + rng, obj, adtype, out, prob_sub, params_sub, re_sub, obj_st, objargs...; kwargs... + ) + return out, (sub_st′, obj_st′), merge(sub_stat, obj_stat) +end diff --git a/src/objectives/subsampling.jl b/src/objectives/subsampling.jl deleted file mode 100644 index 9dfddc65..00000000 --- a/src/objectives/subsampling.jl +++ /dev/null @@ -1,72 +0,0 @@ - -# This function/signature will be moved to src/AdvancedVI.jl -""" - subsample(model, batch) - -# Arguments -- `model`: Model subject to subsampling. Could be the target model or the variational approximation. -- `batch`: Data points or indices corresponding to the subsampled "batch." - -# Returns -- `sub`: Subsampled model. -""" -subsample(model::Any, ::Any) = model - -struct Subsampling{ - O<:AbstractVariationalObjective,D<:AbstractVector -} <: AbstractVariationalObjective - batchsize::Int - objective::O - data::D -end - -function init_batch(rng::Random.AbstractRNG, data::AbstractVector, batchsize::Int) - shuffled = Random.shuffle(rng, data) - batches = Iterators.partition(shuffled, batchsize) - return enumerate(batches) -end - -function AdvancedVI.init(rng::Random.AbstractRNG, sub::Subsampling, params, restructure) - @unpack batchsize, objective, indices = sub - epoch = 1 - sub_state = (epoch, init_batch(rng, indices, batchsize)) - obj_state = AdvancedVI.init(rng, objective, params, restructure) - return (sub_state, obj_state) -end - -function next_batch(rng::Random.AbstractRNG, sub::Subsampling, sub_state) - epoch, batch_itr = sub_state - (step, batch), batch_itr′ = Iterators.peel(batch_itr) - epoch′, batch_itr′′ = if isempty(batch_itr′) - epoch + 1, init_batch(rng, sub.data, sub.batchsize) - else - epoch, batch_itr′ - end - stat = (epoch=epoch, step=step) - return batch, (epoch′, batch_itr′′), stat -end - -function estimate_gradient!( - rng::Random.AbstractRNG, - sub::Subsampling, - adtype::ADTypes.AbstractADType, - out::DiffResults.MutableDiffResult, - prob, - params, - restructure, - state, -) - obj = sub.objective - sub_st, obj_st = state - q = restructure(params) - - batch, sub_st′, sub_stat = next_batch(rng, sub, sub_st) - prob_sub = subsample(prob, batch) - q_sub = subsample(q, batch) - params_sub, re_sub = Optimisers.destructure(q_sub) - - out, obj_st′, obj_stat = AdvancedVI.estimate_gradient!( - rng, obj, adtype, out, prob_sub, params, params_sub, re_sub, obj_st - ) - return out, (sub_st′, obj_st′), merge(sub_stat, obj_stat) -end diff --git a/test/Project.toml b/test/Project.toml index 251869e7..4e7f86c9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -25,6 +26,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "0.2.1, 1" +Accessors = "0.1" Bijectors = "0.13" DiffResults = "1.0" Distributions = "0.25.100" @@ -37,11 +39,13 @@ LinearAlgebra = "1" LogDensityProblems = "2.1.1" Optimisers = "0.2.16, 0.3" PDMats = "0.11.7" +Pkg = "1" Random = "1" ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" StableRNGs = "1.0.0" Statistics = "1" +StatsBase = "0.34" Test = "1" Tracker = "0.2.20" Zygote = "0.6.63" diff --git a/test/interface/subsampled.jl b/test/interface/subsampled.jl new file mode 100644 index 00000000..00ecb4df --- /dev/null +++ b/test/interface/subsampled.jl @@ -0,0 +1,94 @@ + +using Test + +struct SubsampledNormals{D <: Normal, F <: Real} + dists::Vector{D} + likeadj::F +end + +function SubsampledNormals(rng::Random.AbstractRNG, n_normals::Int) + μs = randn(rng, n_normals) + σs = ones(n_normals) + dists = Normal.(μs, σs) + SubsampledNormals{eltype(dists), Float64}(dists, 1.0) +end + +function LogDensityProblems.logdensity(m::SubsampledNormals, x) + @unpack likeadj, dists = m + likeadj*mapreduce(Base.Fix2(logpdf, only(x)), +, dists) +end + +function AdvancedVI.subsample(m::SubsampledNormals, idx) + n_data = length(m.dists) + SubsampledNormals(m.dists[idx], n_data/length(idx)) +end + +@testset "interface Subsampled" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + n_data = 16 + prob = SubsampledNormals(rng, n_data) + + q0 = MeanFieldGaussian(zeros(Float64, 1), Diagonal(ones(Float64, 1))) + full_obj = RepGradELBO(10) + sub_obj = Subsampled(full_obj, 1, 1:n_data) + + adtype = AutoForwardDiff() + optimizer = Optimisers.Adam(1e-2) + averager = PolynomialAveraging() + + T = 128 + @testset "determinism" begin + rng = StableRNG(seed) + q_avg, q, _, _ = optimize( + rng, prob, sub_obj, q0, T; optimizer, averager, show_progress=false, adtype + ) + + rng = StableRNG(seed) + q_avg_ref, q_ref, _, _ = optimize( + rng, prob, sub_obj, q0, T; optimizer, averager, show_progress=false, adtype + ) + + @test q_avg == q_avg_ref + @test q == q_ref + + rng = StableRNG(seed) + sub_objval_ref = estimate_objective(rng, sub_obj, q0, prob) + + rng = StableRNG(seed) + sub_objval = estimate_objective(rng, sub_obj, q0, prob) + @test sub_objval == sub_objval_ref + end + + @testset "exactness estimate_objective batchsize=$(batchsize)" for batchsize in [1, 3, 4] + sub_obj′ = @set sub_obj.batchsize = batchsize + full_objval = estimate_objective(full_obj, q0, prob; n_samples=10^6) + sub_objval = estimate_objective(sub_obj′, q0, prob; n_samples=10^6) + @test full_objval ≈ sub_objval rtol=0.1 + end + + @testset "exactness estimate_gradient batchsize=$(batchsize)" for batchsize in [1, 3, 4] + params, restructure = Optimisers.destructure(q0) + out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) + n_batches_per_epoch = ceil(Int, n_data/batchsize) + sub_obj = Subsampled(full_obj, 1, 1:n_data) + + full_state = AdvancedVI.init(rng, full_obj, prob, params, restructure) + AdvancedVI.estimate_gradient!( + rng, full_obj, adtype, out, prob, params, restructure, full_state + ) + grad_ref = DiffResults.gradient(out) + + sub_state = AdvancedVI.init(rng, sub_obj, prob, params, restructure) + grad = mean(1:n_batches_per_epoch) do _ + # Using a fixed RNG so that the same Monte Carlo samples are used across the batches + rng = StableRNG(seed) + AdvancedVI.estimate_gradient!( + rng, sub_obj, adtype, out, prob, params, restructure, sub_state + ) + DiffResults.gradient(out) + end + @test grad ≈ grad_ref + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 5d0d2c8d..b71f0e0a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using Test using Test: @testset, @test +using Accessors using Base.Iterators using Bijectors using Distributions @@ -53,6 +54,7 @@ if GROUP == "All" || GROUP == "Interface" include("interface/rules.jl") include("interface/averaging.jl") include("interface/location_scale.jl") + include("interface/subsampled.jl") end const PROGRESS = haskey(ENV, "PROGRESS") From aa5fd83aa1c0babba57a2d161871652e438a5fce Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Mon, 9 Sep 2024 22:44:28 -0700 Subject: [PATCH 4/7] add optional arguments and keyword arguments to `RepGradELBO` --- src/AdvancedVI.jl | 2 +- src/objectives/elbo/repgradelbo.jl | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 0684ca5f..00261b79 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -149,7 +149,7 @@ include("objectives/subsampled.jl") export Subsampled """ - estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state) + estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state, objargs...; kwargs...) Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index e6f04ae8..87da60a0 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -110,6 +110,8 @@ function estimate_gradient!( params, restructure, state, + objargs...; + kwargs... ) q_stop = restructure(params) aux = ( From 1c1b289acab2285ce8ce19bde6094c6055afdf22 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 10 Sep 2024 00:38:12 -0700 Subject: [PATCH 5/7] add elapsed time measurement in stat --- src/optimize.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/optimize.jl b/src/optimize.jl index eb462ff5..cd749953 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -69,6 +69,7 @@ function optimize( obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure) avg_st = maybe_init_averager(state_init, averager, params) grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) + start_time = time() stats = NamedTuple[] for t in 1:max_iter @@ -93,6 +94,8 @@ function optimize( ) avg_st = apply(averager, avg_st, params) + stat = merge(stat, (elapsed_time=time() - start_time,)) + if !isnothing(callback) averaged_params = value(averager, avg_st) stat′ = callback(; From d8d6f560bb441817ad4115a09fe73af55acc8ee4 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 10 Sep 2024 00:38:32 -0700 Subject: [PATCH 6/7] add docs for subsampling --- docs/make.jl | 1 + docs/src/subsampling.md | 145 +++++++++++++++++++++++++++++++++++ src/AdvancedVI.jl | 4 +- src/objectives/subsampled.jl | 10 +++ 4 files changed, 159 insertions(+), 1 deletion(-) create mode 100644 docs/src/subsampling.md diff --git a/docs/make.jl b/docs/make.jl index b71d9a4f..3220849d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -19,6 +19,7 @@ makedocs(; "Location-Scale Variational Family" => "locscale.md", ], "Optimization" => "optimization.md", + "Subsampling" => "subsampling.md", ], ) diff --git a/docs/src/subsampling.md b/docs/src/subsampling.md new file mode 100644 index 00000000..af03d83d --- /dev/null +++ b/docs/src/subsampling.md @@ -0,0 +1,145 @@ + +# [Subsampling](@id subsampling) + +## Introduction +For problems with large datasets, evaluating the objective may become computationally too expensive. +In this regime, many variational inference algorithms can readily incorporate datapoint subsampling to reduce the per-iteration computation cost[^HBWP2013][^TL2014]. +Notice that many variational objectives require only *gradients* of the log target. +In a lot of cases, the gradient can be replaced with an *unbiased estimate* of the log target. +This section describes how to do this in `AdvancedVI`. + + +[^HBWP2013]: Hoffman, M. D., Blei, D. M., Wang, C., & Paisley, J. (2013). Stochastic variational inference. *Journal of Machine Learning Research*. +[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning.* + +## API +Subsampling is performed by wrapping the desired variational objective with the following objective: + +```@docs +Subsampled +``` +Furthermore, the target distribution `prob` must implement the following function: +```@docs +AdvancedVI.subsample +``` +The subsampling strategy used by `Subsampled` is what is known as "random reshuffling". +That is, the full dataset is shuffled and then partitioned into batches. +The batches are picked one at a time in a "sampling without replacement" fashion, which results in faster convergence than independently subsampling batches.[^KKMG2024] + +[^KKMG2024]: Kim, K., Ko, J., Ma, Y., & Gardner, J. R. (2024). Demystifying SGD with Doubly Stochastic Gradients. In *International Conference on Machine Learning.* + +!!! note + For the log target to be an valid unbiased estimate of the full batch gradient, the average over the batch must be adjusted by a constant factor ``n/b``, where ``n`` is the number of datapoints and ``b`` is the size of the minibatch (`length(batch)`). See the [example](@ref subsampling_example) for a demonstration of how to do this. + + +## [Example](@id subsampling) + +We will consider a sum of multivariate Gaussians, and subsample over the components of the sum: + +```@example subsampling +using SimpleUnPack, LogDensityProblems, Distributions, Random, LinearAlgebra + +struct SubsampledMvNormals{D <: MvNormal, F <: Real} + dists::Vector{D} + likeadj::F +end + +function SubsampledMvNormals(rng::Random.AbstractRNG, n_dims, n_normals::Int) + μs = randn(rng, n_dims, n_normals) + Σ = I + dists = MvNormal.(eachcol(μs), Ref(Σ)) + SubsampledMvNormals{eltype(dists), Float64}(dists, 1.0) +end + +function LogDensityProblems.logdensity(m::SubsampledMvNormals, x) + @unpack likeadj, dists = m + likeadj*mapreduce(Base.Fix2(logpdf, x), +, dists) +end +``` + +Notice that, when computing the log-density, we multiple by a constant `likeadj`. +This is to adjust the strength of the likelihood when minibatching is used. + +To use subsampling, we need to implement `subsample`, where we also compute the likelihood adjustment `likeadj`: +```@example subsampling +using AdvancedVI + +function AdvancedVI.subsample(m::SubsampledMvNormals, idx) + n_data = length(m.dists) + SubsampledMvNormals(m.dists[idx], n_data/length(idx)) +end +``` + +The objective is constructed as follows: +```@example subsampling +n_dims = 10 +n_data = 1024 +prob = SubsampledMvNormals(Random.default_rng(), n_dims, n_data); +``` +We will a dataset with `1024` datapoints. + +For the objective, we will use `RepGradELBO`. +To apply subsampling, it suffices to wrap with `subsampled`: +```@example subsampling +batchsize = 8 +full_obj = RepGradELBO(1) +sub_obj = Subsampled(full_obj, batchsize, 1:n_data); +``` +We can now invoke `optimize` to perform inference. +```@setup subsampling +using ForwardDiff, ADTypes, Optimisers, Plots + +Σ_true = Diagonal(fill(1/n_data, n_dims)) +μ_true = mean([mean(component) for component in prob.dists]) +Σsqrt_true = sqrt(Σ_true) + +q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims))) + +adtype = AutoForwardDiff() +optimizer = Adam(0.01) +averager = PolynomialAveraging() + +function callback(; averaged_params, restructure, kwargs...) + q = restructure(averaged_params) + μ, Σ = mean(q), cov(q) + dist2 = sum(abs2, μ - μ_true) + tr(Σ + Σ_true - 2*sqrt(Σsqrt_true*Σ*Σsqrt_true)) + (dist = sqrt(dist2),) +end + +n_iters = 3*10^2 +_, q, stats_full, _ = optimize( + prob, full_obj, q0, n_iters; optimizer, averager, show_progress=false, adtype, callback, +) + +n_iters = 10^3 +_, _, stats_sub, _ = optimize( + prob, sub_obj, q0, n_iters; optimizer, averager, show_progress=false, adtype, callback, +) + +x = [stat.iteration for stat in stats_full] +y = [stat.dist for stat in stats_full] +Plots.plot(x, y, xlabel="Iterations", ylabel="Wasserstein-2 Distance", label="Full Batch") + +x = [stat.iteration for stat in stats_sub] +y = [stat.dist for stat in stats_sub] +Plots.plot!(x, y, xlabel="Iterations", ylabel="Wasserstein-2 Distance", label="Subsampling (Random Reshuffling)") +savefig("subsampling_iteration.svg") + +x = [stat.elapsed_time for stat in stats_full] +y = [stat.dist for stat in stats_full] +Plots.plot(x, y, xlabel="Wallclock Time (sec)", ylabel="Wasserstein-2 Distance", label="Full Batch") + +x = [stat.elapsed_time for stat in stats_sub] +y = [stat.dist for stat in stats_sub] +Plots.plot!(x, y, xlabel="Wallclock Time (sec)", ylabel="Wasserstein-2 Distance", label="Subsampling (Random Reshuffling)") +savefig("subsampling_wallclocktime.svg") +``` +Let's first compare the convergence of full-batch `RepGradELBO` versus subsampled `RepGradELBO` with respect to the number of iterations: + +![](subsampling_iteration.svg) + +While it seems that subsampling results in slower convergence, the real power of subsampling is revealed when comparing with respect to the wallclock time: + +![](subsampling_wallclocktime.svg) + +Clearly, subsampling results in a vastly faster convergence speed. diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 00261b79..37cd2e47 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -135,9 +135,11 @@ export estimate_objective """ subsample(model, batch) +Subsample `model` to use only the datapoints designated by the iterable collection `batch`. + # Arguments - `model`: Model subject to subsampling. Could be the target model or the variational approximation. -- `batch`: Data points or indices corresponding to the subsampled "batch." +- `batch`: Iterable collection of datapoints or indices corresponding to the subsampled "batch." # Returns - `sub`: Subsampled model. diff --git a/src/objectives/subsampled.jl b/src/objectives/subsampled.jl index fb8c36a6..99d42415 100644 --- a/src/objectives/subsampled.jl +++ b/src/objectives/subsampled.jl @@ -1,4 +1,14 @@ +""" + Subsampled(objective, batchsize, data) + +Subsample `objective` over the dataset represented by `data` with minibatches of size `batchsize`. + +# Arguments +- `objective::AbstractVariationalObjective`: A variational objective that is compatible with subsampling. +- `batchsize::Int`: Size of minibatches. +- `data`: An iterator over the datapoints or indices representing the datapoints. +""" struct Subsampled{O<:AbstractVariationalObjective,D<:AbstractVector} <: AbstractVariationalObjective objective::O From 2fe0061b2231e656cf6d3451c243c4bfbf8f4f83 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 10 Sep 2024 12:55:48 -0700 Subject: [PATCH 7/7] fix subsampling example to work with DoG (tweaked `scale_eps`) --- docs/src/subsampling.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/subsampling.md b/docs/src/subsampling.md index af03d83d..291673d3 100644 --- a/docs/src/subsampling.md +++ b/docs/src/subsampling.md @@ -93,10 +93,10 @@ using ForwardDiff, ADTypes, Optimisers, Plots μ_true = mean([mean(component) for component in prob.dists]) Σsqrt_true = sqrt(Σ_true) -q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims))) +q0 = MvLocationScale(zeros(n_dims), Diagonal(ones(n_dims)), Normal(); scale_eps=1e-3) adtype = AutoForwardDiff() -optimizer = Adam(0.01) +optimizer = DoG() averager = PolynomialAveraging() function callback(; averaged_params, restructure, kwargs...) @@ -106,7 +106,7 @@ function callback(; averaged_params, restructure, kwargs...) (dist = sqrt(dist2),) end -n_iters = 3*10^2 +n_iters = 10^3 _, q, stats_full, _ = optimize( prob, full_obj, q0, n_iters; optimizer, averager, show_progress=false, adtype, callback, )