From 39e3703eb187da60fd81a7773019e7acb2c8dd05 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 30 Jun 2024 11:57:26 -0700 Subject: [PATCH] Re-add the tests --- src/layers/recurrent.jl | 6 ++++-- test/layers/recurrent_tests.jl | 32 ++++++++++++-------------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index e2c1e0ddb8..155c228586 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -702,6 +702,8 @@ function BidirectionalRNN(cell::AbstractRecurrentCell; backward_rnn_layer = backward_cell === nothing ? layer : Recurrence(backward_cell; return_sequence=true, ordering) fuse_op = merge_mode === nothing ? nothing : Broadcast.BroadcastFunction(merge_mode) - return BidirectionalRNN(Parallel( - fuse_op, layer, Chain(ReverseSequence(), backward_rnn_layer, ReverseSequence()))) + return BidirectionalRNN(Parallel(fuse_op; + forward_rnn=layer, + backward_rnn=Chain(; + rev1=ReverseSequence(), rnn=backward_rnn_layer, rev2=ReverseSequence()))) end diff --git a/test/layers/recurrent_tests.jl b/test/layers/recurrent_tests.jl index 3df89c2cd8..8af6cc5755 100644 --- a/test/layers/recurrent_tests.jl +++ b/test/layers/recurrent_tests.jl @@ -402,17 +402,14 @@ end @test size(y_[1]) == (4,) @test all(x -> size(x) == (5, 2), y_[1]) - if mode != "AMDGPU" - # gradients test failed after vcat - # __f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st))) - # @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu + __f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st))) + @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu - __f = p -> sum( - Base.Fix1(sum, abs2), first(first(bi_rnn_no_merge(x, p, st)))) - @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu - else - # This is just added as a stub to remember about this broken test + __f = p -> begin + (y1, y2), st_ = bi_rnn_no_merge(x, p, st) + return sum(Base.Fix1(sum, abs2), y1) + sum(Base.Fix1(sum, abs2), y2) end + @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu @testset "backward_cell: $_backward_cell" for _backward_cell in ( RNNCell, LSTMCell, GRUCell) @@ -421,7 +418,6 @@ end bi_rnn = BidirectionalRNN(cell; backward_cell=backward_cell) bi_rnn_no_merge = BidirectionalRNN( cell; backward_cell=backward_cell, merge_mode=nothing) - println("BidirectionalRNN:") display(bi_rnn) # Batched Time Series @@ -441,18 +437,14 @@ end @test size(y_[1]) == (4,) @test all(x -> size(x) == (5, 2), y_[1]) - if mode != "AMDGPU" - # gradients test failed after vcat - # __f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st))) - # @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu + __f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st))) + @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu - __f = p -> sum( - Base.Fix1(sum, abs2), first(first(bi_rnn_no_merge(x, p, st)))) - @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu - else - # This is just added as a stub to remember about this broken test - @test_broken 1 + 1 == 1 + __f = p -> begin + (y1, y2), st_ = bi_rnn_no_merge(x, p, st) + return sum(Base.Fix1(sum, abs2), y1) + sum(Base.Fix1(sum, abs2), y2) end + @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu end end end