Skip to content

Commit

Permalink
update birnn docs
Browse files Browse the repository at this point in the history
  • Loading branch information
NeroBlackstone authored and avik-pal committed Jun 30, 2024
1 parent 22f562f commit af336ad
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
8 changes: 4 additions & 4 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ end
merge_mode::Union{Function, Nothing}=vcat,
ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex())
BidirectionalRNN wrapper for RNNs.
Bidirectional RNN wrapper.
## Arguments
Expand Down Expand Up @@ -682,8 +682,8 @@ BidirectionalRNN wrapper for RNNs.
- Same as `cell` and `backward_cell`.
"""
struct BidirectionalRNN <: AbstractExplicitContainerLayer{(:model,)}
model::Parallel
@concrete struct BidirectionalRNN <: AbstractExplicitContainerLayer{(:model,)}
model <: Parallel
end

(rnn::BidirectionalRNN)(x, ps, st::NamedTuple) = rnn.model(x, ps, st)

Check warning on line 689 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L689

Added line #L689 was not covered by tests
Expand All @@ -693,7 +693,7 @@ function BidirectionalRNN(cell::AbstractRecurrentCell;
merge_mode::Union{Function, Nothing}=vcat,
ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex())
layer = Recurrence(cell; return_sequence=true, ordering)
backward_rnn_layer = backward_cell=== nothing ? layer :
backward_rnn_layer = backward_cell === nothing ? layer :

Check warning on line 696 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L695-L696

Added lines #L695 - L696 were not covered by tests
Recurrence(backward_cell; return_sequence=true, ordering)
fuse_op = merge_mode === nothing ? nothing : Broadcast.BroadcastFunction(merge_mode)
return BidirectionalRNN(Parallel(

Check warning on line 699 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L698-L699

Added lines #L698 - L699 were not covered by tests
Expand Down
10 changes: 6 additions & 4 deletions test/layers/recurrent_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,16 +382,18 @@ end
@testset "$mode" for (mode, aType, device, ongpu) in MODES
x = randn(rng, 2, 3) |> aType
@test_throws ErrorException Lux._eachslice(x, BatchLastIndex())
end
end

@testitem "Bidirectional" timeout=3000 setup=[SharedTestSetup] tags=[:recurrent_layers] begin
rng = get_stable_rng(12345)
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, device, ongpu) in MODES
@testset "cell: $_cell" for _cell in (RNNCell, LSTMCell, GRUCell)
cell = _cell(3 => 5)
bi_rnn = BidirectionalRNN(cell)
bi_rnn_no_merge = BidirectionalRNN(cell; merge_mode=nothing)
__display(bi_rnn)
display(bi_rnn)

# Batched Time Series
x = randn(rng, Float32, 3, 4, 2) |> aType
Expand All @@ -411,7 +413,6 @@ end
@test all(x -> size(x) == (5, 2), y_[1])

if mode != "AMDGPU"

# gradients test failed after vcat
# __f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st)))
# @eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu
Expand All @@ -422,6 +423,7 @@ end
else
# This is just added as a stub to remember about this broken test
end

@testset "backward_cell: $_backward_cell" for _backward_cell in (
RNNCell, LSTMCell, GRUCell)
cell = _cell(3 => 5)
Expand All @@ -430,7 +432,7 @@ end
bi_rnn_no_merge = BidirectionalRNN(
cell; backward_cell=backward_cell, merge_mode=nothing)
println("BidirectionalRNN:")
__display(bi_rnn)
display(bi_rnn)

# Batched Time Series
x = randn(rng, Float32, 3, 4, 2) |> aType
Expand Down

0 comments on commit af336ad

Please sign in to comment.