From dba63b9ae6cacbc5977f53770eff22e557fc4e0c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 10 Sep 2024 15:42:40 -0400 Subject: [PATCH] test: add tests comparing the fused op with unfused op --- Project.toml | 2 +- test/common_ops/dense_tests.jl | 27 +++++++++++++++++++++------ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 37d7a25b..0517a3bf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.2.0" +version = "1.2.1-DEV" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/test/common_ops/dense_tests.jl b/test/common_ops/dense_tests.jl index f139928d..69b2ad3f 100644 --- a/test/common_ops/dense_tests.jl +++ b/test/common_ops/dense_tests.jl @@ -3,6 +3,9 @@ using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs anonact = x -> x^3 +dense_simple(act, w, x, ::Nothing) = act.(w * x) +dense_simple(act, w, x, b) = act.(w * x .+ b) + function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) rng = StableRNG(1234) @@ -44,6 +47,20 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu (w, x, b) -> __f(activation, w, x, b) end test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends, soft_fail=fp16) + + y_simple = dense_simple(activation, w, x, bias) + y_zyg = fused_dense_bias_activation(activation, w, x, bias) + @test y_simple≈y_zyg atol=atol rtol=rtol + + _, ∂w_true, ∂x_true, ∂b_true = Zygote.gradient( + sum ∘ dense_simple, activation, w, x, bias) + _, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote.gradient( + sum ∘ fused_dense_bias_activation, activation, w, x, bias) + @test ∂w_true≈∂w_zyg atol=atol rtol=rtol + @test ∂x_true≈∂x_zyg atol=atol rtol=rtol + if bias !== nothing + @test ∂b_true≈∂b_zyg atol=atol rtol=rtol + end end const ALL_TEST_CONFIGS = Iterators.product( @@ -149,14 +166,12 @@ end @testitem "Enzyme.Forward patch: dense" tags=[:dense] setup=[SharedTestSetup] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin using LuxLib, Random, LuxTestUtils, Enzyme - if LuxTestUtils.ENZYME_TESTING_ENABLED - x = rand(Float32, 2, 2) + x = rand(Float32, 2, 2) - f(x) = sum(abs2, LuxLib.Impl.matmul(x, x)) + f(x) = sum(abs2, LuxLib.Impl.matmul(x, x)) - # Just test that we don't crash - @test length(Enzyme.gradient(Forward, f, x)) == 4 - end + # Just test that we don't crash + @test length(Enzyme.gradient(Forward, f, x)) == 4 end @testitem "Enzyme rules for fused dense" tags=[:dense] setup=[SharedTestSetup] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin