Skip to content

Commit

Permalink
update birnn docs
Browse files Browse the repository at this point in the history
  • Loading branch information
NeroBlackstone authored and avik-pal committed Jul 4, 2024
1 parent 82d4bd3 commit 541e130
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
8 changes: 4 additions & 4 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ end
merge_mode::Union{Function, Nothing}=vcat,
ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex())
BidirectionalRNN wrapper for RNNs.
Bidirectional RNN wrapper.
## Arguments
Expand Down Expand Up @@ -688,8 +688,8 @@ BidirectionalRNN wrapper for RNNs.
- Same as `cell` and `backward_cell`.
"""
struct BidirectionalRNN <: AbstractExplicitContainerLayer{(:model,)}
model::Parallel
@concrete struct BidirectionalRNN <: AbstractExplicitContainerLayer{(:model,)}
model <: Parallel
end

(rnn::BidirectionalRNN)(x, ps, st::NamedTuple) = rnn.model(x, ps, st)
Expand All @@ -699,7 +699,7 @@ function BidirectionalRNN(cell::AbstractRecurrentCell;
merge_mode::Union{Function, Nothing}=vcat,
ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex())
layer = Recurrence(cell; return_sequence=true, ordering)
backward_rnn_layer = backward_cell=== nothing ? layer :
backward_rnn_layer = backward_cell === nothing ? layer :
Recurrence(backward_cell; return_sequence=true, ordering)
fuse_op = merge_mode === nothing ? nothing : Broadcast.BroadcastFunction(merge_mode)
return BidirectionalRNN(Parallel(
Expand Down
10 changes: 6 additions & 4 deletions test/layers/recurrent_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,16 +372,18 @@ end
@testset "$mode" for (mode, aType, device, ongpu) in MODES
x = randn(rng, 2, 3) |> aType
@test_throws ErrorException Lux._eachslice(x, BatchLastIndex())
end
end

@testitem "Bidirectional" timeout=3000 setup=[SharedTestSetup] tags=[:recurrent_layers] begin
rng = get_stable_rng(12345)
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, device, ongpu) in MODES
@testset "cell: $_cell" for _cell in (RNNCell, LSTMCell, GRUCell)
cell = _cell(3 => 5)
bi_rnn = BidirectionalRNN(cell)
bi_rnn_no_merge = BidirectionalRNN(cell; merge_mode=nothing)
__display(bi_rnn)
display(bi_rnn)

# Batched Time Series
x = randn(rng, Float32, 3, 4, 2) |> aType
Expand All @@ -401,7 +403,6 @@ end
@test all(x -> size(x) == (5, 2), y_[1])

if mode != "AMDGPU"

# gradients test failed after vcat
# __f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st)))
# @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu
Expand All @@ -412,6 +413,7 @@ end
else
# This is just added as a stub to remember about this broken test
end

@testset "backward_cell: $_backward_cell" for _backward_cell in (
RNNCell, LSTMCell, GRUCell)
cell = _cell(3 => 5)
Expand All @@ -420,7 +422,7 @@ end
bi_rnn_no_merge = BidirectionalRNN(
cell; backward_cell=backward_cell, merge_mode=nothing)
println("BidirectionalRNN:")
__display(bi_rnn)
display(bi_rnn)

# Batched Time Series
x = randn(rng, Float32, 3, 4, 2) |> aType
Expand Down

0 comments on commit 541e130

Please sign in to comment.