Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 11, 2024
1 parent b9eab55 commit f96bd58
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Expand Down Expand Up @@ -55,7 +54,6 @@ MLDataDevices = "1.4.2"
MLUtils = "0.4"
MPI = "0.20.19"
MacroTools = "0.5"
Metal = "1.4.2"
NCCL = "0.1.1"
NNlib = "0.9.22"
OneHotArrays = "0.2.4"
Expand Down
4 changes: 2 additions & 2 deletions test/ext_common/recurrent_gpu_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ out_from_state(state) = state
function recurrent_cell_loss(cell, seq, state)
out = []
for xt in seq
state = Flux.scan(cell, x, state)
state = cell(xt, state)
yt = out_from_state(state)
out = vcat(out, [yt])
end
return mean(stack(y, dims = 2))
return mean(stack(out, dims = 2))
end

@testset "RNNCell GPU AD" begin
Expand Down

0 comments on commit f96bd58

Please sign in to comment.