diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 87cc04fc8..aed689adb 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -31,6 +31,7 @@ jobs: - 'ext/nnlib' - 'ext/special_functions' - 'integration_testing/array' + - 'integration_testing/bijectors' - 'integration_testing/diff_tests' - 'integration_testing/distributions' - 'integration_testing/gp' diff --git a/test/integration_testing/bijectors/Project.toml b/test/integration_testing/bijectors/Project.toml new file mode 100644 index 000000000..54a732057 --- /dev/null +++ b/test/integration_testing/bijectors/Project.toml @@ -0,0 +1,5 @@ +[deps] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/integration_testing/bijectors/bijectors.jl b/test/integration_testing/bijectors/bijectors.jl new file mode 100644 index 000000000..9a2a7fa4b --- /dev/null +++ b/test/integration_testing/bijectors/bijectors.jl @@ -0,0 +1,134 @@ +using Pkg +Pkg.activate(@__DIR__) +Pkg.develop(; path = joinpath(@__DIR__, "..", "..", "..")) + +using Bijectors: Bijectors +using LinearAlgebra: LinearAlgebra +using Random: randn + +""" +Type for specifying a test case for `test_rule`. +""" +struct TestCase + func::Function + arg::Any + name::Union{String,Nothing} + broken::Bool +end + +TestCase(f, arg; name = nothing, broken=false) = TestCase(f, arg, name, broken) + +""" +A helper function that returns a TestCase that evaluates bijector(inverse(bijector)(x)) +""" +function b_binv_test_case(bijector, dim; name = nothing, rng = Xoshiro(23)) + if name === nothing + name = string(bijector) + end + b_inv = Bijectors.inverse(bijector) + return TestCase(x -> bijector(b_inv(x)), randn(rng, dim); name = name) +end + +@testset "Bijectors integration tests" begin + test_cases = TestCase[ + b_binv_test_case(Bijectors.VecCorrBijector(), 3), + b_binv_test_case(Bijectors.VecCorrBijector(), 0), + b_binv_test_case(Bijectors.CorrBijector(), (3, 3)), + b_binv_test_case(Bijectors.CorrBijector(), (0, 0)), + b_binv_test_case(Bijectors.VecCholeskyBijector(:L), 3), + b_binv_test_case(Bijectors.VecCholeskyBijector(:L), 0), + b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 3), + b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 0), + b_binv_test_case( + Bijectors.Coupling(Bijectors.Shift, Bijectors.PartitionMask(3, [1], [2])), + 3, + ), + b_binv_test_case(Bijectors.InvertibleBatchNorm(3), (3, 3)), + b_binv_test_case(Bijectors.LeakyReLU(0.2), 3), + b_binv_test_case(Bijectors.Logit(0.1, 0.3), 3), + b_binv_test_case(Bijectors.PDBijector(), (3, 3)), + b_binv_test_case(Bijectors.PDVecBijector(), 3), + b_binv_test_case(Bijectors.Permute([ + 0 1 0 + 1 0 0 + 0 0 1 + ]), (3, 3)), + b_binv_test_case(Bijectors.PlanarLayer(3), (3, 3)), + b_binv_test_case(Bijectors.RadialLayer(3), 3), + b_binv_test_case(Bijectors.Reshape((2, 3), (3, 2)), (2, 3)), + b_binv_test_case(Bijectors.Scale(0.2), 3), + b_binv_test_case(Bijectors.Shift(-0.4), 3), + b_binv_test_case(Bijectors.SignFlip(), 3), + b_binv_test_case(Bijectors.SimplexBijector(), 3), + b_binv_test_case(Bijectors.TruncatedBijector(-0.2, 0.5), 3), + + # Below, some test cases that don't fit the b_binv_test_case mold. + + TestCase( + function (x) + b = Bijectors.RationalQuadraticSpline( + [-0.2, 0.1, 0.5], + [-0.3, 0.3, 0.9], + [1.0, 0.2, 1.0], + ) + binv = Bijectors.inverse(b) + return binv(b(x)) + end, + randn(Xoshiro(23)); + name = "RationalQuadraticSpline on scalar", + ), + TestCase( + function (x) + b = Bijectors.OrderedBijector() + binv = Bijectors.inverse(b) + return binv(b(x)) + end, + randn(Xoshiro(23), 7); + name = "OrderedBijector", + ), + TestCase( + function (x) + layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5]) + flow = Bijectors.transformed( + Bijectors.MvNormal(zeros(2), LinearAlgebra.I), + layer, + ) + x = x[6:7] + return Bijectors.logpdf(flow.dist, x) - + Bijectors.logabsdetjac(flow.transform, x) + end, + randn(Xoshiro(23), 7); + name = "PlanarLayer7", + # TODO(mhauru) Broken on v1.11 due to + # https://github.com/compintell/Mooncake.jl/issues/319 + broken=(VERSION >= v"1.11"), + ), + TestCase( + function (x) + layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5]) + flow = Bijectors.transformed( + Bijectors.MvNormal(zeros(2), LinearAlgebra.I), + layer, + ) + x = reshape(x[6:end], 2, :) + return sum( + Bijectors.logpdf(flow.dist, x) - + Bijectors.logabsdetjac(flow.transform, x), + ) + end, + randn(Xoshiro(23), 11); + name = "PlanarLayer11", + ), + ] + + @testset "$(case.name)" for case in test_cases + if case.broken + @test_broken begin + test_rule(Xoshiro(123456), case.func, case.arg; is_primitive=false) + true + end + else + test_rule(Xoshiro(123456), case.func, case.arg; is_primitive=false) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index e28bbd470..dd9d955ea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -69,6 +69,8 @@ include("front_matter.jl") include(joinpath("ext", "special_functions", "special_functions.jl")) elseif test_group == "integration_testing/array" include(joinpath("integration_testing", "array.jl")) + elseif test_group == "integration_testing/bijectors" + include(joinpath("integration_testing", "bijectors", "bijectors.jl")) elseif test_group == "integration_testing/diff_tests" include(joinpath("integration_testing", "diff_tests.jl")) elseif test_group == "integration_testing/distributions"