-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bidirectional RNN #708
Bidirectional RNN #708
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #708 +/- ##
==========================================
- Coverage 96.30% 87.00% -9.31%
==========================================
Files 57 57
Lines 2789 2801 +12
==========================================
- Hits 2686 2437 -249
- Misses 103 364 +261 ☔ View full report in Codecov by Sentry. |
@avik-pal Hi! I have completed the implementation of |
😭😭😭Any suggestions? I will fix it today |
@avik-pal 🥹🥹Hi, could you please help me review the PR and find the reason why Thank you very much |
@avik-pal Sorry to bother you. I still don't know how to solve (rnn::BidirectionalRNN)(x, ps, st::NamedTuple) = rnn.model(x, ps, st) julia> @report_opt bi_rnn(x, ps, st)
═════ 1 possible error found ═════
┌ (::BidirectionalRNN)(x::Array{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}) @ Lux ./Lux.jl/src/layers/recurrent.jl:690
│ runtime dispatch detected: %1::Parallel(x::Array{Float32, 3}, ps::@NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}}, st::@NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}})::Any
└──────────────────── |
I will take a look on the weekend |
😥Hi, could you help me review this PR?.. |
b97e29a
to
e7ae725
Compare
Fix the gradient tests and it should be fine. They are probably originating from lazy reverse rrules for Zygote not being defined for GPU arrays |
Thank you for your help! Some gradient tests still failed at here, I have no idea about how to set 19:22:03 | maxrss 20.0% | mem 67.2% | DONE (1/1) test item "Bidirectional" 112.9 secs (68.5% compile, 0.1% recompile, 6.2% GC), 188.79 M allocs (13.948 GB)
Test Summary: | Pass Error Total Time
ReTestItem Tests | 120 12 132 2m08.7s
Bidirectional | 60 6 66 2m03.7s
cpu | 30 3 33 1m08.8s
cell: RNNCell | 10 1 11 52.2s
cell: LSTMCell | 10 1 11 8.0s
cell: GRUCell | 10 1 11 8.6s
cuda | 30 3 33 43.5s
cell: RNNCell | 10 1 11 27.5s
cell: LSTMCell | 10 1 11 7.8s
cell: GRUCell | 10 1 11 8.3s
Lux | 60 6 66 2m05.6s
test | 60 6 66
test/layers | 60 6 66
test/layers/recurrent_tests.jl | 60 6 66
Bidirectional | 60 6 66 2m03.7s
cpu | 30 3 33 1m08.8s
cell: RNNCell | 10 1 11 52.2s
cell: LSTMCell | 10 1 11 8.0s
cell: GRUCell | 10 1 11 8.6s
cuda | 30 3 33 43.5s
cell: RNNCell | 10 1 11 27.5s
cell: LSTMCell | 10 1 11 7.8s
cell: GRUCell | 10 1 11 8.3s
ERROR: LoadError: Some tests did not pass: 120 passed, 0 failed, 12 errored, 0 broken.
in expression starting at /home/nero/Documents/github/Lux.jl/test/runtests.jl:75 |
@avik-pal Hi, I'm sorry to bother you again. I want to continue to push this PR, but it seems that there is a limit to what I can do. Is "lazy reverse rrules" a feature that Zygote is missing? Do I need to open an issue for Zygote? |
Kind of, but not worth opening a Zygote issue for this. If you look at https://buildkite.com/julialang/lux-dot-jl/builds/2994#01906a93-9e6b-4704-a6a1-d3b8e82bb694/350-1748, it is saying that we doing a broadcast The easiest way to resolve this would be to make the Iterators.Reverse into a Vector, since that is effectively materializing a vector of pointers, it is not expensive either. |
@avik-pal Thank you very much for pointing me in the right direction, but I can't really understand "make the Iterators.Reverse into a Vector" here, I think julia> vec = [1,2,3,4,5]
julia> foreach(println, Iterators.reverse(vec))
5
4
3
2
1 Could you give me a few inputs and outputs as examples, and the signature of the function to define? I can implement it if I could |
No that is not what I meant. Try something like x = [cu(rand(5)) for _ in 1:5]
x_rev = Iterators.reverse(x)
vcat(x, x_rev) Now try to differentiate the result of vcat wrt |
@avik-pal So sorry to bother you... Forgive my poor understanding... I totally don't understand the math meaning here! So I don't know how to differentiate this Let's say: x = [1,2,3,4,5]
cat_x = [1,2,3,4,5,1,2,3,4,5]
y = cat_x
# how to differentiate the result??? Thanks all your OSS works, It's great. And I really need this feature. |
b0e692e
to
5692dfe
Compare
issue #687
Please confirm whether the interface meets the requirements. Thank you.