Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add variate transport #62

Merged
merged 70 commits into from
Jun 19, 2022
Merged
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
25fa2e1
Add ChainRulesCore to dependencies
oschulz Jun 15, 2022
a3e1bbf
Add InverseFunctions and ChangesOfVariables to deps
oschulz Jun 15, 2022
2015ccd
Add require_insupport
oschulz Jun 15, 2022
e86cf2c
Add effndof and require_same_effndof
oschulz Jun 15, 2022
2e8e7d6
Add check_varshape
oschulz Jun 15, 2022
5019b14
Add vartransform
oschulz Jun 15, 2022
6ee4e3d
Remove "measure-like" terminology
oschulz Jun 15, 2022
2217a42
Remove requirement for vartransform to return a Function
oschulz Jun 15, 2022
bb0257d
Remove check_varshape
oschulz Jun 15, 2022
15952bb
Rename effndof to getdof
oschulz Jun 15, 2022
d5635fe
Separate vartransform and vartransform_def
oschulz Jun 15, 2022
57c83b3
Fix default vartransform_def
oschulz Jun 15, 2022
b2817f0
Remove select_vartransform_intermediate
oschulz Jun 15, 2022
66f5990
Fix check_dof
oschulz Jun 16, 2022
de96842
Export getdof
oschulz Jun 16, 2022
9101015
Export vartransform
oschulz Jun 16, 2022
45fe220
Implement getdof for measures
oschulz Jun 16, 2022
e7fff14
Remove StdNormal
oschulz Jun 16, 2022
680e8f8
FIXUP implement getdof
oschulz Jun 16, 2022
7d9e11e
FIXUP implement getdof
oschulz Jun 16, 2022
c6368bb
Add StdLogistic
oschulz Jun 16, 2022
9372871
Implement vartransform_def for StdMeasure
oschulz Jun 16, 2022
70f1e30
FIXUP vartransform_def for StdMeasure
oschulz Jun 16, 2022
147ba2a
Add _vartransform_intermediate
oschulz Jun 16, 2022
0f26323
FIXUP Implement vartransform_def for StdMeasure
oschulz Jun 16, 2022
6ffa92c
Fix insupport for StdLogistic
oschulz Jun 16, 2022
101f947
Fix StdLogistic
oschulz Jun 16, 2022
1193e05
FIXUP StdMeasure vartransform
oschulz Jun 16, 2022
d0509e9
Fix check_dof
oschulz Jun 16, 2022
9556e31
Add checked_var
oschulz Jun 16, 2022
acde08b
WIP Add vartransform tests
oschulz Jun 16, 2022
5a16bef
FIXUP vartransform tests
oschulz Jun 16, 2022
ebb7ddb
Fix rand for StdUniform
oschulz Jun 16, 2022
7b56f08
FIXUP vartransform tests
oschulz Jun 16, 2022
23e41c7
Use checked_var at VarTransformation input stage
oschulz Jun 16, 2022
6f5b246
FIX vartransform tests
oschulz Jun 16, 2022
f5ebe6d
Add defaults for check_dof and checked_var
oschulz Jun 16, 2022
cbb6873
Add vartransform_origin for WeightedMeasure
oschulz Jun 16, 2022
1c52ed0
Fix deps
oschulz Jun 16, 2022
7b954a5
Fix tests
oschulz Jun 16, 2022
5a12523
WIP Add PushforwardMeasure
oschulz Jun 16, 2022
520562c
WIP improve PushforwardMeasure
oschulz Jun 16, 2022
10b12dc
WIP improve PushforwardMeasure
oschulz Jun 16, 2022
bfda82b
WIP improve PushforwardMeasure
oschulz Jun 16, 2022
62eabb0
WIP improve PushforwardMeasure
oschulz Jun 16, 2022
a3a7b00
FIX PushforwardMeasure
oschulz Jun 16, 2022
39bf7b0
Allow PushforwardMeasure to bypass checked_var
oschulz Jun 17, 2022
da7ecc6
Test PushforwardMeasure
oschulz Jun 17, 2022
75e1fb3
Fix docstring of NoDOF
oschulz Jun 17, 2022
6250b20
Add test_vartransform to Interface
oschulz Jun 17, 2022
9bfa9f9
FIXUP _default_checked_var
oschulz Jun 17, 2022
d4f0246
FIXUP vartransform_origin docs and defaults
oschulz Jun 17, 2022
0512930
Run vartransform tests
oschulz Jun 17, 2022
14890bb
Improve vartransform_origin def for WeightedMeasure
oschulz Jun 17, 2022
b32f34d
Add vartransform stdmeasure autodim
oschulz Jun 17, 2022
f12f1b5
Specialize equality for VarTransformation
oschulz Jun 18, 2022
b3dfe13
Don't call check_dof so often
oschulz Jun 18, 2022
1088fc1
Improve checked_var for PowerMeasure
oschulz Jun 18, 2022
1c2311c
Fix check_dof and require_insupport rrules
oschulz Jun 18, 2022
1fb805c
Test getdof
oschulz Jun 18, 2022
24affe6
Document TransformVolCorr
oschulz Jun 18, 2022
35fae34
Fix transform variable naming inconsistencies
oschulz Jun 18, 2022
d673692
Specialize gotdof for inferrably empty power measures
oschulz Jun 18, 2022
3965732
Add trafos for Dirac
oschulz Jun 18, 2022
5b5eb7f
Support logdensity calculation on empty power measures
oschulz Jun 19, 2022
e70f294
Improve test_vartransform
oschulz Jun 19, 2022
e85e54f
Fix and test vartransform for Dirac
oschulz Jun 19, 2022
8f2da10
Rename vartransform to transport_to
oschulz Jun 19, 2022
343d20b
Rename vartransform_origin
oschulz Jun 19, 2022
a8faa66
Increase package version to v0.11.0
oschulz Jun 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ authors = ["Chad Scherrer <[email protected]> and contributors"]
version = "0.10.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Expand All @@ -24,11 +27,14 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"

