Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance wrapped distributions #414

Closed
wants to merge 10 commits into from
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ pointwise_loglikelihoods
```

```@docs
WrappedDistribution
NamedDist
NoDist
```

## Testing Utilities
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ export AbstractVarInfo,
dot_tilde_assume,
dot_tilde_observe,
# Pseudo distributions
WrappedDistribution,
NamedDist,
NoDist,
# Prob macros
Expand Down
43 changes: 30 additions & 13 deletions src/distribution_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,41 @@ using Distributions: Distributions
using Bijectors: Bijectors
using Distributions: Univariate, Multivariate, Matrixvariate

"""
Base type for distribution wrappers.
"""
abstract type WrappedDistribution{variate,support,Td<:Distribution{variate,support}} <:
Distribution{variate,support} end

wrapped_dist_type(::Type{<:WrappedDistribution{<:Any,<:Any,Td}}) where {Td} = Td
wrapped_dist_type(d::WrappedDistribution) = wrapped_dist_type(d)

wrapped_dist(d::WrappedDistribution) = d.dist

Base.length(d::WrappedDistribution{<:Multivariate}) = length(wrapped_dist(d))
Base.size(d::WrappedDistribution{<:Multivariate}) = size(wrapped_dist(d))
Base.eltype(::Type{T}) where {T<:WrappedDistribution} = eltype(wrapped_dist_type(T))
Base.eltype(d::WrappedDistribution) = eltype(wrapped_dist_type(d))

function Distributions.rand(rng::Random.AbstractRNG, d::WrappedDistribution)
return rand(rng, wrapped_dist(d))
end
Distributions.minimum(d::WrappedDistribution) = minimum(wrapped_dist(d))
Distributions.maximum(d::WrappedDistribution) = maximum(wrapped_dist(d))

Bijectors.bijector(d::WrappedDistribution) = bijector(wrapped_dist(d))

"""
A named distribution that carries the name of the random variable with it.
"""
struct NamedDist{variate,support,Td<:Distribution{variate,support},Tv<:VarName} <:
Distribution{variate,support}
WrappedDistribution{variate,support,Td}
dist::Td
name::Tv
end

NamedDist(dist::Distribution, name::Symbol) = NamedDist(dist, VarName{name}())

Base.length(dist::NamedDist) = Base.length(dist.dist)
Base.size(dist::NamedDist) = Base.size(dist.dist)

Distributions.logpdf(dist::NamedDist, x::Real) = Distributions.logpdf(dist.dist, x)
function Distributions.logpdf(dist::NamedDist, x::AbstractArray{<:Real})
return Distributions.logpdf(dist.dist, x)
Expand All @@ -27,29 +48,27 @@ function Distributions.loglikelihood(dist::NamedDist, x::AbstractArray{<:Real})
return Distributions.loglikelihood(dist.dist, x)
end

Bijectors.bijector(d::NamedDist) = Bijectors.bijector(d.dist)
"""
Wrapper around distribution `Td` that suppresses `logpdf()` calculation.

Note that *SampleFromPrior* would still sample from `Td`.
"""
struct NoDist{variate,support,Td<:Distribution{variate,support}} <:
Distribution{variate,support}
WrappedDistribution{variate,support,Td}
dist::Td
end
NoDist(dist::NamedDist) = NamedDist(NoDist(dist.dist), dist.name)

nodist(dist::Distribution) = NoDist(dist)
nodist(dists::AbstractArray) = nodist.(dists)

Base.length(dist::NoDist) = Base.length(dist.dist)
Base.size(dist::NoDist) = Base.size(dist.dist)

Distributions.rand(rng::Random.AbstractRNG, d::NoDist) = rand(rng, d.dist)
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
return zeros(Int, size(x, 2))
end
Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0
Distributions.minimum(d::NoDist) = minimum(d.dist)
Distributions.maximum(d::NoDist) = maximum(d.dist)

Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0
function Bijectors.logpdf_with_trans(
Expand All @@ -67,5 +86,3 @@ function Bijectors.logpdf_with_trans(
)
return 0
end

Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist)
37 changes: 37 additions & 0 deletions test/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,42 @@
end
end
end

@testset "multivariate NoDist" begin
@model function genmodel()
x ~ NoDist(Product(fill(Uniform(-20, 20), 5)))
for i in eachindex(x)
x[i] ~ Normal(0, 1)
end
Comment on lines +74 to +77
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems quite surprising, I have never seen anyone using NoDist in a model. I'm also not sure, why would you want to do that? When would such a model as the example here be useful?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems quite surprising, I have never seen anyone using NoDist in a model. I'm also not sure, why would you want to do that? When would such a model as the example here be useful?

a) This is a MWE
b) In the real usecase the length of the variable is ~500 elements. When I'm using x[i] ~ ... (or dot_tilde_assume()), the profiling indicates that with the current state of DynamicPPL ~50% of time is spent on indexing individual elements. That's why I've switched to multivariate distribution. With multivariate distribution the indexing overhead is resolved.
c) In the real usecase the prior is logpdf.(Ref(Normal(mean(x), sigma)), x) |> sum |> addlogp!!, so NoDist helps to declare x and its domain (also see d).
d) In the real usecase I'm switching between the evolutionary programming (BlackBoxOptim.jl) and gradient-based methods to get the MAP estimates. So while the model allows alternative parametrization, e.g. xmean ~ Normal(0, 1), xdelta .~ Normal(0, sigma), x = xmean .+ xdelta, it would be suboptimal for crossover operations; also it would introduce one extra degree of freedom.
e) I appreciate your concerns regarding the usability of MWE, but I think the problem of wrapped distributions not supporting all necessary Distributions.jl API is there, and the tests do cover that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NoDist is an internal workaround/implementation detail but as NamedDist it's no "proper" user-facing distribution. Therefore it was not supposed to be used in a model directly, and not tested and implemented to support such use cases.

More generally, your workarounds and use of internal functionality (also addlogp!! is somewhat internal, the user-facing alternative is @addlogprob! which is still somewhat dangerous - IIRC in some cases it leads to incorrect or at least surprising results) make me wonder if there is some other functionality missing or some part of DynamicPPL that should be changed. I don't think the best solution is to start promoting and supporting such workarounds but rather we should better support the actual use cases and models in the first place. I think ideally you just implement your model in the most natural way and it works.

One thing is still not clear to me (also in your real usecase): Why do you want to declare x with a NoDist?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather we should better support the actual use cases and models in the first place

I guess what I'm trying to achieve here with NoDist() is to declare x first, and define its prior later.

Why do you want to declare x with a NoDist?

It's not necessary, but I wanted to avoid calculating Uniform priors, both for performance and for having meaningful probabilities.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess what I'm trying to achieve here with NoDist() is to declare x first, and define its prior later.

But what I don't understand is why do you add a statement with NoDist first? You could just provide x as data to the model (if it is not sampled) or sample it from the actual priors (and here just preallocate the array first).

Having different statements for x where one is basically wrong seems a bit strange.

It's not necessary, but I wanted to avoid calculating Uniform priors, both for performance and for having meaningful probabilities.

But if x has a uniform prior, you should use it properly, shouldn't you? If you don't want to include the prior in your log density calculations you could condition on x or only evaluate the loglikelihood (you can even just do it for a subset of parameters).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Would it work properly if I declare truncated(Flat(), a, b) distribution?

Yeah that should work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@devmotion I'm a bit confused as to whether or not your saying that the fact that @alyst has to do this to achieve the desired performance is undesirable or if you're suggesting that he can achieve the same performance by writing it in a for-loop and pre-allocating? Because if you're saying the former, I think we're all on the same page.

Yes, I meant that it's undesirable that apparently workarounds such as two tilde statements for the same variable are needed to achieve performance.

Maybe we should add an offical way for declaring a variable in the model (i.e., registering it without distribution)? Possibly an official macro (similar to @addlogprob!) that would then make sure that it ends up in the variable structure. I just don't know how it would be implemented exactly. Maybe it would be easiest to only support SimpleVarInfo? I assume it could be useful in cases where you would like to loop but don't want to end up with n different variables x[1], ..., x[n] in the resulting named tuple or dictionary. Alternatively, maybe we could add something like a (arguably also a bit hacky) For/Map distribution that would allow one to write something like

@model function ...
    ...
    x ~ For(1:n) do i
        f(i)
    end
    ...
end

The main difference to the existing possibilities would be that 1) it does not require preallocating an array etc. (such as .~), 2) it does not create n different variables x[1], ..., x[n] (such as a regular for loop), 3) it does not require allocating an array of distributions (such as arraydist/product_distribution) but only create the individual distributions on the fly.

Maybe the better approach would be to not introduce a new distribution but just support something like arraydist(f, xs).

I guess one of the main challenges would be to figure out what the type of arraydist(f, xs) should be. It might not be possible to infer if it is a MultivariateDistribution, MatrixDistribution etc. in general I assume.

Copy link
Contributor Author

@alyst alyst Jun 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it work properly if I declare truncated(Flat(), a, b) distribution?

Yeah that should work.

Actually, Flat() doesn't define cdf(), which is required for truncated(). But even if we define cdf(d::Flat, x) = one(x), then P(a <= d <= b) would be zero. So it would trigger an error in truncated(), and most likely in many other places.
One can define the new FlatBounded(a, b) pseudodistribution, but it looks very similar to NoDist(Uniform(a, b)) to me (except the transformation).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, Flat() doesn't define cdf(), which is required for truncated(). But even if we define cdf(d::Flat, x) = one(x), then P(a <= d <= b) would be zero. So it would trigger an error in truncated(), and most likely in many other places.
One can define the new FlatBounded(a, b) pseudodistribution, but it looks very similar to NoDist(Uniform(a, b)) to me (except the transformation).

Ah I guess this is why we have the FlatPositive rather than just using truncated. But yes, it ends up being very similar to NoDist but not quite: the logpdf_with_transform is going to be different. For NoDist we want no correction but for something like FlatPositive we do we want correction.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I've added bijector for NoDist in #415 now because it's useful for the new getindex(vi, vn, dist) methods introduced (also found a pretty significant bug when combining NoDist + transformed VarInfo) 👍

But, as I said previously, this will produce different results than something like FlatPositive which will, unlike NoDist, also include the log-absdet-jacobian correction.

end
gen_model = genmodel()
vi_gen = VarInfo(gen_model)
@test isfinite(logjoint(gen_model, vi_gen))
# test for bijector
link!(vi_gen, DynamicPPL.SampleFromPrior())
invlink!(vi_gen, DynamicPPL.SampleFromPrior())

# explicit model specification
expl_model = DynamicPPL.Model(NamedTuple()) do model, varinfo, context
DynamicPPL.tilde_assume!!(
context,
NoDist(Product(fill(Uniform(-20, 20), 5))),
@varname(x),
varinfo,
)
x = varinfo[@varname(x)]
@test x isa Vector{<:Real}
@test length(x) == 5
return (
nothing,
DynamicPPL.acclogp!!(varinfo, sum(logpdf.(Ref(Normal(0, 1)), x))),
)
end
vi_expl = VarInfo(expl_model)
@test isfinite(logjoint(expl_model, vi_expl))
# test for bijector
link!(vi_expl, DynamicPPL.SampleFromPrior())
invlink!(vi_expl, DynamicPPL.SampleFromPrior())
end
end
end
38 changes: 29 additions & 9 deletions test/distribution_wrappers.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,33 @@
@testset "distribution_wrappers.jl" begin
d = Normal()
nd = DynamicPPL.NoDist(d)
@testset "univariate" begin
d = Normal()
nd = DynamicPPL.NoDist(d)

# Smoke test
rand(nd)
# Smoke test
rand(nd)

# Actual tests
@test minimum(nd) == -Inf
@test maximum(nd) == Inf
@test logpdf(nd, 15.0) == 0
@test Bijectors.logpdf_with_trans(nd, 30, true) == 0
# Actual tests
@test minimum(nd) == -Inf
@test maximum(nd) == Inf
@test logpdf(nd, 15.0) == 0
@test Bijectors.logpdf_with_trans(nd, 30, true) == 0
@test Bijectors.bijector(nd) == Bijectors.bijector(d)
end

@testset "multivariate" begin
d = Product([Normal(), Uniform()])
nd = DynamicPPL.NoDist(d)

# Smoke test
@test length(rand(nd)) == 2

# Actual tests
@test length(nd) == 2
@test size(nd) == (2,)
@test minimum(nd) == [-Inf, 0.0]
@test maximum(nd) == [Inf, 1.0]
@test logpdf(nd, [15.0, 0.5]) == 0
@test Bijectors.logpdf_with_trans(nd, [0, 1]) == 0
@test Bijectors.bijector(nd) == Bijectors.bijector(d)
end
end
Loading