diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 8db8fec08f..d19fe57a0b 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -642,7 +642,7 @@ end merge_mode::Union{Function, Nothing}=vcat, ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex()) -BidirectionalRNN wrapper for RNNs. +Bidirectional RNN wrapper. ## Arguments @@ -682,8 +682,8 @@ BidirectionalRNN wrapper for RNNs. - Same as `cell` and `backward_cell`. """ -struct BidirectionalRNN <: AbstractExplicitContainerLayer{(:model,)} - model::Parallel +@concrete struct BidirectionalRNN <: AbstractExplicitContainerLayer{(:model,)} + model <: Parallel end (rnn::BidirectionalRNN)(x, ps, st::NamedTuple) = rnn.model(x, ps, st) @@ -693,7 +693,7 @@ function BidirectionalRNN(cell::AbstractRecurrentCell; merge_mode::Union{Function, Nothing}=vcat, ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex()) layer = Recurrence(cell; return_sequence=true, ordering) - backward_rnn_layer = backward_cell=== nothing ? layer : + backward_rnn_layer = backward_cell === nothing ? layer : Recurrence(backward_cell; return_sequence=true, ordering) fuse_op = merge_mode === nothing ? nothing : Broadcast.BroadcastFunction(merge_mode) return BidirectionalRNN(Parallel( diff --git a/test/layers/recurrent_tests.jl b/test/layers/recurrent_tests.jl index 2483be1607..9f2b4c3eaa 100644 --- a/test/layers/recurrent_tests.jl +++ b/test/layers/recurrent_tests.jl @@ -382,16 +382,18 @@ end @testset "$mode" for (mode, aType, device, ongpu) in MODES x = randn(rng, 2, 3) |> aType @test_throws ErrorException Lux._eachslice(x, BatchLastIndex()) + end end @testitem "Bidirectional" timeout=3000 setup=[SharedTestSetup] tags=[:recurrent_layers] begin - rng = get_stable_rng(12345) + rng = StableRNG(12345) + @testset "$mode" for (mode, aType, device, ongpu) in MODES @testset "cell: $_cell" for _cell in (RNNCell, LSTMCell, GRUCell) cell = _cell(3 => 5) bi_rnn = BidirectionalRNN(cell) bi_rnn_no_merge = BidirectionalRNN(cell; merge_mode=nothing) - __display(bi_rnn) + display(bi_rnn) # Batched Time Series x = randn(rng, Float32, 3, 4, 2) |> aType @@ -411,7 +413,6 @@ end @test all(x -> size(x) == (5, 2), y_[1]) if mode != "AMDGPU" - # gradients test failed after vcat # __f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st))) # @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu @@ -422,6 +423,7 @@ end else # This is just added as a stub to remember about this broken test end + @testset "backward_cell: $_backward_cell" for _backward_cell in ( RNNCell, LSTMCell, GRUCell) cell = _cell(3 => 5) @@ -430,7 +432,7 @@ end bi_rnn_no_merge = BidirectionalRNN( cell; backward_cell=backward_cell, merge_mode=nothing) println("BidirectionalRNN:") - __display(bi_rnn) + display(bi_rnn) # Batched Time Series x = randn(rng, Float32, 3, 4, 2) |> aType