Skip to content

Commit

Permalink
feat: bidirectional RNN + debugging RNNs (#708)
Browse files Browse the repository at this point in the history
* feat: add BidirectionalRNN

* fix: avoid lazy reverse

* fix: handle reverse for operator overloading AD

* fix: allow debug modes for recurrent layers

* fix: soa/aos handling for multigate

---------

Co-authored-by: Avik Pal <[email protected]>
  • Loading branch information
NeroBlackstone and avik-pal authored Jul 10, 2024
1 parent 803e660 commit d9aa5a6
Show file tree
Hide file tree
Showing 11 changed files with 200 additions and 20 deletions.
1 change: 1 addition & 0 deletions docs/src/api/Lux/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ LSTMCell
RNNCell
Recurrence
StatefulRecurrentCell
BidirectionalRNN
```

## Linear Layers
Expand Down
19 changes: 19 additions & 0 deletions ext/LuxReverseDiffExt/LuxReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,25 @@ Lux.apply(m::Lux.AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st)
@inline Lux.__eltype(::TrackedReal{T}) where {T} = T
@inline Lux.__eltype(::AbstractArray{<:TrackedReal{T}}) where {T} = T

@inline Lux.__reverse(x::TrackedArray; dims=:) = ArrayInterface.aos_to_soa(reverse(x; dims))
@inline function Lux.__reverse(x::AbstractArray{<:TrackedReal}; dims=:)
return ArrayInterface.aos_to_soa(reverse(x; dims))
end

# multigate: avoid soa formation
@inline function Lux._gate(x::TrackedArray{T, R, 1}, h::Int, n::Int) where {T, R}
return x[Lux._gate(h, n)]
end
@inline function Lux._gate(x::AbstractVector{<:TrackedReal}, h::Int, n::Int)
return ArrayInterface.aos_to_soa(view(x, Lux._gate(h, n)))
end
@inline function Lux._gate(x::TrackedArray{T, R, 2}, h::Int, n::Int) where {T, R}
return x[Lux._gate(h, n), :]
end
@inline function Lux._gate(x::AbstractMatrix{<:TrackedReal}, h::Int, n::Int)
return ArrayInterface.aos_to_soa(view(x, Lux._gate(h, n), :))
end

@inline function Lux.__convert_eltype(::Type{T}, x::AbstractArray{<:TrackedReal}) where {T}
@warn "`Lux.__convert_eltype` doesn't support converting element types of ReverseDiff \
`TrackedReal` arrays. Currently this is a no-op." maxlog=1
Expand Down
5 changes: 5 additions & 0 deletions ext/LuxTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ Lux.apply(m::Lux.AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st)
@inline Lux.__eltype(::TrackedReal{T}) where {T} = T
@inline Lux.__eltype(::AbstractArray{<:TrackedReal{T}}) where {T} = T

@inline Lux.__reverse(x::TrackedArray; dims=:) = ArrayInterface.aos_to_soa(reverse(x; dims))
@inline function Lux.__reverse(x::AbstractArray{<:TrackedReal}; dims=:)
return ArrayInterface.aos_to_soa(reverse(x; dims))
end

# SimpleChains.jl: DON'T REPLACE THESE WITH @grad_from_chainrules
for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray)
T1 === :AbstractArray && T2 === :AbstractArray && continue
Expand Down
8 changes: 4 additions & 4 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ include("preferences.jl")
include("custom_errors.jl")
include("utils.jl")

# Experimental
include("contrib/contrib.jl")

# Layer Implementations
include("layers/basic.jl")
include("layers/containers.jl")
Expand All @@ -54,9 +57,6 @@ include("layers/extension.jl")
# Pretty Printing
include("layers/display.jl")

# Experimental
include("contrib/contrib.jl")

# Helpful Functionalities
include("helpers/stateful.jl")
include("helpers/compact.jl")
Expand Down Expand Up @@ -93,7 +93,7 @@ export AlphaDropout, Dropout, VariationalHiddenDropout
export BatchNorm, GroupNorm, InstanceNorm, LayerNorm
export WeightNorm
export NoOpLayer, ReshapeLayer, SelectDim, FlattenLayer, WrappedFunction, ReverseSequence
export RNNCell, LSTMCell, GRUCell, Recurrence, StatefulRecurrentCell
export RNNCell, LSTMCell, GRUCell, Recurrence, StatefulRecurrentCell, BidirectionalRNN
export SamePad, TimeLastIndex, BatchLastIndex

export StatefulLuxLayer
Expand Down
2 changes: 1 addition & 1 deletion src/contrib/debug.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ See [`Lux.Experimental.@debug_mode`](@ref) to construct this layer.
"""
@concrete struct DebugLayer{NaNCheck, ErrorCheck} <:
AbstractExplicitContainerLayer{(:layer,)}
layer
layer <: AbstractExplicitLayer
location::KeyPath
end

Expand Down
10 changes: 5 additions & 5 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,22 @@ end

@inline function (r::ReverseSequence{Nothing})(
x::AbstractVector{T}, ps, st::NamedTuple) where {T}
return (isbitstype(T) ? reverse(x) : Iterators.reverse(x)), st
return __reverse(x), st
end

@inline function (r::ReverseSequence{Nothing})(
x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N}
return reverse(x; dims=ndims(x) - 1), st
return __reverse(x; dims=ndims(x) - 1), st
end

@inline function (r::ReverseSequence)(x::AbstractVector{T}, ps, st::NamedTuple) where {T}
r.dim == 1 && return reverse(x), st
throw(DimensionMismatch(lazy"Cannot specify a dimension other than 1 for AbstractVector{T}"))
r.dim == 1 && return __reverse(x), st
throw(ArgumentError("Cannot specify a dimension other than 1 for AbstractVector{T}"))
end

@inline function (r::ReverseSequence)(
x::AbstractArray{T, N}, ps, st::NamedTuple) where {T, N}
return reverse(x; dims=r.dim), st
return __reverse(x; dims=r.dim), st
end

"""
Expand Down
6 changes: 6 additions & 0 deletions src/layers/display.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ function _printable_children(l::Union{PairwiseFusion, Parallel})
return merge((; l.connection), children.layers)
end
_printable_children(l::SkipConnection) = (; l.connection, l.layers)
function _printable_children(l::BidirectionalRNN)
merge_mode = l.model.connection isa Broadcast.BroadcastFunction ? l.model.connection.f :
nothing
return (; merge_mode, forward_cell=l.model.layers.forward_rnn.cell,
backward_cell=l.model.layers.backward_rnn.rnn.cell)
end

_show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for:
_show_leaflike(x::AbstractExplicitLayer) = false
Expand Down
91 changes: 82 additions & 9 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
abstract type AbstractRecurrentCell{use_bias, train_state} <: AbstractExplicitLayer end

const AbstractDebugRecurrentCell = Experimental.DebugLayer{
<:Any, <:Any, <:AbstractRecurrentCell}

function ConstructionBase.constructorof(::Type{<:AbstractRecurrentCell{
use_bias, train_state}}) where {use_bias, train_state}
return AbstractRecurrentCell{use_bias, train_state}
end

# Fallback for vector inputs
function (rnn::AbstractRecurrentCell)(x::AbstractVector, ps, st::NamedTuple)
(y, carry), st_ = rnn(reshape(x, :, 1), ps, st)
Expand Down Expand Up @@ -82,16 +90,16 @@ automatically operate over a sequence of inputs.
For some discussion on this topic, see https://github.com/LuxDL/Lux.jl/issues/472.
"""
struct Recurrence{
R, C <: AbstractRecurrentCell, O <: AbstractTimeSeriesDataBatchOrdering} <:
AbstractExplicitContainerLayer{(:cell,)}
cell::C
ordering::O
@concrete struct Recurrence{R} <: AbstractExplicitContainerLayer{(:cell,)}
cell <: Union{<:AbstractRecurrentCell, <:AbstractDebugRecurrentCell}
ordering <: AbstractTimeSeriesDataBatchOrdering
end

ConstructionBase.constructorof(::Type{<:Recurrence{R}}) where {R} = Recurrence{R}

function Recurrence(cell; ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex(),
return_sequence::Bool=false)
return Recurrence{return_sequence, typeof(cell), typeof(ordering)}(cell, ordering)
return Recurrence{return_sequence}(cell, ordering)
end

_eachslice(x::AbstractArray, ::TimeLastIndex) = _eachslice(x, Val(ndims(x)))
Expand Down Expand Up @@ -164,9 +172,8 @@ update the state with `Lux.update_state(st, :carry, nothing)`.
+ `cell`: Same as `cell`.
+ `carry`: The carry state of the `cell`.
"""
struct StatefulRecurrentCell{C <: AbstractRecurrentCell} <:
AbstractExplicitContainerLayer{(:cell,)}
cell::C
@concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)}
cell <: Union{<:AbstractRecurrentCell, <:AbstractDebugRecurrentCell}
end

function initialstates(rng::AbstractRNG, r::StatefulRecurrentCell)
Expand Down Expand Up @@ -641,3 +648,69 @@ function Base.show(io::IO, g::GRUCell{use_bias, TS}) where {use_bias, TS}
TS && print(io, ", train_state=true")
return print(io, ")")
end

"""
BidirectionalRNN(cell::AbstractRecurrentCell,
backward_cell::Union{AbstractRecurrentCell, Nothing}=nothing;
merge_mode::Union{Function, Nothing}=vcat,
ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex())
Bidirectional RNN wrapper.
## Arguments
- `cell`: A recurrent cell. See [`RNNCell`](@ref), [`LSTMCell`](@ref), [`GRUCell`](@ref),
for how the inputs/outputs of a recurrent cell must be structured.
- `backward_cell`: A optional backward recurrent cell. If `backward_cell` is `nothing`,
the rnn layer instance passed as the `cell` argument will be used to generate the
backward layer automatically. `in_dims` of `backward_cell` should be consistent with
`in_dims` of `cell`
## Keyword Arguments
- `merge_mode`: Function by which outputs of the forward and backward RNNs will be combined.
default value is `vcat`. If `nothing`, the outputs will not be combined.
- `ordering`: The ordering of the batch and time dimensions in the input. Defaults to
`BatchLastIndex()`. Alternatively can be set to `TimeLastIndex()`.
## Inputs
- If `x` is a
+ Tuple or Vector: Each element is fed to the `cell` sequentially.
+ Array (except a Vector): It is spliced along the penultimate dimension and each
slice is fed to the `cell` sequentially.
## Returns
- Merged output of the `cell` and `backward_cell` for the entire sequence.
- Update state of the `cell` and `backward_cell`.
## Parameters
- `NamedTuple` with `cell` and `backward_cell`.
## States
- Same as `cell` and `backward_cell`.
"""
@concrete struct BidirectionalRNN <: AbstractExplicitContainerLayer{(:model,)}
model <: Parallel
end

(rnn::BidirectionalRNN)(x, ps, st::NamedTuple) = rnn.model(x, ps, st)

function BidirectionalRNN(cell::AbstractRecurrentCell,
backward_cell::Union{AbstractRecurrentCell, Nothing}=nothing;
merge_mode::Union{Function, Nothing}=vcat,
ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex())
layer = Recurrence(cell; return_sequence=true, ordering)
backward_rnn_layer = backward_cell === nothing ? layer :
Recurrence(backward_cell; return_sequence=true, ordering)
fuse_op = merge_mode === nothing ? nothing : Broadcast.BroadcastFunction(merge_mode)
return BidirectionalRNN(Parallel(fuse_op;
forward_rnn=layer,
backward_rnn=Chain(;
rev1=ReverseSequence(), rnn=backward_rnn_layer, rev2=ReverseSequence())))
end
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,5 @@ end
@inline __set_refval!(x, y) = (x[] = y)

@inline __eltype(x) = eltype(x)

@inline __reverse(x; dims=:) = reverse(x; dims)
2 changes: 1 addition & 1 deletion test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

@test layer(x, ps, st)[1] == aType(xr)
@test layer(x2, ps, st)[1] == aType(x2rd1)
@test_throws DimensionMismatch layer2(x, ps2, st2)[1]
@test_throws ArgumentError layer2(x, ps2, st2)[1]
@test layer3(x, ps3, st3)[1] == aType(xr)
@test layer2(x2, ps2, st2)[1] == aType(x2rd2)

Expand Down
74 changes: 74 additions & 0 deletions test/layers/recurrent_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,77 @@ end
@test_throws ErrorException Lux._eachslice(x, BatchLastIndex())
end
end

@testitem "Bidirectional" timeout=3000 setup=[SharedTestSetup] tags=[:recurrent_layers] begin
rng = StableRNG(12345)

@testset "$mode" for (mode, aType, device, ongpu) in MODES
@testset "cell: $_cell" for _cell in (RNNCell, LSTMCell, GRUCell)
cell = _cell(3 => 5)
bi_rnn = BidirectionalRNN(cell)
bi_rnn_no_merge = BidirectionalRNN(cell; merge_mode=nothing)
display(bi_rnn)

# Batched Time Series
x = randn(rng, Float32, 3, 4, 2) |> aType
ps, st = Lux.setup(rng, bi_rnn) .|> device
y, st_ = bi_rnn(x, ps, st)
y_, st__ = bi_rnn_no_merge(x, ps, st)

@jet bi_rnn(x, ps, st)
@jet bi_rnn_no_merge(x, ps, st)

@test size(y) == (4,)
@test all(x -> size(x) == (10, 2), y)

@test length(y_) == 2
@test size(y_[1]) == size(y_[2])
@test size(y_[1]) == (4,)
@test all(x -> size(x) == (5, 2), y_[1])

__f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st)))
@eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu

__f = p -> begin
(y1, y2), st_ = bi_rnn_no_merge(x, p, st)
return sum(Base.Fix1(sum, abs2), y1) + sum(Base.Fix1(sum, abs2), y2)
end
@eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu

@testset "backward_cell: $_backward_cell" for _backward_cell in (
RNNCell, LSTMCell, GRUCell)
cell = _cell(3 => 5)
backward_cell = _backward_cell(3 => 5)
bi_rnn = BidirectionalRNN(cell, backward_cell)
bi_rnn_no_merge = BidirectionalRNN(cell, backward_cell; merge_mode=nothing)
display(bi_rnn)

# Batched Time Series
x = randn(rng, Float32, 3, 4, 2) |> aType
ps, st = Lux.setup(rng, bi_rnn) .|> device
y, st_ = bi_rnn(x, ps, st)
y_, st__ = bi_rnn_no_merge(x, ps, st)

@jet bi_rnn(x, ps, st)
@jet bi_rnn_no_merge(x, ps, st)

@test size(y) == (4,)
@test all(x -> size(x) == (10, 2), y)

@test length(y_) == 2
@test size(y_[1]) == size(y_[2])
@test size(y_[1]) == (4,)
@test all(x -> size(x) == (5, 2), y_[1])

__f = p -> sum(Base.Fix1(sum, abs2), first(bi_rnn(x, p, st)))
@eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu

__f = p -> begin
(y1, y2), st_ = bi_rnn_no_merge(x, p, st)
return sum(Base.Fix1(sum, abs2), y1) + sum(Base.Fix1(sum, abs2), y2)
end
@eval @test_gradients $__f $ps atol=1e-2 rtol=1e-2 gpu_testing=$ongpu
end
end
end
end

1 comment on commit d9aa5a6

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: d9aa5a6 Previous: 849fea6 Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3643 ns 3668.125 ns 0.99
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7176 ns 7265.333333333333 ns 0.99
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20829 ns 21330 ns 0.98
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9808.2 ns 9776.4 ns 1.00
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 8944 ns 9087 ns 0.98
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4470.875 ns 4558.625 ns 0.98
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1152.5869565217392 ns 1163.4855072463768 ns 0.99
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1112.0065789473683 ns 1119.0544871794873 ns 0.99
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1169.9583333333333 ns 1180.2686567164178 ns 0.99
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1773.142857142857 ns 1766.6949152542372 ns 1.00
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 179.55182072829132 ns 178.72408963585434 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17172 ns 17182 ns 1.00
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 16852 ns 17062 ns 0.99
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 36989 ns 37560 ns 0.98
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 28217.5 ns 28904 ns 0.98
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 19907 ns 21470 ns 0.93
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17743 ns 17412 ns 1.02
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4329.571428571428 ns 4353.857142857143 ns 0.99
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3848.375 ns 3851 ns 1.00
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3965 ns 3946.125 ns 1.00
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4887.857142857143 ns 4869.142857142857 ns 1.00
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1664.1 ns 1652.1 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 38592071 ns 48481192 ns 0.80
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 57598840.5 ns 57589911 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 70928337 ns 112051236 ns 0.63
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 88908288 ns 107350078 ns 0.83
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 72637494 ns 107753804 ns 0.67
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11652507 ns 11633974 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 6961815 ns 8368822 ns 0.83
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7118503 ns 6994384 ns 1.02
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 7056081 ns 6961674 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 10185966 ns 18289304 ns 0.56
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6384646.5 ns 6377896.5 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 693015046 ns 708146792 ns 0.98
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2564503502 ns 2538464789 ns 1.01
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 141044477 ns 130550753 ns 1.08
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 755933134 ns 940215170 ns 0.80
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 2887970897 ns 3222185770 ns 0.90
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 209654834 ns 200541349 ns 1.05
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 652305217 ns 725094261.5 ns 0.90
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2592391657 ns 2700272370 ns 0.96
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 123578859 ns 133034897.5 ns 0.93
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 175231208.5 ns 174289604.5 ns 1.01
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 652645593.5 ns 656945864.5 ns 0.99
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 45509251 ns 45333461 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 164808995.5 ns 164456981 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 643429956 ns 639772875 ns 1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 30232427 ns 30105099 ns 1.00
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 187781212 ns 230482052 ns 0.81
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 711012026.5 ns 896237645 ns 0.79
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 36478557.5 ns 39999991 ns 0.91
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1238395972.5 ns 1229029098.5 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1864792034.5 ns 1857882986.5 ns 1.00
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2335459386 ns 2500640456 ns 0.93
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2510833048 ns 2705526315 ns 0.93
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1834047291 ns 1857087780.5 ns 0.99
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 327114030.5 ns 354871769 ns 0.92
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 326897983 ns 321331574 ns 1.02
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 322709878 ns 319443241 ns 1.01
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 426753786 ns 365388358 ns 1.17
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 11697996.5 ns 11707792 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 17846230 ns 17872220 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19079864 ns 19026670 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23775806 ns 23710928 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 17783235 ns 17883894 ns 0.99
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1158115.5 ns 1151164 ns 1.01
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 2115024 ns 2521988.5 ns 0.84
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2128008 ns 2058845.5 ns 1.03
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2080563.5 ns 2039299 ns 1.02
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2067390 ns 2075005.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 198766.5 ns 196078 ns 1.01
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 293619 ns 291156 ns 1.01
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 265135.5 ns 264076 ns 1.00
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 365413.5 ns 364333.5 ns 1.00
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 408393 ns 406242 ns 1.01
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 275595 ns 273453 ns 1.01
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 408263 ns 406201 ns 1.01
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 83396 ns 83697 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 81152 ns 81232 ns 1.00
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 82895 ns 81733 ns 1.01
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 86902 ns 86647.5 ns 1.00
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104425 ns 104416 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 197729486 ns 192747057.5 ns 1.03
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 326930373 ns 327202184 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 398584711 ns 449880668 ns 0.89
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 443775723.5 ns 474869215 ns 0.93
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 347290902 ns 412132513.5 ns 0.84
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 322621922 ns 322865631.5 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 44383622 ns 51483926 ns 0.86
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 44477990 ns 43968482 ns 1.01
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 44018296 ns 43749705 ns 1.01
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 53487940 ns 70756328 ns 0.76
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 27983719 ns 28106570 ns 1.00
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 18934728.5 ns 18818603 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19606172 ns 19497998 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23436883.5 ns 23422262 ns 1.00
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24235917 ns 24150485 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19708092.5 ns 19687573 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6514418 ns 6494414 ns 1.00
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6534616 ns 6523056 ns 1.00
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6507885 ns 6473309 ns 1.01
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6515966 ns 6503950 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.