Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #88

Merged
merged 392 commits into from
Oct 15, 2022
Merged

Dev #88

Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
392 commits
Select commit Hold shift + click to select a range
1d2b66c
drop debugging code
cscherrer Dec 30, 2021
b5891c5
small update for Likelihood, and a test
cscherrer Dec 30, 2021
21e2c19
fixing up likelihoods
cscherrer Dec 30, 2021
622fa32
improve `basemeasure_depth` dispatch
cscherrer Dec 30, 2021
630067b
still some trouble with inferred basemeasure_depth
cscherrer Dec 30, 2021
b0a0f66
clean up `For` dispatch
cscherrer Dec 30, 2021
8f28472
simplify _logdensityof
cscherrer Dec 30, 2021
80be728
optimize for Returns{True} case
cscherrer Dec 31, 2021
4de454c
Merge branch 'dev' of github.com:cscherrer/MeasureBase.jl into dev
cscherrer Dec 31, 2021
1b65a57
rework basemeasure_depth
cscherrer Dec 31, 2021
18e5605
aggressive tests passing!!
cscherrer Dec 31, 2021
9b7e8aa
drop type-level stuff
cscherrer Dec 31, 2021
f907b8f
drop help
cscherrer Dec 31, 2021
d047fc1
license
cscherrer Jan 3, 2022
7fb384a
affero
cscherrer Jan 5, 2022
68dd214
copyright notice
cscherrer Jan 5, 2022
143d7a7
merge
cscherrer Jan 22, 2022
7e3dab1
Merge branch 'dev' of github.com:cscherrer/MeasureBase.jl into dev
cscherrer Jan 22, 2022
040c815
Drop Create Commons
cscherrer Jan 22, 2022
d24550f
Merge branch 'master' into dev
cscherrer Jan 22, 2022
ae80dc1
cleanup after merge
cscherrer Jan 22, 2022
d6c12f5
update support computations
cscherrer Jan 24, 2022
2f5c31d
insupport(d::SuperpositionMeasure, x)
cscherrer Jan 24, 2022
35271fc
dorp ParamWeighted
cscherrer Jan 24, 2022
de9ccdd
insupport(d::FactoredBase, x)
cscherrer Jan 24, 2022
1e8615d
export unsafe_logdensityof
cscherrer Jan 24, 2022
305ed0e
call promote_type instead of promote_rule
cscherrer Jan 24, 2022
0d35552
logdensity_def for named tuple product measures
cscherrer Jan 24, 2022
cf45278
type annotation for now
cscherrer Jan 24, 2022
e6705af
debugging
cscherrer Jan 25, 2022
093282c
drop shows
cscherrer Jan 26, 2022
7843aea
speed up mapped arrays
cscherrer Feb 4, 2022
641d222
throw an error for `Union{}` types
cscherrer Feb 4, 2022
396405b
MT tests passing
cscherrer Feb 7, 2022
b880018
updates
cscherrer Feb 9, 2022
5afed52
get tests passing
cscherrer Mar 16, 2022
5139be2
MIT license for MeasureBase
cscherrer Mar 16, 2022
101a5fe
bump version
cscherrer Mar 16, 2022
191c5ef
cleanup
cscherrer Mar 29, 2022
636b734
spacing
cscherrer Mar 29, 2022
66ea65c
Merge branch 'dev' of github.com:cscherrer/MeasureBase.jl into dev
cscherrer Mar 30, 2022
6036ec2
Move ConditionalMeasure to MeasureBase
cscherrer Mar 31, 2022
03a7aba
add LogarithmicNumbers
cscherrer Apr 1, 2022
d45e91e
export basemeasure_sequence
cscherrer Apr 1, 2022
b6cf835
update superpose
cscherrer Apr 1, 2022
c483099
fix logdensity_rel
cscherrer Apr 1, 2022
9da96fa
remove FIXME (it's fixed!!)
cscherrer Apr 1, 2022
c8b8822
logdensityof(d::Density, x)
cscherrer Apr 1, 2022
66cee3c
Merge branch 'dev' of github.com:cscherrer/MeasureBase.jl into dev
cscherrer Apr 1, 2022
dc82f20
simplify insupport(::Lebesgue, ::Real)
cscherrer Apr 1, 2022
2f9e5f0
clean up
cscherrer Apr 1, 2022
4e1861f
assume insupport yields Bool
cscherrer Apr 1, 2022
e1b03de
change logdensity_rel fall-through to warning and return NaN
cscherrer Apr 1, 2022
08b40e2
update logdensity_rel
cscherrer Apr 1, 2022
3881dda
drop old code
cscherrer Apr 1, 2022
719be40
fix warning
cscherrer Apr 1, 2022
002bd57
export logdensity_rel
cscherrer Apr 1, 2022
6c00bae
logdensity_def(μ::Dirac, ν::Dirac, x)
cscherrer Apr 1, 2022
9168a2c
logdensity_def methods
cscherrer Apr 1, 2022
2524aa1
drop `static`
cscherrer Apr 1, 2022
b6f5292
]add StatsFuns
cscherrer Apr 1, 2022
c3e2f13
Fixing up superposition
cscherrer Apr 1, 2022
027bc12
[compat] entries
cscherrer Apr 1, 2022
91aa7be
trying to speed things up
cscherrer Apr 1, 2022
b069640
bugfixes
cscherrer Apr 1, 2022
b0078d2
logdensity_rel tests
cscherrer Apr 2, 2022
4b1b49e
logdensity_rel tests
cscherrer Apr 2, 2022
1c16e2c
drop qualifier, and add a test
cscherrer Apr 2, 2022
f469ae0
more tests
cscherrer Apr 2, 2022
aec6841
type constraint in "logdensityof(μ::AbstractMeasure, x)" (was piracy,…
cscherrer Apr 3, 2022
9e45cf2
add some docs
cscherrer Apr 3, 2022
f4ca677
docs
cscherrer Apr 3, 2022
a6cffe5
docs
cscherrer Apr 3, 2022
47ee084
typo
cscherrer Apr 3, 2022
46f8a9b
moar speed
cscherrer Apr 4, 2022
380d696
Merge branch 'dev' of github.com:cscherrer/MeasureBase.jl into dev
cscherrer Apr 5, 2022
b3703bd
don't export Test
cscherrer Apr 5, 2022
b3d3afc
some more updates
cscherrer Apr 5, 2022
1893ae3
logdensity_rel for products
cscherrer Apr 6, 2022
ec64718
`kleisli` docs
cscherrer Apr 6, 2022
09f1c40
update instance_type
cscherrer Apr 9, 2022
a355832
instance_type => Core.Typeof
cscherrer Apr 10, 2022
866b466
`powermeasure` bug fix
cscherrer Apr 11, 2022
74777dc
fix logdensity_rel bug
cscherrer Apr 11, 2022
5436d32
get `commonbase` to take x type into account
cscherrer Apr 11, 2022
fdbd12b
test powers
cscherrer Apr 11, 2022
c3ec497
commonbase docstring
cscherrer Apr 12, 2022
725d2dd
deprecate instance_type
cscherrer Apr 12, 2022
09d5196
avoid breakage
cscherrer Apr 13, 2022
f6f776e
switch || terms
cscherrer Apr 13, 2022
a316828
@ifelse macro
cscherrer Apr 13, 2022
2ca07e7
simplify logdensity_rel
cscherrer Apr 13, 2022
15d8ad4
give up on this @ifelse business
cscherrer Apr 13, 2022
535dcbc
bump version
cscherrer Apr 13, 2022
b30963b
Merge branch 'master' into dev
cscherrer Apr 13, 2022
77156ad
Make `instance` non-generated
cscherrer Apr 21, 2022
a733431
working on likelihoods
cscherrer Apr 26, 2022
f494086
update likelihood
cscherrer Apr 28, 2022
999df37
powerweightedmeasure
cscherrer Apr 28, 2022
5815d17
powerweighted update
cscherrer Apr 28, 2022
9016423
more powerweighted methods
cscherrer Apr 28, 2022
c3fff7e
bugfix
cscherrer Apr 28, 2022
a38da82
dropFactoredBase
cscherrer May 2, 2022
7199c22
drop FactoredBase
cscherrer May 2, 2022
a66c070
(::ProductMeasure) | constraint
cscherrer May 3, 2022
8eda66d
update conditional measure
cscherrer May 3, 2022
fa19d73
update Dirac
cscherrer May 3, 2022
9b86a8d
move conditional.jl down in the `include`s
cscherrer May 3, 2022
ceb310b
Kleisli => TransitionKernel
cscherrer May 4, 2022
e4bb7ff
simplify logdensity_def(::PowerMeasure, x)
cscherrer May 4, 2022
a4b3f44
rename kleisli.jl to kernel.jl
cscherrer May 4, 2022
ed5a61d
update Dirac tests
cscherrer May 4, 2022
1f9eeaa
update Half
cscherrer May 4, 2022
f98c7db
get tests passing
cscherrer May 4, 2022
3bb68fe
update kernel
cscherrer May 5, 2022
fafb3a2
Update Project.toml
cscherrer May 5, 2022
c9bafa5
Merge remote-tracking branch 'origin/master' into dev
cscherrer May 5, 2022
536992f
no call-site inlining
cscherrer May 5, 2022
3ae0f65
restrict single-arg `kernel` to <:ParameterizedMeasure
cscherrer May 5, 2022
7364006
export log_likelihood_ratio
cscherrer May 6, 2022
3eebcfb
Drop DensityKind(::Likelihood), at least for now
cscherrer May 6, 2022
38b26c7
isfinite(x) instead of x>-Inf
cscherrer May 6, 2022
ba1eb20
add `condition` constructor
cscherrer May 6, 2022
417b2d0
EOF newline
cscherrer May 6, 2022
c81ca07
simplify logdensity_def for power measures
cscherrer May 9, 2022
11adfcf
finishing up
cscherrer May 9, 2022
a76a9ed
Merge branch 'master' into dev
cscherrer May 12, 2022
4998111
updates
cscherrer May 18, 2022
7408e6b
kernel stuff
cscherrer May 18, 2022
2210d3d
kernel stuff
cscherrer May 23, 2022
c4ea293
update showe methods
cscherrer May 23, 2022
a7ba363
ass a TODO
cscherrer May 23, 2022
bc08f16
use `dot` instead of `sum`
cscherrer May 23, 2022
20dafa9
drop old code
cscherrer May 23, 2022
11f78c4
typo
cscherrer May 23, 2022
0db846c
Merge branch 'kernels' into dev
cscherrer May 23, 2022
8f9ab6b
formatting
cscherrer May 23, 2022
0e1ecf1
cleanup
cscherrer May 23, 2022
e2a1d80
kernel updates
cscherrer May 23, 2022
2392d8a
uncomment
cscherrer May 23, 2022
90d4908
bugfix
cscherrer May 23, 2022
dc74704
drop old code
cscherrer May 23, 2022
71be3c5
pretty printing
cscherrer May 23, 2022
51d15a9
exports, cleanup
cscherrer May 24, 2022
79dd234
drop old for.jl
cscherrer May 24, 2022
370c8d1
Make DensityKind(::AbstractLikelihood) = IsDensity()
cscherrer May 24, 2022
c2fa848
update Compat version
cscherrer May 24, 2022
e7ec78d
Merge remote-tracking branch 'origin/master' into dev
cscherrer May 24, 2022
5b32581
Make likelihoods work with Distributions
cscherrer May 24, 2022
4cb7e16
_map(f, x::MappedArrays.ReadonlyMappedArray)
cscherrer Jun 1, 2022
8a6ef5c
export productmeasure
cscherrer Jun 1, 2022
407a8f7
Merge branch 'master' into dev
cscherrer Aug 4, 2022
e8a1be5
AbstractMeasure(::AbstractMeasure)
cscherrer Aug 4, 2022
c1b52cd
Merge branch 'master' into dev
cscherrer Sep 2, 2022
a08119f
Merge branch 'master' into dev
cscherrer Sep 11, 2022
d260e61
fixedrng
cscherrer Sep 11, 2022
3bd74a0
StdNormal
cscherrer Sep 11, 2022
d309637
add SpecialFunctions
cscherrer Sep 11, 2022
1b9cec9
no need to qualify
cscherrer Sep 11, 2022
aa22bef
update basemeasure
cscherrer Sep 11, 2022
249b74e
include stdnormal
cscherrer Sep 11, 2022
ed83e12
include fixedrng
cscherrer Sep 11, 2022
f60f85e
update tests
cscherrer Sep 11, 2022
4765887
using SpecialFunctions
cscherrer Sep 12, 2022
88e1ced
fixing transport_def
cscherrer Sep 12, 2022
023f3a4
transport_def bugfix
cscherrer Sep 12, 2022
7c6e430
StdMeasure(::typeof(randn))
cscherrer Sep 12, 2022
d3f0076
checked_arg for LebesgueMeasure
cscherrer Sep 12, 2022
7bfc189
NoTransformOrigin => NoTransportOrigin
cscherrer Sep 13, 2022
07d2de8
transport interface for pushforwards
cscherrer Sep 13, 2022
1d52b13
transporting pushforwards
cscherrer Sep 13, 2022
af88a14
Use LebesgueMeasure for basemeasure
cscherrer Sep 13, 2022
03fa9fa
updates
cscherrer Sep 13, 2022
0ee3c5a
make testvalue fall back on FixedRNG approach
cscherrer Sep 13, 2022
853013c
un-break testvalue
cscherrer Sep 13, 2022
00fdfd7
CI for Juila 1.8
cscherrer Sep 13, 2022
3132af5
fixes
cscherrer Sep 13, 2022
9bbf833
`rand` on a pushforward calls rand on its parent
cscherrer Sep 13, 2022
d8b52c5
LebesgueMeasure => LebesgueBase
cscherrer Sep 14, 2022
9b6e066
tests passing!
cscherrer Sep 14, 2022
ad279ce
change `invoke` type
cscherrer Sep 14, 2022
0bf97b5
Change `test_interface` to check for 2-arg testvalue
cscherrer Sep 14, 2022
4720e21
manually-specifed inverses
cscherrer Sep 14, 2022
d2c3858
more pushfwd stuff
cscherrer Sep 16, 2022
beb7618
A little less wrong
cscherrer Sep 16, 2022
806f314
add mass interface
cscherrer Sep 19, 2022
d05c044
pullback
cscherrer Sep 19, 2022
c0d5ba3
mass interface
cscherrer Sep 19, 2022
c3e8a22
working on mass interface
cscherrer Sep 19, 2022
617ce41
add some `massof` methods
cscherrer Sep 19, 2022
0490a25
Maybe <:Number is better for invalidations?
cscherrer Sep 19, 2022
f68a932
float instead of Int
cscherrer Sep 19, 2022
c05b5c4
logmassof
cscherrer Sep 19, 2022
a2d3a3b
transports for proxies
cscherrer Sep 20, 2022
ff7af02
drop latent-joint.jl
cscherrer Sep 20, 2022
8482cf3
drop exports
cscherrer Sep 21, 2022
cdeb2a0
Drop `logmassof` for now
cscherrer Sep 21, 2022
b15190c
reorganize Lebesgue measure
cscherrer Sep 21, 2022
cccdaf7
IntervalSets
cscherrer Sep 21, 2022
dffa2ec
proxy(::Lebesgue{MeasureBase.RealNumbers}) = LebesgueBase()
cscherrer Sep 21, 2022
44a1512
calling a "useproxy" measure calls its proxy
cscherrer Sep 21, 2022
590c933
StdUniform()(s::Interval)
cscherrer Sep 21, 2022
d46afbf
typo
cscherrer Sep 21, 2022
0838fa2
(m::AbstractMeasure)(s::Interval)
cscherrer Sep 21, 2022
78cf303
bugfix
cscherrer Sep 21, 2022
20a8378
comment
cscherrer Sep 22, 2022
a16d487
IntervalSets version constraint
cscherrer Sep 22, 2022
fcca309
update dynamic_basemeasure_depth
cscherrer Sep 22, 2022
c6c2bff
format
cscherrer Sep 22, 2022
ef7f1ab
Calling a measure calls `massof`
cscherrer Sep 23, 2022
0cfacfd
work on massof
cscherrer Sep 23, 2022
504936c
AbstractSuperpositionMeasure
cscherrer Sep 23, 2022
0e9a563
fix typo
cscherrer Sep 26, 2022
e4147cf
typo
cscherrer Sep 26, 2022
3eba519
format
cscherrer Sep 26, 2022
2393222
docstrings
cscherrer Sep 27, 2022
8ba4f5b
remove massof(::PowerWeightedMeasure) method
cscherrer Sep 27, 2022
d413544
make `massof` better
cscherrer Sep 27, 2022
2164d71
update testvalue
cscherrer Sep 28, 2022
8e63526
formatting
cscherrer Sep 28, 2022
4bc9d93
update _massof
cscherrer Sep 28, 2022
782c412
Update transports for weighted measures
cscherrer Sep 28, 2022
bb092d5
add chain rules
cscherrer Sep 29, 2022
be8251a
invariant mass under transport
cscherrer Sep 29, 2022
38f0fe2
typo
cscherrer Sep 29, 2022
93804cf
bugfix
cscherrer Sep 30, 2022
4aa2b44
hasmethod => Tricks.static_hasmethod
cscherrer Sep 30, 2022
6e5526f
`massof` methods
cscherrer Oct 1, 2022
93ad7d4
Merge remote-tracking branch 'origin/master' into dev
cscherrer Oct 2, 2022
d0530c3
roll back tranports for WeightedMeasure
cscherrer Oct 3, 2022
e9cffcd
Improve transport implementation and add product support (#97)
oschulz Oct 3, 2022
57ecece
`@useproxy` delegates `massof`
cscherrer Oct 3, 2022
2201dc9
drop CI for nightly
cscherrer Oct 4, 2022
f58dfd4
callable densities (#85)
cscherrer Oct 4, 2022
304163d
Pushfwd-inverses (#98)
cscherrer Oct 6, 2022
5d1807d
remove duplicate method
cscherrer Oct 10, 2022
29e0040
remove duplicate `include`
cscherrer Oct 10, 2022
d44ead0
simplify getdof(::PushforwardMeasure)
cscherrer Oct 11, 2022
6d59555
Stieltjes measure function (#100)
cscherrer Oct 14, 2022
aa50a0d
drop redundant `transport_def`s
cscherrer Oct 15, 2022
3dbbddf
update `pushfwd`
cscherrer Oct 15, 2022
d245978
change name
cscherrer Oct 15, 2022
868c348
add type
cscherrer Oct 15, 2022
00b128a
formatting
cscherrer Oct 15, 2022
97cbb33
fix docstring
cscherrer Oct 15, 2022
b6c039f
depend on FunctinoChains
cscherrer Oct 15, 2022
f96b546
Use fchain
cscherrer Oct 15, 2022
0a21672
simplify transport_def for StdLogistic
cscherrer Oct 15, 2022
046b48c
simplify transport_def for StdNormal
cscherrer Oct 15, 2022
18a9f99
drop redundant method
cscherrer Oct 15, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
version:
- '1.6'
- '1.7'
- 'nightly'
- '1.8'
os:
- ubuntu-latest
arch:
Expand Down
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -21,6 +22,7 @@ NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand All @@ -34,14 +36,16 @@ ConstructionBase = "1.3"
DensityInterface = "0.4"
FillArrays = "0.12, 0.13"
IfElse = "0.1"
InverseFunctions = "0.1.7"
IntervalSets = "0.7"
InverseFunctions = "0.1.8"
IrrationalConstants = "0.1"
LogExpFunctions = "0.3"
LogarithmicNumbers = "1"
MappedArrays = "0.4"
NaNMath = "0.3, 1"
PrettyPrinting = "0.3, 0.4"
Reexport = "1"
SpecialFunctions = "2"
Static = "0.5, 0.6"
Tricks = "0.1"
julia = "1.3"
Expand Down
15 changes: 13 additions & 2 deletions src/MeasureBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@ import Random: gentype
using Statistics
using LinearAlgebra

import IntervalSets
# This seems harder than it should be to get `IntervalSets.:(..)`
@eval (using IntervalSets: $(Symbol(IntervalSets.:(..))))

using IntervalSets: Interval, width

import DensityInterface: logdensityof
import DensityInterface: densityof
import DensityInterface: DensityKind
using DensityInterface: FuncDensity, LogFuncDensity
using DensityInterface

using InverseFunctions
Expand All @@ -19,6 +26,7 @@ using ChangesOfVariables
import Base.iterate
import ConstructionBase
using ConstructionBase: constructorof
using IntervalSets

using PrettyPrinting
const Pretty = PrettyPrinting
Expand Down Expand Up @@ -108,17 +116,18 @@ using Compat

using IrrationalConstants

include("smf.jl")
include("getdof.jl")
include("transport.jl")
include("schema.jl")
include("splat.jl")
include("proxies.jl")
include("kernel.jl")
include("parameterized.jl")
include("combinators/half.jl")
include("domains.jl")
include("primitive.jl")
include("utils.jl")
include("mass-interface.jl")
# include("absolutecontinuity.jl")

include("primitives/counting.jl")
Expand All @@ -144,9 +153,11 @@ include("standard/stdmeasure.jl")
include("standard/stduniform.jl")
include("standard/stdexponential.jl")
include("standard/stdlogistic.jl")
include("latent-joint.jl")
include("standard/stdnormal.jl")
include("combinators/half.jl")

include("rand.jl")
include("fixedrng.jl")

include("density.jl")
include("density-core.jl")
Expand Down
16 changes: 15 additions & 1 deletion src/combinators/half.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,18 @@ logdensity_def(μ::Half, x) = logdensity_def(unhalf(μ), x)
insupport(unhalf(d), x)
end

testvalue(::Half) = 1.0
testvalue(::Type{T}, ::Half) where {T} = one(T)

massof(μ::Half) = massof(unhalf(μ))

function smf(μ::Half, x)
2 * smf(μ.parent, max(x, zero(x))) - 1
end

function invsmf(μ::Half, p)
@assert zero(p) ≤ p ≤ one(p)
invsmf(μ.parent, (p + 1) / 2)
end

transport_def(μ::Half, ::StdUniform, p) = invsmf(μ, p)
transport_def(::StdUniform, μ::Half, x) = smf(μ, x)
4 changes: 3 additions & 1 deletion src/combinators/likelihood.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,6 @@ more efficient than

logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x)
"""
likelihood_ratio(ℓ::Likelihood, p, q) = exp(ULogarithmic, logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x))
function likelihood_ratio(ℓ::Likelihood, p, q)
exp(ULogarithmic, logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x))
end
2 changes: 2 additions & 0 deletions src/combinators/power.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,5 @@ end
function checked_arg(μ::PowerMeasure, x::Any)
throw(ArgumentError("Size of variate doesn't match size of power measure"))
end

massof(m::PowerMeasure) = massof(m.parent)^prod(m.axes)
19 changes: 18 additions & 1 deletion src/combinators/product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ function Pretty.tile(μ::AbstractProductMeasure)
result *= Pretty.literal(")")
end

massof(m::AbstractProductMeasure) = prod(massof, marginals(m))

export marginals

function Base.:(==)(a::AbstractProductMeasure, b::AbstractProductMeasure)
Expand Down Expand Up @@ -161,7 +163,9 @@ marginals(μ::ProductMeasure) = μ.marginals
_map(f, args...) = map(f, args...)
_map(f, x::MappedArrays.ReadonlyMappedArray) = mappedarray(f ∘ x.f, x.data)

testvalue(d::AbstractProductMeasure) = _map(testvalue, marginals(d))
function testvalue(::Type{T}, d::AbstractProductMeasure) where {T}
_map(m -> testvalue(T, m), marginals(d))
end

export ⊗

Expand Down Expand Up @@ -220,3 +224,16 @@ end
end
return true
end

getdof(d::AbstractProductMeasure) = mapreduce(getdof, +, marginals(d))

function checked_arg(μ::ProductMeasure{<:NTuple{N,Any}}, x::NTuple{N,Any}) where {N}
map(checked_arg, marginals(μ), x)
end

function checked_arg(
μ::ProductMeasure{<:NamedTuple{names}},
x::NamedTuple{names},
) where {names}
NamedTuple{names}(map(checked_arg, values(marginals(μ)), values(x)))
end
2 changes: 1 addition & 1 deletion src/combinators/smart-constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ superpose(nt::NamedTuple) = SuperpositionMeasure(nt)

function superpose(μ::T, ν::T) where {T<:AbstractMeasure}
if μ == ν
return weightedmeasure(logtwo, μ)
return weightedmeasure(static(float(logtwo)), μ)
else
return superpose((μ, ν))
end
Expand Down
2 changes: 1 addition & 1 deletion src/combinators/spikemixture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ function Base.rand(rng::AbstractRNG, T::Type, μ::SpikeMixture)
return (rand(rng, T) < μ.w) * rand(rng, T, μ.m)
end

testvalue(μ::SpikeMixture) = testvalue(μ.m)
testvalue(::Type{T}, μ::SpikeMixture) where {T} = zero(T)

insupport(μ::SpikeMixture, x) = dynamic(insupport(μ.m, x)) || iszero(x)
8 changes: 6 additions & 2 deletions src/combinators/superpose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ using LogExpFunctions

export SuperpositionMeasure

abstract type AbstractSuperpositionMeasure <: AbstractMeasure end

@doc raw"""
struct SuperpositionMeasure{NT} <: AbstractMeasure
components :: NT
Expand All @@ -24,17 +26,19 @@ Superposition measures satisfy
\end{aligned}
```
"""
struct SuperpositionMeasure{C} <: AbstractMeasure
struct SuperpositionMeasure{C} <: AbstractSuperpositionMeasure
components::C
end

massof(m::SuperpositionMeasure) = sum(massof, m.components)

function Pretty.tile(d::SuperpositionMeasure)
result = Pretty.literal("SuperpositionMeasure(")
result *= Pretty.list_layout([Pretty.tile.(d.components)...])
result *= Pretty.literal(")")
end

testvalue(μ::SuperpositionMeasure) = testvalue(first(μ.components))
testvalue(::Type{T}, μ::SuperpositionMeasure) where {T} = testvalue(T, first(μ.components))

# SuperpositionMeasure(ms :: AbstractMeasure...) = SuperpositionMeasure{X,length(ms)}(ms)

Expand Down
126 changes: 91 additions & 35 deletions src/combinators/transformedmeasure.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# TODO: Compare with ChangesOfVariables.jl

using InverseFunctions: FunctionWithInverse

abstract type AbstractTransformedMeasure <: AbstractMeasure end

abstract type AbstractPushforward <: AbstractTransformedMeasure end
Expand All @@ -17,16 +19,15 @@ function parent(::AbstractTransformedMeasure) end
export PushforwardMeasure

"""
struct PushforwardMeasure{FF,IF,MU,VC<:TransformVolCorr} <: AbstractPushforward
f :: FF
inv_f :: IF
origin :: MU
volcorr :: VC
end
struct PushforwardMeasure{F,I,MU,VC<:TransformVolCorr} <:
AbstractPushforward f :: F finv :: I origin :: MU volcorr :: VC end

Users should not call `PushforwardMeasure` directly. Instead call or add
methods to `pushfwd`.
"""
struct PushforwardMeasure{FF,IF,M,VC<:TransformVolCorr} <: AbstractPushforward
f::FF
inv_f::IF
struct PushforwardMeasure{F,I,M,VC<:TransformVolCorr} <: AbstractPushforward
f::F
finv::I
oschulz marked this conversation as resolved.
Show resolved Hide resolved
origin::M
volcorr::VC
end
Expand All @@ -35,14 +36,25 @@ gettransform(ν::PushforwardMeasure) = ν.f
parent(ν::PushforwardMeasure) = ν.origin

function Pretty.tile(ν::PushforwardMeasure)
Pretty.list_layout(Pretty.tile.([ν.f, ν.inv_f, ν.origin]); prefix = :PushforwardMeasure)
Pretty.list_layout(Pretty.tile.([ν.f, ν.origin]); prefix = :PushforwardMeasure)
end

@inline function logdensity_def(
ν::PushforwardMeasure{FF,IF,M,<:WithVolCorr},
y,
) where {FF,IF,M}
x_orig, inv_ladj = with_logabsdet_jacobian(ν.inv_f, y)
# TODO: THIS IS ALMOST CERTAINLY WRONG
# @inline function logdensity_rel(
# ν::PushforwardMeasure{FF1,IF1,M1,<:WithVolCorr},
# β::PushforwardMeasure{FF2,IF2,M2,<:WithVolCorr},
# y,
# ) where {FF1,IF1,M1,FF2,IF2,M2}
# x = β.inv_f(y)
# f = ν.inv_f ∘ β.f
# inv_f = β.inv_f ∘ ν.f
# logdensity_rel(pushfwd(f, inv_f, ν.origin, WithVolCorr()), β.origin, x)
# end

@inline function logdensity_def(ν::PushforwardMeasure{F,I,M,<:WithVolCorr}, y) where {F,I,M}
f = ν.f
finv = ν.finv
x_orig, inv_ladj = with_logabsdet_jacobian(finv, y)
logd_orig = logdensity_def(ν.origin, x_orig)
logd = float(logd_orig + inv_ladj)
neginf = oftype(logd, -Inf)
Expand All @@ -57,49 +69,93 @@ end
)
end

@inline function logdensity_def(
ν::PushforwardMeasure{FF,IF,M,<:NoVolCorr},
y,
) where {FF,IF,M}
x_orig = to_origin(ν, y)
return logdensity_def(ν.origin, x_orig)
@inline function logdensity_def(ν::PushforwardMeasure{F,I,M,<:NoVolCorr}, y) where {F,I,M}
x = ν.finv(y)
return logdensity_def(ν.origin, x)
end

insupport(ν::PushforwardMeasure, y) = insupport(transport_origin(ν), to_origin(ν, y))
insupport(ν::PushforwardMeasure, y) = insupport(ν.origin, ν.finv(y))

testvalue(ν::PushforwardMeasure) = from_origin(ν, testvalue(transport_origin(ν)))
function testvalue(::Type{T}, ν::PushforwardMeasure) where {T}
ν.f(testvalue(T, parent(ν)))
end

@inline function basemeasure(ν::PushforwardMeasure)
PushforwardMeasure(ν.f, ν.inv_f, basemeasure(transport_origin(ν)), NoVolCorr())
pushfwd(ν.f, basemeasure(parent(ν)), NoVolCorr())
end

_pushfwd_dof(::Type{MU}, ::Type, dof) where {MU} = NoDOF{MU}()
_pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where {MU} = dof

# Assume that DOF are preserved if with_logabsdet_jacobian is functional:
@inline function getdof(ν::MU) where {MU<:PushforwardMeasure}
T = Core.Compiler.return_type(testvalue, Tuple{typeof(ν.origin)})
R = Core.Compiler.return_type(with_logabsdet_jacobian, Tuple{typeof(ν.f),T})
_pushfwd_dof(MU, R, getdof(ν.origin))
end
@inline getdof(ν::MU) where {MU<:PushforwardMeasure} = getdof(ν.origin)

# Bypass `checked_arg`, would require potentially costly transformation:
@inline checked_arg(::PushforwardMeasure, x) = x

@inline transport_origin(ν::PushforwardMeasure) = ν.origin
@inline from_origin(ν::PushforwardMeasure, x) = ν.f(x)
@inline to_origin(ν::PushforwardMeasure, y) = ν.inv_f(y)
@inline to_origin(ν::PushforwardMeasure, y) = ν.finv(y)

function Base.rand(rng::AbstractRNG, ::Type{T}, ν::PushforwardMeasure) where {T}
return from_origin(ν, rand(rng, T, transport_origin(ν)))
return ν.f(rand(rng, T, parent(ν)))
end

###############################################################################
# pushfwd

export pushfwd



"""
pushfwd(f, μ, volcorr = WithVolCorr())

Return the [pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure)
from `μ` the [measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`.
Return the [pushforward
measure](https://en.wikipedia.org/wiki/Pushforward_measure) from `μ` the
[measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`.

To manually specify an inverse, call
`pushfwd(InverseFunctions.setinverse(f, finv), μ, volcorr)`.
"""
pushfwd(f, μ, volcorr = WithVolCorr()) = PushforwardMeasure(f, inverse(f), μ, volcorr)
function pushfwd(f, μ, volcorr::TransformVolCorr = WithVolCorr())
PushforwardMeasure(f, inverse(f), μ, volcorr)
end

function pushfwd(f, μ::PushforwardMeasure, volcorr::TransformVolCorr = WithVolCorr())
_pushfwd_of_pushfwd(f, μ, μ.volcorr, volcorr)
end

function pushfwd(f, μ::PushforwardMeasure, ::WithVolCorr)
cscherrer marked this conversation as resolved.
Show resolved Hide resolved
_pushfwd_of_pushfwd(f, μ, μ.volcorr, WithVolCorr())
end

# Either both WithVolCorr or both NoVolCorr, so we can merge them
function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, ::V, v::V) where {V}
pushfwd(f ∘ μ.f, μ.origin, v)
cscherrer marked this conversation as resolved.
Show resolved Hide resolved
end

function _pushfwd_of_pushfwd(f, μ::PushforwardMeasure, _, v)
PushforwardMeasure(f, inverse(f), μ, v)
end

###############################################################################
# pullback

"""
pullback(f, μ, volcorr = WithVolCorr())

A _pullback_ is a dual concept to a _pushforward_. While a pushforward needs a
map _from_ the support of a measure, a pullback requires a map _into_ the
support of a measure. The log-density is then computed through function
composition, together with a volume correction as needed.

This can be useful, since the log-density of a `PushforwardMeasure` is computing
in terms of the inverse function; the "forward" function is not used at all. In
some cases, we may be focusing on log-density (and not, for example, sampling).

To manually specify an inverse, call
`pullback(InverseFunctions.setinverse(f, finv), μ, volcorr)`.
"""
function pullback(f, μ, volcorr::TransformVolCorr = WithVolCorr())
pushfwd(setinverse(inverse(f), f), μ, volcorr)
end
Loading