[compat]
ChainRulesCore = "1"
ChangesOfVariables = "0.1.3"
Compat = "3.35, 4"
ConstructionBase = "1.3"
DensityInterface = "0.4"
FillArrays = "0.12, 0.13"
IfElse = "0.1"
InverseFunctions = "0.1.7"
IrrationalConstants = "0.1"
LogExpFunctions = "0.3"
LogarithmicNumbers = "1"
Expand All @@ -42,6 +48,7 @@ julia = "1.3"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"

[targets]
test = ["Aqua"]
test = ["Aqua", "ChainRulesTestUtils"]
29 changes: 14 additions & 15 deletions src/MeasureBase.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module MeasureBase

using Base: @propagate_inbounds

using Random
import Random: rand!
import Random: gentype
Expand All @@ -11,13 +13,17 @@ import DensityInterface: densityof
import DensityInterface: DensityKind
using DensityInterface

using InverseFunctions
using ChangesOfVariables

import Base.iterate
import ConstructionBase
using ConstructionBase: constructorof

using PrettyPrinting
const Pretty = PrettyPrinting

using ChainRulesCore
using FillArrays
using Static

Expand All @@ -32,20 +38,11 @@ export logdensity_def
export basemeasure
export basekernel
export productmeasure

"""
inssupport(m, x)
insupport(m)

`insupport(m,x)` computes whether `x` is in the support of `m`.

`insupport(m)` returns a function, and satisfies

insupport(m)(x) == insupport(m, x)
"""
function insupport end

export insupport
export getdof
export vartransform

include("insupport.jl")

abstract type AbstractMeasure end

Expand All @@ -63,7 +60,7 @@ gentype(μ::AbstractMeasure) = typeof(testvalue(μ))
# gentype(μ::AbstractMeasure) = gentype(basemeasure(μ))

using NaNMath
using LogExpFunctions: logsumexp
using LogExpFunctions: logsumexp, logistic, logit

@deprecate instance_type(x) Core.Typeof(x) false

Expand Down Expand Up @@ -94,6 +91,8 @@ using Compat

using IrrationalConstants

include("getdof.jl")
include("vartransform.jl")
include("schema.jl")
include("splat.jl")
include("proxies.jl")
Expand Down Expand Up @@ -125,9 +124,9 @@ include("combinators/powerweighted.jl")
include("combinators/conditional.jl")

include("standard/stdmeasure.jl")
include("standard/stdnormal.jl")
include("standard/stduniform.jl")
include("standard/stdexponential.jl")
include("standard/stdlogistic.jl")

include("rand.jl")

