Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
test: add tests comparing the fused op with unfused op
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 10, 2024
1 parent 40d9192 commit f592a66
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.2.0"
version = "1.2.1-DEV"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
17 changes: 17 additions & 0 deletions test/common_ops/dense_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs

anonact = x -> x^3

dense_simple(act, w, x, ::Nothing) = act.(w * x)
dense_simple(act, w, x, b) = act.(w * x .+ b)

function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu)
rng = StableRNG(1234)

Expand Down Expand Up @@ -44,6 +47,20 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu
(w, x, b) -> __f(activation, w, x, b)
end
test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends, soft_fail=fp16)

y_simple = dense_simple(activation, w, x, bias)
y_zyg = fused_dense_bias_activation(activation, w, x, bias)
@test y_simpley_zyg atol=atol rtol=rtol

_, ∂w_true, ∂x_true, ∂b_true = Zygote.gradient(
sum dense_simple, activation, w, x, bias)
_, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote.gradient(
sum fused_dense_bias_activation, activation, w, x, bias)
@test ∂w_true∂w_zyg atol=atol rtol=rtol
@test ∂x_true∂x_zyg atol=atol rtol=rtol
if bias !== nothing
@test ∂b_true∂b_zyg atol=atol rtol=rtol
end
end

const ALL_TEST_CONFIGS = Iterators.product(
Expand Down

0 comments on commit f592a66

Please sign in to comment.