From edb7157bb08119637f9292a8d72b93924be5e4fa Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 7 Nov 2024 12:42:16 +0000 Subject: [PATCH 1/4] Add Bijectors integration tests --- .../bijectors/Project.toml | 5 + .../bijectors/bijectors.jl | 123 ++++++++++++++++++ test/runtests.jl | 2 + 3 files changed, 130 insertions(+) create mode 100644 test/integration_testing/bijectors/Project.toml create mode 100644 test/integration_testing/bijectors/bijectors.jl diff --git a/test/integration_testing/bijectors/Project.toml b/test/integration_testing/bijectors/Project.toml new file mode 100644 index 000000000..54a732057 --- /dev/null +++ b/test/integration_testing/bijectors/Project.toml @@ -0,0 +1,5 @@ +[deps] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/integration_testing/bijectors/bijectors.jl b/test/integration_testing/bijectors/bijectors.jl new file mode 100644 index 000000000..70f515dc8 --- /dev/null +++ b/test/integration_testing/bijectors/bijectors.jl @@ -0,0 +1,123 @@ +using Pkg +Pkg.activate(@__DIR__) +Pkg.develop(; path = joinpath(@__DIR__, "..", "..", "..")) + +using Bijectors: Bijectors +using LinearAlgebra: LinearAlgebra +using Random: randn + +""" +Type for specifying a test case for `test_rule`. +""" +struct TestCase + func::Function + arg::Any + name::Union{String,Nothing} +end + +TestCase(f, arg; name = nothing) = TestCase(f, arg, name) + +""" +A helper function that returns a TestCase that evaluates sum(bijector(inverse(bijector)(x))) +""" +function sum_b_binv_test_case(bijector, dim; name = nothing, rng = Xoshiro(23)) + if name === nothing + name = string(bijector) + end + b_inv = Bijectors.inverse(bijector) + return TestCase(x -> sum(bijector(b_inv(x))), randn(rng, dim); name = name) +end + +@testset "Bijectors integration tests" begin + test_cases = TestCase[ + sum_b_binv_test_case(Bijectors.VecCorrBijector(), 3), + sum_b_binv_test_case(Bijectors.VecCorrBijector(), 0), + sum_b_binv_test_case(Bijectors.CorrBijector(), (3, 3)), + sum_b_binv_test_case(Bijectors.CorrBijector(), (0, 0)), + sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:L), 3), + sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:L), 0), + sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 3), + sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 0), + sum_b_binv_test_case( + Bijectors.Coupling(Bijectors.Shift, Bijectors.PartitionMask(3, [1], [2])), + 3, + ), + sum_b_binv_test_case(Bijectors.InvertibleBatchNorm(3), (3, 3)), + sum_b_binv_test_case(Bijectors.LeakyReLU(0.2), 3), + sum_b_binv_test_case(Bijectors.Logit(0.1, 0.3), 3), + sum_b_binv_test_case(Bijectors.PDBijector(), (3, 3)), + sum_b_binv_test_case(Bijectors.PDVecBijector(), 3), + sum_b_binv_test_case(Bijectors.Permute([ + 0 1 0 + 1 0 0 + 0 0 1 + ]), (3, 3)), + sum_b_binv_test_case(Bijectors.PlanarLayer(3), (3, 3)), + sum_b_binv_test_case(Bijectors.RadialLayer(3), 3), + sum_b_binv_test_case(Bijectors.Reshape((2, 3), (3, 2)), (2, 3)), + sum_b_binv_test_case(Bijectors.Scale(0.2), 3), + sum_b_binv_test_case(Bijectors.Shift(-0.4), 3), + sum_b_binv_test_case(Bijectors.SignFlip(), 3), + sum_b_binv_test_case(Bijectors.SimplexBijector(), 3), + sum_b_binv_test_case(Bijectors.TruncatedBijector(-0.2, 0.5), 3), + + # Below, some test cases that don't fit the sum_b_binv_test_case mold. + + TestCase( + function (x) + b = Bijectors.RationalQuadraticSpline( + [-0.2, 0.1, 0.5], + [-0.3, 0.3, 0.9], + [1.0, 0.2, 1.0], + ) + binv = Bijectors.inverse(b) + return sum(binv(b(x))) + end, + randn(Xoshiro(23)); + name = "RationalQuadraticSpline on scalar", + ), + TestCase( + function (x) + b = Bijectors.OrderedBijector() + binv = Bijectors.inverse(b) + return sum(binv(b(x))) + end, + randn(Xoshiro(23), 7); + name = "OrderedBijector", + ), + TestCase( + function (x) + layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5]) + flow = Bijectors.transformed( + Bijectors.MvNormal(zeros(2), LinearAlgebra.I), + layer, + ) + x = x[6:7] + return Bijectors.logpdf(flow.dist, x) - + Bijectors.logabsdetjac(flow.transform, x) + end, + randn(Xoshiro(23), 7); + name = "PlanarLayer7", + ), + TestCase( + function (x) + layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5]) + flow = Bijectors.transformed( + Bijectors.MvNormal(zeros(2), LinearAlgebra.I), + layer, + ) + x = reshape(x[6:end], 2, :) + return sum( + Bijectors.logpdf(flow.dist, x) - + Bijectors.logabsdetjac(flow.transform, x), + ) + end, + randn(Xoshiro(23), 11); + name = "PlanarLayer11", + ), + ] + + @testset "$(case.name)" for case in test_cases + test_rule(Xoshiro(123456), case.func, case.arg; is_primitive = false) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index e28bbd470..dd9d955ea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -69,6 +69,8 @@ include("front_matter.jl") include(joinpath("ext", "special_functions", "special_functions.jl")) elseif test_group == "integration_testing/array" include(joinpath("integration_testing", "array.jl")) + elseif test_group == "integration_testing/bijectors" + include(joinpath("integration_testing", "bijectors", "bijectors.jl")) elseif test_group == "integration_testing/diff_tests" include(joinpath("integration_testing", "diff_tests.jl")) elseif test_group == "integration_testing/distributions" From d1156d7e58b90cef26459e6a6142bd3cb3c7e02e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 7 Nov 2024 12:45:32 +0000 Subject: [PATCH 2/4] Add bijectors integration_tests to CI --- .github/workflows/CI.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 87cc04fc8..aed689adb 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -31,6 +31,7 @@ jobs: - 'ext/nnlib' - 'ext/special_functions' - 'integration_testing/array' + - 'integration_testing/bijectors' - 'integration_testing/diff_tests' - 'integration_testing/distributions' - 'integration_testing/gp' From 20642cd56558e640adb7ba8bdc45fa927283c91c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 7 Nov 2024 13:31:55 +0000 Subject: [PATCH 3/4] Remove unnecessary summing in Bijectors tests --- .../bijectors/bijectors.jl | 58 +++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/test/integration_testing/bijectors/bijectors.jl b/test/integration_testing/bijectors/bijectors.jl index 70f515dc8..614fd0efb 100644 --- a/test/integration_testing/bijectors/bijectors.jl +++ b/test/integration_testing/bijectors/bijectors.jl @@ -18,50 +18,50 @@ end TestCase(f, arg; name = nothing) = TestCase(f, arg, name) """ -A helper function that returns a TestCase that evaluates sum(bijector(inverse(bijector)(x))) +A helper function that returns a TestCase that evaluates bijector(inverse(bijector)(x)) """ -function sum_b_binv_test_case(bijector, dim; name = nothing, rng = Xoshiro(23)) +function b_binv_test_case(bijector, dim; name = nothing, rng = Xoshiro(23)) if name === nothing name = string(bijector) end b_inv = Bijectors.inverse(bijector) - return TestCase(x -> sum(bijector(b_inv(x))), randn(rng, dim); name = name) + return TestCase(x -> bijector(b_inv(x)), randn(rng, dim); name = name) end @testset "Bijectors integration tests" begin test_cases = TestCase[ - sum_b_binv_test_case(Bijectors.VecCorrBijector(), 3), - sum_b_binv_test_case(Bijectors.VecCorrBijector(), 0), - sum_b_binv_test_case(Bijectors.CorrBijector(), (3, 3)), - sum_b_binv_test_case(Bijectors.CorrBijector(), (0, 0)), - sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:L), 3), - sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:L), 0), - sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 3), - sum_b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 0), - sum_b_binv_test_case( + b_binv_test_case(Bijectors.VecCorrBijector(), 3), + b_binv_test_case(Bijectors.VecCorrBijector(), 0), + b_binv_test_case(Bijectors.CorrBijector(), (3, 3)), + b_binv_test_case(Bijectors.CorrBijector(), (0, 0)), + b_binv_test_case(Bijectors.VecCholeskyBijector(:L), 3), + b_binv_test_case(Bijectors.VecCholeskyBijector(:L), 0), + b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 3), + b_binv_test_case(Bijectors.VecCholeskyBijector(:U), 0), + b_binv_test_case( Bijectors.Coupling(Bijectors.Shift, Bijectors.PartitionMask(3, [1], [2])), 3, ), - sum_b_binv_test_case(Bijectors.InvertibleBatchNorm(3), (3, 3)), - sum_b_binv_test_case(Bijectors.LeakyReLU(0.2), 3), - sum_b_binv_test_case(Bijectors.Logit(0.1, 0.3), 3), - sum_b_binv_test_case(Bijectors.PDBijector(), (3, 3)), - sum_b_binv_test_case(Bijectors.PDVecBijector(), 3), - sum_b_binv_test_case(Bijectors.Permute([ + b_binv_test_case(Bijectors.InvertibleBatchNorm(3), (3, 3)), + b_binv_test_case(Bijectors.LeakyReLU(0.2), 3), + b_binv_test_case(Bijectors.Logit(0.1, 0.3), 3), + b_binv_test_case(Bijectors.PDBijector(), (3, 3)), + b_binv_test_case(Bijectors.PDVecBijector(), 3), + b_binv_test_case(Bijectors.Permute([ 0 1 0 1 0 0 0 0 1 ]), (3, 3)), - sum_b_binv_test_case(Bijectors.PlanarLayer(3), (3, 3)), - sum_b_binv_test_case(Bijectors.RadialLayer(3), 3), - sum_b_binv_test_case(Bijectors.Reshape((2, 3), (3, 2)), (2, 3)), - sum_b_binv_test_case(Bijectors.Scale(0.2), 3), - sum_b_binv_test_case(Bijectors.Shift(-0.4), 3), - sum_b_binv_test_case(Bijectors.SignFlip(), 3), - sum_b_binv_test_case(Bijectors.SimplexBijector(), 3), - sum_b_binv_test_case(Bijectors.TruncatedBijector(-0.2, 0.5), 3), + b_binv_test_case(Bijectors.PlanarLayer(3), (3, 3)), + b_binv_test_case(Bijectors.RadialLayer(3), 3), + b_binv_test_case(Bijectors.Reshape((2, 3), (3, 2)), (2, 3)), + b_binv_test_case(Bijectors.Scale(0.2), 3), + b_binv_test_case(Bijectors.Shift(-0.4), 3), + b_binv_test_case(Bijectors.SignFlip(), 3), + b_binv_test_case(Bijectors.SimplexBijector(), 3), + b_binv_test_case(Bijectors.TruncatedBijector(-0.2, 0.5), 3), - # Below, some test cases that don't fit the sum_b_binv_test_case mold. + # Below, some test cases that don't fit the b_binv_test_case mold. TestCase( function (x) @@ -71,7 +71,7 @@ end [1.0, 0.2, 1.0], ) binv = Bijectors.inverse(b) - return sum(binv(b(x))) + return binv(b(x)) end, randn(Xoshiro(23)); name = "RationalQuadraticSpline on scalar", @@ -80,7 +80,7 @@ end function (x) b = Bijectors.OrderedBijector() binv = Bijectors.inverse(b) - return sum(binv(b(x))) + return binv(b(x)) end, randn(Xoshiro(23), 7); name = "OrderedBijector", From b496154a0099da5e36d6435dec6046cb0c0d2148 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 7 Nov 2024 13:32:20 +0000 Subject: [PATCH 4/4] Mark on Bijector test as broken on v1.11 --- test/integration_testing/bijectors/bijectors.jl | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/test/integration_testing/bijectors/bijectors.jl b/test/integration_testing/bijectors/bijectors.jl index 614fd0efb..9a2a7fa4b 100644 --- a/test/integration_testing/bijectors/bijectors.jl +++ b/test/integration_testing/bijectors/bijectors.jl @@ -13,9 +13,10 @@ struct TestCase func::Function arg::Any name::Union{String,Nothing} + broken::Bool end -TestCase(f, arg; name = nothing) = TestCase(f, arg, name) +TestCase(f, arg; name = nothing, broken=false) = TestCase(f, arg, name, broken) """ A helper function that returns a TestCase that evaluates bijector(inverse(bijector)(x)) @@ -98,6 +99,9 @@ end end, randn(Xoshiro(23), 7); name = "PlanarLayer7", + # TODO(mhauru) Broken on v1.11 due to + # https://github.com/compintell/Mooncake.jl/issues/319 + broken=(VERSION >= v"1.11"), ), TestCase( function (x) @@ -118,6 +122,13 @@ end ] @testset "$(case.name)" for case in test_cases - test_rule(Xoshiro(123456), case.func, case.arg; is_primitive = false) + if case.broken + @test_broken begin + test_rule(Xoshiro(123456), case.func, case.arg; is_primitive=false) + true + end + else + test_rule(Xoshiro(123456), case.func, case.arg; is_primitive=false) + end end end