Expand Down
18 changes: 18 additions & 0 deletions src/combinators/power.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,21 @@ end
dynamic(insupport(p, xj))
end
end


@inline getdof(μ::PowerMeasure) = getdof(μ.parent) * prod(map(length, μ.axes))
cscherrer marked this conversation as resolved.
Show resolved Hide resolved

@propagate_inbounds function checked_var(μ::PowerMeasure, x::AbstractArray{<:Any})
@boundscheck begin
sz_μ = map(length, μ.axes)
sz_x = size(x)
if sz_μ != sz_x
throw(ArgumentError("Size of variate doesn't match size of power measure"))
end
end
return x
end

function checked_var(μ::PowerMeasure, x::Any)
throw(ArgumentError("Size of variate doesn't match size of power measure"))
end
96 changes: 96 additions & 0 deletions src/combinators/transformedmeasure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,99 @@ function params(::AbstractTransformedMeasure) end
function paramnames(::AbstractTransformedMeasure) end

function parent(::AbstractTransformedMeasure) end


abstract type TransformVolCorr end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a reason to use VolCorr instead of LogJac? That's more familiar I think

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VolCorr emphasizes the purpose over the mechanism LogJac, but I have no strong opinion on it.

struct WithVolCorr <: TransformVolCorr end
struct NoVolCorr <: TransformVolCorr end


export PushforwardMeasure

"""
struct PushforwardMeasure{FF,IF,MU,VC<:TransformVolCorr} <: AbstractPushforward
f :: FF
inv_f :: IF
origin :: MU
volcorr :: VC
end
"""
struct PushforwardMeasure{FF,IF,M,VC<:TransformVolCorr} <: AbstractPushforward
f::FF
inv_f::IF
origin::M
volcorr::VC
end

gettransform(ν::PushforwardMeasure) = ν.f
parent(ν::PushforwardMeasure) = ν.origin


function Pretty.tile(ν::PushforwardMeasure)
Pretty.list_layout(Pretty.tile.([ν.f, ν.inv_f, ν.origin]); prefix = :PushforwardMeasure)
end


@inline function logdensity_def(ν::PushforwardMeasure{FF,IF,M,<:WithVolCorr}, y) where {FF,IF,M}
x_orig, inv_ladj = with_logabsdet_jacobian(ν.inv_f, y)
logd_orig = logdensity_def(ν.origin, x_orig)
logd = float(logd_orig + inv_ladj)
neginf = oftype(logd, -Inf)
return ifelse(
# Zero density wins against infinite volume:
(isnan(logd) && logd_orig == -Inf && inv_ladj == +Inf) ||
# Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ?
# Return constant -Inf to prevent problems with ForwardDiff:
(isfinite(logd_orig) && (inv_ladj == -Inf)),
neginf,
logd
)
end

@inline function logdensity_def(ν::PushforwardMeasure{FF,IF,M,<:NoVolCorr}, y) where {FF,IF,M}
x_orig = to_origin(ν, y)
return logdensity_def(ν.origin, x_orig)
end


insupport(ν::PushforwardMeasure, y) = insupport(vartransform_origin(ν), to_origin(ν, y))

testvalue(ν::PushforwardMeasure) = from_origin(ν, testvalue(vartransform_origin(ν)))

@inline function basemeasure(ν::PushforwardMeasure)
PushforwardMeasure(ν.f, ν.inv_f, basemeasure(vartransform_origin(ν)), NoVolCorr())
end


_pushfwd_dof(::Type{MU}, ::Type, dof) where MU = NoDOF{MU}()
_pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where MU = dof

# Assume that DOF are preserved if with_logabsdet_jacobian is functional:
@inline function getdof(ν::MU) where {MU<:PushforwardMeasure}
T = Core.Compiler.return_type(testvalue, Tuple{typeof(ν.origin)})
R = Core.Compiler.return_type(with_logabsdet_jacobian, Tuple{typeof(ν.f), T})
_pushfwd_dof(MU, R, getdof(ν.origin))
end

# Bypass `checked_var`, would require potentially costly transformation:
@inline checked_var(::PushforwardMeasure, x) = x


@inline vartransform_origin(ν::PushforwardMeasure) = ν.origin
@inline to_origin(ν::PushforwardMeasure, x) = ν.inv_f(x)
@inline from_origin(ν::PushforwardMeasure, y) = ν.f(y)

function Base.rand(rng::AbstractRNG, ::Type{T}, ν::PushforwardMeasure) where T
return from_origin(ν, rand(rng, T, vartransform_origin(ν)))
end


export pushfwd

"""
pushfwd(f, μ, volcorr = WithVolCorr())

Return the [pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure)
from `μ` the [measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`.
"""
pushfwd(f, μ, volcorr = WithVolCorr()) = PushforwardMeasure(f, inverse(f), μ, volcorr)
4 changes: 4 additions & 0 deletions src/combinators/weighted.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,7 @@ Base.:*(m::AbstractMeasure, k::Real) = k * m
gentype(μ::WeightedMeasure) = gentype(μ.base)

insupport(μ::WeightedMeasure, x) = insupport(μ.base, x)

vartransform_origin(ν::WeightedMeasure) = ν.base
to_origin(::WeightedMeasure, y) = y
from_origin(::WeightedMeasure, x) = x
77 changes: 77 additions & 0 deletions src/getdof.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
MeasureBase.NoDOF{MU}

Indicates that there is no way to compute degrees of freedom of a measure
of type `MU` with the given information, e.g. because the DOF are not
a global property of the measure.
"""
struct NoDOF{MU} end


"""
getdof(μ)

Returns the effective number of degrees of freedom of variates of
measure `μ`.

The effective NDOF my differ from the length of the variates. For example,
the effective NDOF for a Dirichlet distribution with variates of length `n`
is `n - 1`.

Also see [`check_dof`](@ref).
"""
function getdof end

# Prevent infinite recursion:
@inline _default_getdof(::Type{MU}, ::MU) where MU = NoDOF{MU}
@inline _default_getdof(::Type{MU}, mu_base) where MU = getdof(mu_base)

@inline getdof(μ::MU) where MU = _default_getdof(MU, basemeasure(μ))


"""
MeasureBase.check_dof(ν, μ)::Nothing

Check if `ν` and `μ` have the same effective number of degrees of freedom
according to [`MeasureBase.getdof`](@ref).
"""
function check_dof end

function check_dof(ν, μ)
n_ν = getdof(ν)
n_μ = getdof(μ)
if n_ν != n_μ
throw(ArgumentError("Measure ν of type $(nameof(typeof(ν))) has $(n_ν) DOF but μ of type $(nameof(typeof(μ))) has $(n_μ) DOF"))
end
return nothing
end

_check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent()
ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback


"""
MeasureBase.NoVarCheck{MU,T}

Indicates that there is no way to check of a values of type `T` are
variate of measures of type `MU`.
"""
struct NoVarCheck{MU,T} end


"""
MeasureBase.checked_var(μ::MU, x::T)::T

Return `x` if `x` is a valid variate of `μ`, throw an `ArgumentError` if not,
return `NoVarCheck{MU,T}()` if not check can be performed.
"""
function checked_var end

# Prevent infinite recursion:
@propagate_inbounds _default_checked_var(::Type{MU}, ::MU, ::T) where {MU,T} = NoVarCheck{MU,T}
@propagate_inbounds _default_checked_var(::Type{MU}, mu_base, x) where MU = checked_var(mu_base, x)

@propagate_inbounds checked_var(mu::MU, x) where MU = _default_checked_var(MU, basemeasure(mu), x)

_checked_var_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ
ChainRulesCore.rrule(::typeof(checked_var), ν, x) = checked_var(ν, x), _checked_var_pullback
32 changes: 32 additions & 0 deletions src/insupport.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
inssupport(m, x)
insupport(m)

`insupport(m,x)` computes whether `x` is in the support of `m`.

`insupport(m)` returns a function, and satisfies

insupport(m)(x) == insupport(m, x)
"""
function insupport end


"""
MeasureBase.require_insupport(μ, x)::Nothing

Checks if `x` is in the support of distribution/measure `μ`, throws an
`ArgumentError` if not.
"""
function require_insupport end

_require_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent()
function ChainRulesCore.rrule(::typeof(require_insupport), μ, x)
return require_insupport(μ, x), _require_insupport_pullback
end

function require_insupport(μ, x)
if !insupport(μ, x)
throw(ArgumentError("x is not within the support of μ"))
end
return nothing
end
Loading