Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add minibatch subsampling (doubly stochastic) objective #84

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ makedocs(;
"Location-Scale Variational Family" => "locscale.md",
],
"Optimization" => "optimization.md",
"Subsampling" => "subsampling.md",
],
)

Expand Down
145 changes: 145 additions & 0 deletions docs/src/subsampling.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@

# [Subsampling](@id subsampling)

## Introduction
For problems with large datasets, evaluating the objective may become computationally too expensive.
In this regime, many variational inference algorithms can readily incorporate datapoint subsampling to reduce the per-iteration computation cost[^HBWP2013][^TL2014].
Notice that many variational objectives require only *gradients* of the log target.
In a lot of cases, the gradient can be replaced with an *unbiased estimate* of the log target.
This section describes how to do this in `AdvancedVI`.


[^HBWP2013]: Hoffman, M. D., Blei, D. M., Wang, C., & Paisley, J. (2013). Stochastic variational inference. *Journal of Machine Learning Research*.
[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In *International Conference on Machine Learning.*

## API
Subsampling is performed by wrapping the desired variational objective with the following objective:

```@docs
Subsampled
```
Furthermore, the target distribution `prob` must implement the following function:
```@docs
AdvancedVI.subsample
```
The subsampling strategy used by `Subsampled` is what is known as "random reshuffling".
That is, the full dataset is shuffled and then partitioned into batches.
The batches are picked one at a time in a "sampling without replacement" fashion, which results in faster convergence than independently subsampling batches.[^KKMG2024]

[^KKMG2024]: Kim, K., Ko, J., Ma, Y., & Gardner, J. R. (2024). Demystifying SGD with Doubly Stochastic Gradients. In *International Conference on Machine Learning.*

!!! note
For the log target to be an valid unbiased estimate of the full batch gradient, the average over the batch must be adjusted by a constant factor ``n/b``, where ``n`` is the number of datapoints and ``b`` is the size of the minibatch (`length(batch)`). See the [example](@ref subsampling_example) for a demonstration of how to do this.


## [Example](@id subsampling)

We will consider a sum of multivariate Gaussians, and subsample over the components of the sum:

```@example subsampling
using SimpleUnPack, LogDensityProblems, Distributions, Random, LinearAlgebra

struct SubsampledMvNormals{D <: MvNormal, F <: Real}
dists::Vector{D}
likeadj::F
end

function SubsampledMvNormals(rng::Random.AbstractRNG, n_dims, n_normals::Int)
μs = randn(rng, n_dims, n_normals)
Σ = I
dists = MvNormal.(eachcol(μs), Ref(Σ))
SubsampledMvNormals{eltype(dists), Float64}(dists, 1.0)
end

function LogDensityProblems.logdensity(m::SubsampledMvNormals, x)
@unpack likeadj, dists = m
likeadj*mapreduce(Base.Fix2(logpdf, x), +, dists)
end
```

Notice that, when computing the log-density, we multiple by a constant `likeadj`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Notice that, when computing the log-density, we multiple by a constant `likeadj`.
Notice that, when computing the log-density, we multiply by a constant `likeadj`.

This is to adjust the strength of the likelihood when minibatching is used.

To use subsampling, we need to implement `subsample`, where we also compute the likelihood adjustment `likeadj`:
```@example subsampling
using AdvancedVI

function AdvancedVI.subsample(m::SubsampledMvNormals, idx)
n_data = length(m.dists)
SubsampledMvNormals(m.dists[idx], n_data/length(idx))
end
```

The objective is constructed as follows:
```@example subsampling
n_dims = 10
n_data = 1024
prob = SubsampledMvNormals(Random.default_rng(), n_dims, n_data);
```
We will a dataset with `1024` datapoints.

For the objective, we will use `RepGradELBO`.
To apply subsampling, it suffices to wrap with `subsampled`:
```@example subsampling
batchsize = 8
full_obj = RepGradELBO(1)
sub_obj = Subsampled(full_obj, batchsize, 1:n_data);
```
We can now invoke `optimize` to perform inference.
```@setup subsampling
using ForwardDiff, ADTypes, Optimisers, Plots

Σ_true = Diagonal(fill(1/n_data, n_dims))
μ_true = mean([mean(component) for component in prob.dists])
Σsqrt_true = sqrt(Σ_true)

q0 = MvLocationScale(zeros(n_dims), Diagonal(ones(n_dims)), Normal(); scale_eps=1e-3)

adtype = AutoForwardDiff()
optimizer = DoG()
averager = PolynomialAveraging()

function callback(; averaged_params, restructure, kwargs...)
q = restructure(averaged_params)
μ, Σ = mean(q), cov(q)
dist2 = sum(abs2, μ - μ_true) + tr(Σ + Σ_true - 2*sqrt(Σsqrt_true*Σ*Σsqrt_true))
(dist = sqrt(dist2),)
end

n_iters = 10^3
_, q, stats_full, _ = optimize(
prob, full_obj, q0, n_iters; optimizer, averager, show_progress=false, adtype, callback,
)

n_iters = 10^3
_, _, stats_sub, _ = optimize(
prob, sub_obj, q0, n_iters; optimizer, averager, show_progress=false, adtype, callback,
)

x = [stat.iteration for stat in stats_full]
y = [stat.dist for stat in stats_full]
Plots.plot(x, y, xlabel="Iterations", ylabel="Wasserstein-2 Distance", label="Full Batch")

x = [stat.iteration for stat in stats_sub]
y = [stat.dist for stat in stats_sub]
Plots.plot!(x, y, xlabel="Iterations", ylabel="Wasserstein-2 Distance", label="Subsampling (Random Reshuffling)")
savefig("subsampling_iteration.svg")

x = [stat.elapsed_time for stat in stats_full]
y = [stat.dist for stat in stats_full]
Plots.plot(x, y, xlabel="Wallclock Time (sec)", ylabel="Wasserstein-2 Distance", label="Full Batch")

x = [stat.elapsed_time for stat in stats_sub]
y = [stat.dist for stat in stats_sub]
Plots.plot!(x, y, xlabel="Wallclock Time (sec)", ylabel="Wasserstein-2 Distance", label="Subsampling (Random Reshuffling)")
savefig("subsampling_wallclocktime.svg")
```
Let's first compare the convergence of full-batch `RepGradELBO` versus subsampled `RepGradELBO` with respect to the number of iterations:

![](subsampling_iteration.svg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these .svg files missing from the repo? I haven't looked at the built docs, just don't see them in the PR.


While it seems that subsampling results in slower convergence, the real power of subsampling is revealed when comparing with respect to the wallclock time:

![](subsampling_wallclocktime.svg)

Clearly, subsampling results in a vastly faster convergence speed.
22 changes: 21 additions & 1 deletion src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,28 @@ function estimate_objective end

export estimate_objective

# Oejectives

"""
subsample(model, batch)

Subsample `model` to use only the datapoints designated by the iterable collection `batch`.

# Arguments
- `model`: Model subject to subsampling. Could be the target model or the variational approximation.
- `batch`: Iterable collection of datapoints or indices corresponding to the subsampled "batch."

# Returns
- `sub`: Subsampled model.
"""
subsample(model::Any, ::Any) = model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly that subsampling is a more general operation than just a VI thing? If that's the case, could this be moved to DynamicPPL, or even AbstractPPL?

Also, I wonder if an empty function without methods would make more sense. Is returning the unmodified original model a reasonable fallback? I could imagine it confusing users, who would call it and get a return value, not realising it's actually just the original model.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yebai Any comments on the current direction on the PPL side?


include("objectives/subsampled.jl")

export Subsampled

"""
estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state)
estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state, objargs...; kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the varargs be explained in the docstring?


Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ`

Expand Down
2 changes: 2 additions & 0 deletions src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ function estimate_gradient!(
params,
restructure,
state,
objargs...;
kwargs...
)
q_stop = restructure(params)
aux = (
Expand Down
100 changes: 100 additions & 0 deletions src/objectives/subsampled.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@

"""
Subsampled(objective, batchsize, data)

Subsample `objective` over the dataset represented by `data` with minibatches of size `batchsize`.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you comment on what happens if batchsize does not divide length(data), or whether that's significant at all?

# Arguments
- `objective::AbstractVariationalObjective`: A variational objective that is compatible with subsampling.
- `batchsize::Int`: Size of minibatches.
- `data`: An iterator over the datapoints or indices representing the datapoints.
"""
struct Subsampled{O<:AbstractVariationalObjective,D<:AbstractVector} <:
AbstractVariationalObjective
objective::O
batchsize::Int
data::D
end

function init_batch(rng::Random.AbstractRNG, data::AbstractVector, batchsize::Int)
shuffled = Random.shuffle(rng, data)
batches = Iterators.partition(shuffled, batchsize)
return enumerate(batches)
end

function AdvancedVI.init(
rng::Random.AbstractRNG, sub::Subsampled, prob, params, restructure
)
@unpack batchsize, objective, data = sub
epoch = 1
sub_state = (epoch, init_batch(rng, data, batchsize))
obj_state = AdvancedVI.init(rng, objective, prob, params, restructure)
return (sub_state, obj_state)
end

function next_batch(rng::Random.AbstractRNG, sub::Subsampled, sub_state)
epoch, batch_itr = sub_state
(step, batch), batch_itr′ = Iterators.peel(batch_itr)
epoch′, batch_itr′′ = if isempty(batch_itr′)
epoch + 1, init_batch(rng, sub.data, sub.batchsize)
else
epoch, batch_itr′
end
stat = (epoch=epoch, step=step)
return batch, (epoch′, batch_itr′′), stat
end

function estimate_objective(
rng::Random.AbstractRNG,
sub::Subsampled,
q,
prob;
n_batches::Int=ceil(Int, length(sub.data) / sub.batchsize),
kwargs...,
)
@unpack objective, batchsize, data = sub
sub_st = (1, init_batch(rng, data, batchsize))
return mean(1:n_batches) do _
batch, sub_st, _ = next_batch(rng, sub, sub_st)
prob_sub = subsample(prob, batch)
q_sub = subsample(q, batch)
estimate_objective(rng, objective, q_sub, prob_sub; kwargs...)
end
end

function estimate_objective(
sub::Subsampled,
q,
prob;
n_batches::Int=ceil(Int, length(sub.data) / sub.batchsize),
kwargs...,
)
return estimate_objective(Random.default_rng(), sub, q, prob; n_batches, kwargs...)
end

function estimate_gradient!(
rng::Random.AbstractRNG,
sub::Subsampled,
adtype::ADTypes.AbstractADType,
out::DiffResults.MutableDiffResult,
prob,
params,
restructure,
state,
objargs...;
kwargs...,
)
obj = sub.objective
sub_st, obj_st = state
q = restructure(params)

batch, sub_st′, sub_stat = next_batch(rng, sub, sub_st)
prob_sub = subsample(prob, batch)
q_sub = subsample(q, batch)
params_sub, re_sub = Optimisers.destructure(q_sub)

out, obj_st′, obj_stat = AdvancedVI.estimate_gradient!(
rng, obj, adtype, out, prob_sub, params_sub, re_sub, obj_st, objargs...; kwargs...
)
return out, (sub_st′, obj_st′), merge(sub_stat, obj_stat)
end
3 changes: 3 additions & 0 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ function optimize(
obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure)
avg_st = maybe_init_averager(state_init, averager, params)
grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params))
start_time = time()
stats = NamedTuple[]

for t in 1:max_iter
Expand All @@ -93,6 +94,8 @@ function optimize(
)
avg_st = apply(averager, avg_st, params)

stat = merge(stat, (elapsed_time=time() - start_time,))

if !isnothing(callback)
averaged_params = value(averager, avg_st)
stat′ = callback(;
Expand Down
4 changes: 4 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -25,6 +26,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "0.2.1, 1"
Accessors = "0.1"
Bijectors = "0.13"
DiffResults = "1.0"
Distributions = "0.25.100"
Expand All @@ -37,11 +39,13 @@ LinearAlgebra = "1"
LogDensityProblems = "2.1.1"
Optimisers = "0.2.16, 0.3"
PDMats = "0.11.7"
Pkg = "1"
Random = "1"
ReverseDiff = "1.15.1"
SimpleUnPack = "1.1.0"
StableRNGs = "1.0.0"
Statistics = "1"
StatsBase = "0.34"
Test = "1"
Tracker = "0.2.20"
Zygote = "0.6.63"
Expand Down
Loading
Loading