diff --git a/Project.toml b/Project.toml index 37d7a25b..0517a3bf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.2.0" +version = "1.2.1-DEV" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/test/common_ops/dense_tests.jl b/test/common_ops/dense_tests.jl index f139928d..6a53922d 100644 --- a/test/common_ops/dense_tests.jl +++ b/test/common_ops/dense_tests.jl @@ -3,6 +3,9 @@ using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs anonact = x -> x^3 +dense_simple(act, w, x, ::Nothing) = act.(w * x) +dense_simple(act, w, x, b) = act.(w * x .+ b) + function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) rng = StableRNG(1234) @@ -44,6 +47,20 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu (w, x, b) -> __f(activation, w, x, b) end test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends, soft_fail=fp16) + + y_simple = dense_simple(activation, w, x, bias) + y_zyg = fused_dense_bias_activation(activation, w, x, bias) + @test y_simple≈y_zyg atol=atol rtol=rtol + + _, ∂w_true, ∂x_true, ∂b_true = Zygote.gradient( + sum ∘ dense_simple, activation, w, x, bias) + _, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote.gradient( + sum ∘ fused_dense_bias_activation, activation, w, x, bias) + @test ∂w_true≈∂w_zyg atol=atol rtol=rtol + @test ∂x_true≈∂x_zyg atol=atol rtol=rtol + if bias !== nothing + @test ∂b_true≈∂b_zyg atol=atol rtol=rtol + end end const ALL_TEST_CONFIGS = Iterators.product(