Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

test: add tests comparing the fused op with unfused op #157

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.2.0"
version = "1.2.1-DEV"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
27 changes: 21 additions & 6 deletions test/common_ops/dense_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs

anonact = x -> x^3

dense_simple(act, w, x, ::Nothing) = act.(w * x)
dense_simple(act, w, x, b) = act.(w * x .+ b)

function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu)
rng = StableRNG(1234)

Expand Down Expand Up @@ -44,6 +47,20 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu
(w, x, b) -> __f(activation, w, x, b)
end
test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends, soft_fail=fp16)

y_simple = dense_simple(activation, w, x, bias)
y_zyg = fused_dense_bias_activation(activation, w, x, bias)
@test y_simple≈y_zyg atol=atol rtol=rtol

_, ∂w_true, ∂x_true, ∂b_true = Zygote.gradient(
sum ∘ dense_simple, activation, w, x, bias)
_, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote.gradient(
sum ∘ fused_dense_bias_activation, activation, w, x, bias)
@test ∂w_true≈∂w_zyg atol=atol rtol=rtol
@test ∂x_true≈∂x_zyg atol=atol rtol=rtol
if bias !== nothing
@test ∂b_true≈∂b_zyg atol=atol rtol=rtol
end
end

const ALL_TEST_CONFIGS = Iterators.product(
Expand Down Expand Up @@ -149,14 +166,12 @@ end
@testitem "Enzyme.Forward patch: dense" tags=[:dense] setup=[SharedTestSetup] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin
using LuxLib, Random, LuxTestUtils, Enzyme

if LuxTestUtils.ENZYME_TESTING_ENABLED
x = rand(Float32, 2, 2)
x = rand(Float32, 2, 2)

f(x) = sum(abs2, LuxLib.Impl.matmul(x, x))
f(x) = sum(abs2, LuxLib.Impl.matmul(x, x))

# Just test that we don't crash
@test length(Enzyme.gradient(Forward, f, x)) == 4
end
# Just test that we don't crash
@test length(Enzyme.gradient(Forward, f, x)) == 4
end

@testitem "Enzyme rules for fused dense" tags=[:dense] setup=[SharedTestSetup] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin
Expand Down
Loading