Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bijectors integration tests #353

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
- 'ext/nnlib'
- 'ext/special_functions'
- 'integration_testing/array'
- 'integration_testing/bijectors'
- 'integration_testing/diff_tests'
- 'integration_testing/distributions'
- 'integration_testing/gp'
Expand Down
5 changes: 5 additions & 0 deletions test/integration_testing/bijectors/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
134 changes: 134 additions & 0 deletions test/integration_testing/bijectors/bijectors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
using Pkg
Pkg.activate(@__DIR__)
Pkg.develop(; path = joinpath(@__DIR__, "..", "..", ".."))

using Bijectors: Bijectors
using LinearAlgebra: LinearAlgebra
using Random: randn

"""
Type for specifying a test case for `test_rule`.
"""
struct TestCase
func::Function
arg::Any
name::Union{String,Nothing}
broken::Bool
end

TestCase(f, arg; name = nothing, broken=false) = TestCase(f, arg, name, broken)

"""
A helper function that returns a TestCase that evaluates bijector(inverse(bijector)(x))
"""
function b_binv_test_case(bijector, dim; name = nothing, rng = Xoshiro(23))
if name === nothing
name = string(bijector)
end
b_inv = Bijectors.inverse(bijector)
return TestCase(x -> bijector(b_inv(x)), randn(rng, dim); name = name)
end

@testset "Bijectors integration tests" begin
test_cases = TestCase[
b_binv_test_case(Bijectors.VecCorrBijector(), 3),
b_binv_test_case(Bijectors.VecCorrBijector(), 0),
b_binv_test_case(Bijectors.CorrBijector(), (3, 3)),
b_binv_test_case(Bijectors.CorrBijector(), (0, 0)),
b_binv_test_case(Bijectors.VecCholeskyBijector(:L), 3),
b_binv_test_case(Bijectors.VecCholeskyBijector(:L), 0),
b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 3),
b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 0),
b_binv_test_case(
Bijectors.Coupling(Bijectors.Shift, Bijectors.PartitionMask(3, [1], [2])),
3,
),
b_binv_test_case(Bijectors.InvertibleBatchNorm(3), (3, 3)),
b_binv_test_case(Bijectors.LeakyReLU(0.2), 3),
b_binv_test_case(Bijectors.Logit(0.1, 0.3), 3),
b_binv_test_case(Bijectors.PDBijector(), (3, 3)),
b_binv_test_case(Bijectors.PDVecBijector(), 3),
b_binv_test_case(Bijectors.Permute([
0 1 0
1 0 0
0 0 1
]), (3, 3)),
b_binv_test_case(Bijectors.PlanarLayer(3), (3, 3)),
b_binv_test_case(Bijectors.RadialLayer(3), 3),
b_binv_test_case(Bijectors.Reshape((2, 3), (3, 2)), (2, 3)),
b_binv_test_case(Bijectors.Scale(0.2), 3),
b_binv_test_case(Bijectors.Shift(-0.4), 3),
b_binv_test_case(Bijectors.SignFlip(), 3),
b_binv_test_case(Bijectors.SimplexBijector(), 3),
b_binv_test_case(Bijectors.TruncatedBijector(-0.2, 0.5), 3),

# Below, some test cases that don't fit the b_binv_test_case mold.

TestCase(
function (x)
b = Bijectors.RationalQuadraticSpline(
[-0.2, 0.1, 0.5],
[-0.3, 0.3, 0.9],
[1.0, 0.2, 1.0],
)
binv = Bijectors.inverse(b)
return binv(b(x))
end,
randn(Xoshiro(23));
name = "RationalQuadraticSpline on scalar",
),
TestCase(
function (x)
b = Bijectors.OrderedBijector()
binv = Bijectors.inverse(b)
return binv(b(x))
end,
randn(Xoshiro(23), 7);
name = "OrderedBijector",
),
TestCase(
function (x)
layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5])
flow = Bijectors.transformed(
Bijectors.MvNormal(zeros(2), LinearAlgebra.I),
layer,
)
x = x[6:7]
return Bijectors.logpdf(flow.dist, x) -
Bijectors.logabsdetjac(flow.transform, x)
end,
randn(Xoshiro(23), 7);
name = "PlanarLayer7",
# TODO(mhauru) Broken on v1.11 due to
# https://github.com/compintell/Mooncake.jl/issues/319
broken=(VERSION >= v"1.11"),
),
TestCase(
function (x)
layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5])
flow = Bijectors.transformed(
Bijectors.MvNormal(zeros(2), LinearAlgebra.I),
layer,
)
x = reshape(x[6:end], 2, :)
return sum(
Bijectors.logpdf(flow.dist, x) -
Bijectors.logabsdetjac(flow.transform, x),
)
end,
randn(Xoshiro(23), 11);
name = "PlanarLayer11",
),
]

@testset "$(case.name)" for case in test_cases
if case.broken
@test_broken begin
test_rule(Xoshiro(123456), case.func, case.arg; is_primitive=false)
true
end
else
test_rule(Xoshiro(123456), case.func, case.arg; is_primitive=false)
end
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ include("front_matter.jl")
include(joinpath("ext", "special_functions", "special_functions.jl"))
elseif test_group == "integration_testing/array"
include(joinpath("integration_testing", "array.jl"))
elseif test_group == "integration_testing/bijectors"
include(joinpath("integration_testing", "bijectors", "bijectors.jl"))
elseif test_group == "integration_testing/diff_tests"
include(joinpath("integration_testing", "diff_tests.jl"))
elseif test_group == "integration_testing/distributions"
Expand Down