Skip to content

Commit

Permalink
feat: compile training loop automatically using reactant (#969)
Browse files Browse the repository at this point in the history
* feat: compile training loop automatically using reactant

* refactor: add a level of indirection for the train_step

* feat: directly compile step + grad function

* fix: make note of current issue with inplace update

* chore: bump minimum reactant version

* test: setup specific reactant test group

* ci: temporarily disable other tests (drop me)

* test: fix installation of Reactant

* test: start adding loss function tests

* fix: xlogx and xlogy now work with Reactant scalars

* feat: support regression losses + tests

* test: classification losses

* fix: more specialization

* fix: support all loss functions

* chore: comments

* fix: bump reactant version

* test: don't run reactant tests on windows

* test: temporarily disable more tests

* fix: reactant GPU support

* fix: remove old LossFunctions.jl dispatches

* test: try using MSELoss directly

* ci: reactivate all tests

* ci(windows): don't test Reactant on windows
  • Loading branch information
avik-pal authored Oct 9, 2024
1 parent 77eb5fb commit 1b0d6f8
Show file tree
Hide file tree
Showing 15 changed files with 487 additions and 33 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:
- "recurrent_layers"
- "eltype_match"
- "fluxcompat"
- "reactant"
include:
- version: "1.10"
os: macos-latest
Expand Down
21 changes: 11 additions & 10 deletions .github/workflows/CIPreRelease.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@ jobs:
os:
- ubuntu-latest
test_group:
- "core_layers"
- "contrib"
- "helpers"
- "distributed"
- "normalize_layers"
- "others"
- "autodiff"
- "recurrent_layers"
- "eltype_match"
- "fluxcompat"
# - "core_layers"
# - "contrib"
# - "helpers"
# - "distributed"
# - "normalize_layers"
# - "others"
# - "autodiff"
# - "recurrent_layers"
# - "eltype_match"
# - "fluxcompat"
- "reactant"
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -59,6 +60,7 @@ LuxLossFunctionsExt = "LossFunctions"
LuxMLUtilsExt = "MLUtils"
LuxMPIExt = "MPI"
LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
LuxReactantExt = ["Enzyme", "Reactant"]
LuxReverseDiffExt = ["FunctionWrappers", "ReverseDiff"]
LuxSimpleChainsExt = "SimpleChains"
LuxTrackerExt = "Tracker"
Expand All @@ -68,7 +70,7 @@ LuxZygoteExt = "Zygote"
ADTypes = "1.8.1"
Adapt = "4"
ArgCheck = "2.3"
ArrayInterface = "7.9"
ArrayInterface = "7.10"
CUDA = "5.3.2"
ChainRulesCore = "1.24"
Compat = "4.15"
Expand All @@ -87,7 +89,7 @@ LinearAlgebra = "1.10"
LossFunctions = "0.11.1"
LuxCore = "1"
LuxLib = "1.3"
MLDataDevices = "1.1"
MLDataDevices = "1.2"
MLUtils = "0.4.4"
MPI = "0.20.19"
MacroTools = "0.5.13"
Expand All @@ -97,6 +99,7 @@ NNlib = "0.9.24"
Optimisers = "0.3.3"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.3"
Reexport = "1.2.2"
ReverseDiff = "1.15"
SIMDTypes = "0.1"
Expand Down
9 changes: 4 additions & 5 deletions ext/LuxEnzymeExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function Lux.Training.compute_gradients(
function Lux.Training.compute_gradients_impl(
ad::AutoEnzyme, obj_fn::F, data, ts::TrainState) where {F}
dps = Lux.recursive_make_zero(ts.parameters)

Expand All @@ -20,9 +20,8 @@ end
const AUTODIFF_CACHE_TYPE = TrainingBackendCache{
<:AutoEnzyme, False, PS, <:NamedTuple{(:obj_fn, :st_wrap, :stats_wrap)}} where {PS}

function Lux.Training.compute_gradients(
function Lux.Training.compute_gradients_impl(
::AutoEnzyme, obj_fn::F, data, ts::TrainState{<:AUTODIFF_CACHE_TYPE, F}) where {F}
# dps = Lux.recursive_make_zero!!(ts.cache.dparameters)
Enzyme.make_zero!(ts.cache.dparameters)
dps = ts.cache.dparameters

Expand All @@ -36,7 +35,7 @@ function Lux.Training.compute_gradients(
return dps, loss, ts.cache.extras.stats_wrap[], ts
end

function Lux.Training.compute_gradients(ad::AutoEnzyme, obj_fn::F, data,
function Lux.Training.compute_gradients_impl(ad::AutoEnzyme, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{<:AutoEnzyme, False}}) where {F}
@warn "Detected calls to `compute_gradients(::AutoEnzyme, ...)` with objective \
function that is changing across function calls. This can lead to the \
Expand All @@ -56,7 +55,7 @@ end
const AUTODIFF_THUNK_CACHE_TYPE = TrainingBackendCache{
<:AutoEnzyme, False, PS, <:NamedTuple{(:forward, :reverse)}} where {PS}

function Lux.Training.compute_gradients(::AutoEnzyme, obj_fn::F, data,
function Lux.Training.compute_gradients_impl(::AutoEnzyme, obj_fn::F, data,
ts::TrainState{<:AUTODIFF_THUNK_CACHE_TYPE, F}) where {F}
dps = Lux.recursive_make_zero!!(ts.cache.dparameters)
params = Duplicated(ts.parameters, dps)
Expand Down
14 changes: 14 additions & 0 deletions ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module LuxReactantExt

using Enzyme: Enzyme, Const, Duplicated, Active
using Optimisers: Optimisers
using Reactant: Reactant, @compile, TracedRArray
using Setfield: @set!
using Static: False

using Lux: Lux, LuxOps, Training
using Lux.Training: TrainingBackendCache, ReactantBackend

include("training.jl")

end
92 changes: 92 additions & 0 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
function Lux.Training.compute_gradients_impl(
backend::ReactantBackend, objective_function::F,
data, ts::Training.TrainState) where {F}
compiled_gradient_function = @compile compute_gradients_internal(
objective_function, ts.model, data, ts.parameters, ts.states)

grads, loss, stats, st = compiled_gradient_function(
objective_function, ts.model, data, ts.parameters, ts.states)

cache = TrainingBackendCache(backend, False(), nothing, (; compiled_gradient_function))
@set! ts.cache = cache
@set! ts.objective_function = objective_function
@set! ts.states = st
return grads, loss, stats, ts
end

function Lux.Training.compute_gradients_impl(::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F}
grads, loss, stats, st = ts.cache.extras.compiled_gradient_function(
obj_fn, ts.model, data, ts.parameters, ts.states)
@set! ts.states = st
return grads, loss, stats, ts
end

function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F}
dps = Enzyme.make_zero(ps)
_, (loss, stₙ, stats) = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))
return dps, loss, stats, stₙ
end

for inplace in ("!", "")
fname = Symbol(:single_train_step_impl, inplace)
internal_fn = Symbol(:compute_gradients_internal_and_step, inplace)

@eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F,
data, ts::Training.TrainState) where {F}
compiled_grad_and_step_function = @compile $(internal_fn)(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)

grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)

cache = TrainingBackendCache(
backend, False(), nothing, (; compiled_grad_and_step_function))
@set! ts.cache = cache
@set! ts.objective_function = objective_function
@set! ts.states = st
@set! ts.parameters = ps
@set! ts.optimizer_state = opt_state
@set! ts.step = ts.step + 1

return grads, loss, stats, ts
end

@eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F}
grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function(
obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state)

@set! ts.states = st
@set! ts.parameters = ps
@set! ts.optimizer_state = opt_state
@set! ts.step = ts.step + 1

return grads, loss, stats, ts
end
end

function compute_gradients_internal_and_step(objective_function::F, model, data, ps,
st, opt_state) where {F}
dps = Enzyme.make_zero(ps)
_, (loss, stₙ, stats) = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))
opt_state, ps = Optimisers.update(opt_state, ps, dps)
return dps, ps, loss, stats, stₙ, opt_state
end

function compute_gradients_internal_and_step!(objective_function::F, model, data, ps,
st, opt_state) where {F}
dps = Enzyme.make_zero(ps)
_, (loss, stₙ, stats) = Enzyme.autodiff(
Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model),
Duplicated(ps, dps), Const(st), Const(data))
# XXX: Inplace updates not actually inplace
opt_state, ps = Optimisers.update!(opt_state, ps, dps)
return dps, ps, loss, stats, stₙ, opt_state
end
10 changes: 5 additions & 5 deletions ext/LuxReverseDiffExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Uncompiled ReverseDiff
function Lux.Training.compute_gradients(
function Lux.Training.compute_gradients_impl(
ad::AutoReverseDiff{false}, obj_fn::F, data, ts::TrainState) where {F}
@set! ts.cache = TrainingBackendCache(
ad, True(), Lux.recursive_make_zero(ts.parameters), nothing)
@set! ts.objective_function = obj_fn
return Lux.Training.compute_gradients(ad, obj_fn, data, ts)
end

function Lux.Training.compute_gradients(::AutoReverseDiff{false}, obj_fn::F, data,
function Lux.Training.compute_gradients_impl(::AutoReverseDiff{false}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{false}}}) where {F}
dparams = Training.dparameters(ts.cache)
tape = ReverseDiff.InstructionTape()
Expand All @@ -24,7 +24,7 @@ function Lux.Training.compute_gradients(::AutoReverseDiff{false}, obj_fn::F, dat
end

# Compiled ReverseDiff
function Lux.Training.compute_gradients(
function Lux.Training.compute_gradients_impl(
ad::AutoReverseDiff{true}, obj_fn::F, data, ts::TrainState) where {F}
@set! ts.cache = TrainingBackendCache(
ad, True(), Lux.recursive_make_zero(ts.parameters),
Expand All @@ -35,7 +35,7 @@ function Lux.Training.compute_gradients(
end

## Tape hasn't been compiled yet / Function mismatch so recompile
function Lux.Training.compute_gradients(ad::AutoReverseDiff{true}, obj_fn::F, data,
function Lux.Training.compute_gradients_impl(ad::AutoReverseDiff{true}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{true}}}) where {F}
if LuxCore.statelength(ts.states) != 0
throw(ArgumentError("AutoReverseDiff(; compile=true) is not supported for Lux \
Expand Down Expand Up @@ -82,7 +82,7 @@ function Lux.Training.compute_gradients(ad::AutoReverseDiff{true}, obj_fn::F, da
return dparams, ReverseDiff.value(loss), NamedTuple(), ts
end

function Lux.Training.compute_gradients(::AutoReverseDiff{true}, obj_fn::F, data,
function Lux.Training.compute_gradients_impl(::AutoReverseDiff{true}, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{AutoReverseDiff{true}}, F}) where {F}
(; ps_cache, data_cache, output) = ts.cache.extras

Expand Down
4 changes: 2 additions & 2 deletions ext/LuxTrackerExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function Lux.Training.compute_gradients(::AutoTracker, obj_fn::F, data,
function Lux.Training.compute_gradients_impl(::AutoTracker, obj_fn::F, data,
ts::TrainState{<:TrainingBackendCache{AutoTracker}}) where {F}
dps = Training.dparameters(ts.cache)
ps_tracked = construct_tracked_params(ts.parameters, dps)
Expand All @@ -13,7 +13,7 @@ function Lux.Training.compute_gradients(::AutoTracker, obj_fn::F, data,
return dps, loss.data, stats, ts
end

function Lux.Training.compute_gradients(
function Lux.Training.compute_gradients_impl(
ad::AutoTracker, obj_fn::F, data, ts::TrainState) where {F}
grads = Lux.recursive_make_zero(ts.parameters)
cache = TrainingBackendCache(ad, True(), grads, nothing)
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxZygoteExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function Lux.Training.compute_gradients(
function Lux.Training.compute_gradients_impl(
::AutoZygote, objective_function::F, data, ts::Lux.Training.TrainState) where {F}
(loss, st, stats), back = Zygote.pullback(
objective_function, ts.model, ts.parameters, ts.states, data)
Expand Down
5 changes: 3 additions & 2 deletions src/helpers/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ function huber_loss(x::T1, y::T2, δ::T3) where {T1, T2, T3}
T = promote_type(T1, T2, T3)
diff = x - y
abs_diff = abs(diff)
return ifelse(abs_diff δ, T(0.5) * abs2(diff), δ * (abs_diff - T(0.5) * δ))
return ifelse(
abs_diff δ, convert(T, 0.5) * abs2(diff), δ * (abs_diff - convert(T, 0.5) * δ))
end
has_custom_derivative(::typeof(huber_loss)) = true
function derivative(::typeof(huber_loss), x::T, y::T2, δ::T3) where {T, T2, T3}
Expand Down Expand Up @@ -148,7 +149,7 @@ function derivative(::typeof(l2_hinge_loss), x::T1, y::T2) where {T1, T2}
end

function siamese_contrastive_loss(x::T1, y::T2, margin=true) where {T1, T2}
return (true - y) * x^2 + y * max(promote_type(T1, T2)(false), margin - x)^2
return (true - y) * x^2 + y * max(convert(promote_type(T1, T2), false), margin - x)^2
end

poisson_loss(x::T1, y::T2, ϵ) where {T1, T2} = x - xlogy(y, x + get_ϵ(T1, ϵ))
Expand Down
43 changes: 38 additions & 5 deletions src/helpers/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using Static: StaticBool, Static, False, True

using ..Lux: Lux
using LuxCore: LuxCore, AbstractLuxLayer
using MLDataDevices: XLADevice, get_device_type, get_device, cpu_device

"""
TrainState
Expand Down Expand Up @@ -61,7 +62,13 @@ Constructor for [`TrainState`](@ref).
[`TrainState`](@ref) object.
"""
function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule)
st_opt = Optimisers.setup(optimizer, ps)
dev = get_device(ps)
st_opt = if dev isa XLADevice
ps_cpu = ps |> cpu_device()
Optimisers.setup(optimizer, ps_cpu) |> dev
else
Optimisers.setup(optimizer, ps)
end
return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0)
end

Expand Down Expand Up @@ -96,6 +103,8 @@ function Base.show(io::IO, ::MIME"text/plain", ts::TrainState)
print(io, "\n objective_function: ", nameof(typeof(ts.objective_function)))
end

struct ReactantBackend end

const APPLY_GRAD_DOCSTRING = """
## Arguments
Expand Down Expand Up @@ -183,7 +192,20 @@ A 4-Tuple containing:
returned in step `i + 1` might be aliased by the old gradients. If you want to prevent
this, simply use `copy(grads)` or `deepcopy(grads)` to make a copy of the gradients.
"""
function compute_gradients(ad::AbstractADType, ::F, _, ::TrainState) where {F}
function compute_gradients(ad, obj_fn::F, data, ts::TrainState) where {F}
dev_type = get_device_type((ts.parameters, ts.states))
return compute_gradients_impl(maybe_wrap_adtype(ad, dev_type), obj_fn, data, ts)
end

maybe_wrap_adtype(backend::ReactantBackend, _) = backend
maybe_wrap_adtype(ad::AbstractADType, _) = ad
function maybe_wrap_adtype(ad::AbstractADType, ::Type{XLADevice})
ad isa AutoEnzyme && return ReactantBackend()
throw(ArgumentError("Computing gradients for models on XLA is supported only with \
Enzyme.jl (`AutoEnzyme`)."))
end

function compute_gradients_impl(ad, ::F, _, ts::TrainState) where {F}
return check_if_compute_gradients_implemented(ad)
end

Expand All @@ -192,6 +214,10 @@ function check_if_compute_gradients_implemented(::T) where {T <: AbstractADType}
yet!"))
end

function check_if_compute_gradients_implemented(::ReactantBackend)
throw(ArgumentError("Load `Reactant` with `using Reactant` before using this function!"))
end

for package in (:Zygote, :Tracker, :ReverseDiff, :Enzyme)
adtype = Symbol(:Auto, package)
msg = "Load `$(package)` with `using $(package)`/`import $(package)` before using this \
Expand Down Expand Up @@ -244,7 +270,10 @@ only the parameters in `ts` are updated inplace. Users should be using the retur
object for further training steps, else there is no caching and performance will be
suboptimal (and absolutely terrible for backends like `AutoReactant`).
"""
function single_train_step! end
function single_train_step!(backend, obj_fn::F, data, ts::TrainState) where {F}
backend = maybe_wrap_adtype(backend, get_device_type((ts.parameters, ts.states)))
return single_train_step_impl!(backend, obj_fn, data, ts)
end

"""
single_train_step(backend, obj_fn::F, data, ts::TrainState)
Expand All @@ -259,10 +288,14 @@ In most cases you should use [`single_train_step!`](@ref) instead of this functi
Returned values are the same as [`compute_gradients`](@ref).
"""
function single_train_step end
function single_train_step(backend, obj_fn::F, data, ts::TrainState) where {F}
backend = maybe_wrap_adtype(backend, get_device_type((ts.parameters, ts.states)))
return single_train_step_impl(backend, obj_fn, data, ts)
end

for inplace in ("!", "")
step, apply_fn = Symbol(:single_train_step, inplace), Symbol(:apply_gradients, inplace)
step = Symbol(:single_train_step_impl, inplace)
apply_fn = Symbol(:apply_gradients, inplace)
@eval function $(step)(backend, obj_fn::F, data, ts::TrainState) where {F}
grads, loss, stats, ts = compute_gradients(backend, obj_fn, data, ts)
ts = $(apply_fn)(ts, grads)
Expand Down
Loading

1 comment on commit 1b0d6f8

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lux Benchmarks

Benchmark suite Current: 1b0d6f8 Previous: 77eb5fb Ratio
Dense(512 => 512, identity)(512 x 128)/forward/CPU/2 thread(s) 412250 ns 414958 ns 0.99
Dense(512 => 512, identity)(512 x 128)/forward/CPU/4 thread(s) 244083 ns 322541 ns 0.76
Dense(512 => 512, identity)(512 x 128)/forward/CPU/8 thread(s) 322041 ns 323167 ns 1.00
Dense(512 => 512, identity)(512 x 128)/forward/CPU/1 thread(s) 739625 ns 739562.5 ns 1.00
Dense(512 => 512, identity)(512 x 128)/forward/GPU/CUDA 43576 ns 44543 ns 0.98
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/2 thread(s) 1368688 ns 1335729 ns 1.02
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/4 thread(s) 1198625 ns 485000 ns 2.47
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/8 thread(s) 13918417 ns 14073833 ns 0.99
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/1 thread(s) 929312.5 ns 2211312.5 ns 0.42
Dense(512 => 512, identity)(512 x 128)/zygote/GPU/CUDA 190464 ns 194175 ns 0.98
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/2 thread(s) 1348750 ns 1374959 ns 0.98
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/4 thread(s) 1282083 ns 596188 ns 2.15
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/8 thread(s) 13837312.5 ns 13290875.5 ns 1.04
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/1 thread(s) 987250 ns 2199270.5 ns 0.45
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 1655917 ns 1665666 ns 0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1089000 ns 1186833.5 ns 0.92
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1532499.5 ns 1536854.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 2439708 ns 2912062.5 ns 0.84
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/GPU/CUDA 211500 ns 213313 ns 0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 12136437.5 ns 12145187.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 8847479 ns 9486083 ns 0.93
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 9240938 ns 9213083 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 17956208 ns 18563708 ns 0.97
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1905747 ns 1921274.5 ns 0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 17305250 ns 17291000 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 13985416 ns 14310062.5 ns 0.98
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 14505584 ns 14535333 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 21107833 ns 21812208 ns 0.97
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 249894083 ns 250754270.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 148856208 ns 174424541 ns 0.85
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 115718875 ns 115532521 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 101619125 ns 446573667 ns 0.23
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/GPU/CUDA 5485492 ns 5489738 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1228009625 ns 1222307709 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 931338167 ns 543403209 ns 1.71
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 829169479 ns 832977124.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 628483479 ns 1653507000 ns 0.38
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 38151835 ns 34972271 ns 1.09
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1134889125 ns 1142743917 ns 0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 992066062.5 ns 686139667 ns 1.45
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1309459854 ns 1325824667 ns 0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 745440771 ns 1748793708.5 ns 0.43
lenet(28, 28, 1, 32)/forward/CPU/2 thread(s) 1092042 ns 1120083 ns 0.97
lenet(28, 28, 1, 32)/forward/CPU/4 thread(s) 1645709 ns 820062 ns 2.01
lenet(28, 28, 1, 32)/forward/CPU/8 thread(s) 3466333 ns 3738667 ns 0.93
lenet(28, 28, 1, 32)/forward/CPU/1 thread(s) 957250 ns 785458.5 ns 1.22
lenet(28, 28, 1, 32)/forward/GPU/CUDA 270549.5 ns 280004 ns 0.97
lenet(28, 28, 1, 32)/zygote/CPU/2 thread(s) 2979042 ns 2992209 ns 1.00
lenet(28, 28, 1, 32)/zygote/CPU/4 thread(s) 4110542 ns 2457292 ns 1.67
lenet(28, 28, 1, 32)/zygote/CPU/8 thread(s) 10529229 ns 10152708 ns 1.04
lenet(28, 28, 1, 32)/zygote/CPU/1 thread(s) 3308833 ns 3200812.5 ns 1.03
lenet(28, 28, 1, 32)/zygote/GPU/CUDA 1070477 ns 1093024 ns 0.98
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 2350792 ns 2350792 ns 1
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1364187.5 ns 1546500 ns 0.88
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1709000 ns 1703875 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 3666666.5 ns 4314041 ns 0.85
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/GPU/CUDA 210396 ns 214178 ns 0.98
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 20275459 ns 20293542 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 16981437 ns 17667520.5 ns 0.96
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 18162375 ns 18215687.5 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 26198500 ns 26742770.5 ns 0.98
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1979369 ns 1989179 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 46206895.5 ns 44338833 ns 1.04
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 41017187.5 ns 29803125 ns 1.38
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 41176208.5 ns 41253750.5 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 47588917 ns 49627062.5 ns 0.96
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 4669000 ns 4666292 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2603916 ns 2867729.5 ns 0.91
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2999833 ns 3027042 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 7252188 ns 8637375 ns 0.84
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/GPU/CUDA 517525.5 ns 512379 ns 1.01
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 40878729.5 ns 40744375 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 33994250 ns 34861000 ns 0.98
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 33958333 ns 34130125 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 51263292 ns 53656458 ns 0.96
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 3013320.5 ns 3039460 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 113392541.5 ns 109854709 ns 1.03
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 136850541 ns 60211500 ns 2.27
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 250011854.5 ns 244742208.5 ns 1.02
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 95314208 ns 100222291.5 ns 0.95
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 270234083 ns 270538750 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 157676542 ns 187102791.5 ns 0.84
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 128100708 ns 128131083.5 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 144520145.5 ns 496544584 ns 0.29
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/GPU/CUDA 7091283 ns 7095920 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1503173291.5 ns 1493043979 ns 1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 1201978125 ns 820794750 ns 1.46
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 1103595666.5 ns 1089880791.5 ns 1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 1028790125.5 ns 2057983854 ns 0.50
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 33654931 ns 33958491 ns 0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 2089411062.5 ns 2031772896 ns 1.03
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 1851532083 ns 1169902125 ns 1.58
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 2117297604.5 ns 2031263062.5 ns 1.04
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 1605439208 ns 2621950583 ns 0.61
lenet(28, 28, 1, 128)/forward/CPU/2 thread(s) 2066438 ns 2080791 ns 0.99
lenet(28, 28, 1, 128)/forward/CPU/4 thread(s) 3005354 ns 1266458.5 ns 2.37
lenet(28, 28, 1, 128)/forward/CPU/8 thread(s) 7102958.5 ns 7120666.5 ns 1.00
lenet(28, 28, 1, 128)/forward/CPU/1 thread(s) 2151875 ns 2469583 ns 0.87
lenet(28, 28, 1, 128)/forward/GPU/CUDA 270072.5 ns 271494.5 ns 0.99
lenet(28, 28, 1, 128)/zygote/CPU/2 thread(s) 9657334 ns 9645292 ns 1.00
lenet(28, 28, 1, 128)/zygote/CPU/4 thread(s) 11945459 ns 6578875 ns 1.82
lenet(28, 28, 1, 128)/zygote/CPU/8 thread(s) 23020875 ns 23723729.5 ns 0.97
lenet(28, 28, 1, 128)/zygote/CPU/1 thread(s) 10467750 ns 11743000 ns 0.89
lenet(28, 28, 1, 128)/zygote/GPU/CUDA 1095059 ns 1108810 ns 0.99
vgg16(32, 32, 3, 32)/forward/CPU/2 thread(s) 381251625 ns 378620500 ns 1.01
vgg16(32, 32, 3, 32)/forward/CPU/4 thread(s) 309062375 ns 148222625 ns 2.09
vgg16(32, 32, 3, 32)/forward/CPU/8 thread(s) 241236375 ns 232625416.5 ns 1.04
vgg16(32, 32, 3, 32)/forward/CPU/1 thread(s) 180294333.5 ns 452981958.5 ns 0.40
vgg16(32, 32, 3, 32)/forward/GPU/CUDA 4847355 ns 4877366.5 ns 0.99
vgg16(32, 32, 3, 32)/zygote/CPU/2 thread(s) 1146004375 ns 1151789959 ns 0.99
vgg16(32, 32, 3, 32)/zygote/CPU/4 thread(s) 966522375 ns 608267042 ns 1.59
vgg16(32, 32, 3, 32)/zygote/CPU/8 thread(s) 1026283833 ns 958784875 ns 1.07
vgg16(32, 32, 3, 32)/zygote/CPU/1 thread(s) 662156542 ns 1398102250 ns 0.47
vgg16(32, 32, 3, 32)/zygote/GPU/CUDA 17798543 ns 17573518 ns 1.01
lenet(28, 28, 1, 64)/forward/CPU/2 thread(s) 1050458 ns 1044000 ns 1.01
lenet(28, 28, 1, 64)/forward/CPU/4 thread(s) 1656750 ns 966666.5 ns 1.71
lenet(28, 28, 1, 64)/forward/CPU/8 thread(s) 6491250 ns 5483500 ns 1.18
lenet(28, 28, 1, 64)/forward/CPU/1 thread(s) 1312792 ns 1369000 ns 0.96
lenet(28, 28, 1, 64)/forward/GPU/CUDA 270319.5 ns 277684.5 ns 0.97
lenet(28, 28, 1, 64)/zygote/CPU/2 thread(s) 6504813 ns 6395583 ns 1.02
lenet(28, 28, 1, 64)/zygote/CPU/4 thread(s) 13132417 ns 4649000 ns 2.82
lenet(28, 28, 1, 64)/zygote/CPU/8 thread(s) 19754250 ns 18457937.5 ns 1.07
lenet(28, 28, 1, 64)/zygote/CPU/1 thread(s) 5741521 ns 6087375 ns 0.94
lenet(28, 28, 1, 64)/zygote/GPU/CUDA 1124270 ns 1153126 ns 0.97
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 70469479 ns 70616187.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 43706291.5 ns 34338895.5 ns 1.27
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 39518625 ns 39546146 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 35367542 ns 132480333 ns 0.27
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/GPU/CUDA 1851430 ns 1837859.5 ns 1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 356004604 ns 354650895.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 270290792 ns 158687854 ns 1.70
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 254164750 ns 253835396.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 271950333.5 ns 535065979.5 ns 0.51
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 16539357 ns 16493787 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 395899500 ns 394789208 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 372060292 ns 245506292 ns 1.52
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 713782625 ns 682534167 ns 1.05
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 447779125 ns 711436834 ns 0.63
vgg16(32, 32, 3, 128)/forward/CPU/2 thread(s) 1190490459 ns 1186860667 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/4 thread(s) 832670062.5 ns 435266750 ns 1.91
vgg16(32, 32, 3, 128)/forward/CPU/8 thread(s) 629944291 ns 628427791 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/1 thread(s) 681507396 ns 1780484229 ns 0.38
vgg16(32, 32, 3, 128)/forward/GPU/CUDA 12475051 ns 12477417 ns 1.00
vgg16(32, 32, 3, 128)/zygote/CPU/2 thread(s) 3708044854 ns 3652621271 ns 1.02
vgg16(32, 32, 3, 128)/zygote/CPU/4 thread(s) 2828581542 ns 1639329875 ns 1.73
vgg16(32, 32, 3, 128)/zygote/CPU/8 thread(s) 2698925958 ns 2709465041 ns 1.00
vgg16(32, 32, 3, 128)/zygote/CPU/1 thread(s) 2137669604.5 ns 5075123916 ns 0.42
vgg16(32, 32, 3, 128)/zygote/GPU/CUDA 49415932 ns 49797376 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 3423125 ns 3423958 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2078500 ns 2097249.5 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2518458 ns 2534854 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 4870375 ns 6018625 ns 0.81
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/GPU/CUDA 586699.5 ns 580639 ns 1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 25989500 ns 25949208 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 19069958.5 ns 20274833 ns 0.94
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 19259312 ns 19561271 ns 0.98
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 36800833 ns 39224583 ns 0.94
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2993892 ns 2980196 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 54216125 ns 55399562.5 ns 0.98
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 83642959 ns 28378292 ns 2.95
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 174413208.5 ns 172128937.5 ns 1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 42857708.5 ns 45636375 ns 0.94
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 1784458 ns 1783542 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1095646 ns 1197959 ns 0.91
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1575292 ns 1577021.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 2364687 ns 3027083.5 ns 0.78
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/GPU/CUDA 216504.5 ns 218302 ns 0.99
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 12531833 ns 12541854.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 9200375 ns 9966458 ns 0.92
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 9626292 ns 9641875 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 18391667 ns 18982208 ns 0.97
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1950268 ns 1943759 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 17650333.5 ns 17642208 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 14301166 ns 14745479 ns 0.97
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 14560250.5 ns 14622083 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 21506145.5 ns 22196749.5 ns 0.97
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 70470354 ns 70503791.5 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 43665542 ns 34154375 ns 1.28
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 39582249.5 ns 39724625.5 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 35175625 ns 133426312.5 ns 0.26
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/GPU/CUDA 1838843 ns 1855504 ns 0.99
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 360077895.5 ns 357508542 ns 1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 349062958.5 ns 236762959 ns 1.47
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 305213917 ns 305563667 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 462206583 ns 731068500 ns 0.63
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 13925027 ns 13898567 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 417720542 ns 418493479 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 426193583 ns 253429500 ns 1.68
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 717833375.5 ns 696829083.5 ns 1.03
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 394045333.5 ns 717012750 ns 0.55
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/2 thread(s) 1908458 ns 1657812.5 ns 1.15
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/4 thread(s) 1382145.5 ns 1559500 ns 0.89
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/8 thread(s) 1574208 ns 1547979 ns 1.02
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/1 thread(s) 2658583 ns 2615500 ns 1.02
mlp7layer_bn(gelu)(32 x 256)/forward/GPU/CUDA 567560 ns 579410 ns 0.98
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/2 thread(s) 9263291 ns 8948521 ns 1.04
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/4 thread(s) 15741709 ns 5918125 ns 2.66
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/8 thread(s) 30677874.5 ns 30404791 ns 1.01
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/1 thread(s) 6782125 ns 10061562 ns 0.67
mlp7layer_bn(gelu)(32 x 256)/zygote/GPU/CUDA 1355856 ns 1389319 ns 0.98
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/2 thread(s) 23068125 ns 22293791 ns 1.03
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/4 thread(s) 28298875 ns 19118125 ns 1.48
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/8 thread(s) 49366125 ns 50278833 ns 0.98
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/1 thread(s) 15664541 ns 19441208.5 ns 0.81
Dense(512 => 512, relu)(512 x 128)/forward/CPU/2 thread(s) 787000 ns 687479 ns 1.14
Dense(512 => 512, relu)(512 x 128)/forward/CPU/4 thread(s) 613416 ns 71083 ns 8.63
Dense(512 => 512, relu)(512 x 128)/forward/CPU/8 thread(s) 1014937.5 ns 1021000 ns 0.99
Dense(512 => 512, relu)(512 x 128)/forward/CPU/1 thread(s) 67541.5 ns 725458.5 ns 0.09310181078586852
Dense(512 => 512, relu)(512 x 128)/forward/GPU/CUDA 47213.5 ns 48336 ns 0.98
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/2 thread(s) 1547187.5 ns 1568500 ns 0.99
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/4 thread(s) 1017917 ns 284021 ns 3.58
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/8 thread(s) 1412645.5 ns 1426229 ns 0.99
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/1 thread(s) 321542 ns 2289417 ns 0.14
Dense(512 => 512, relu)(512 x 128)/zygote/GPU/CUDA 211309 ns 213525.5 ns 0.99
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/2 thread(s) 1571042 ns 1518000 ns 1.03
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/4 thread(s) 1020042 ns 446709 ns 2.28
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/8 thread(s) 1402125.5 ns 1398833.5 ns 1.00
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/1 thread(s) 343812 ns 2227979 ns 0.15
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 3408000.5 ns 3424312.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2049583.5 ns 2076145.5 ns 0.99
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2491583.5 ns 2517375 ns 0.99
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 4842271 ns 6002625 ns 0.81
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/GPU/CUDA 580126 ns 580319 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 24112333.5 ns 24064958.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 17188792 ns 18099937.5 ns 0.95
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 17119042 ns 17179812.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 34987687 ns 37498749.5 ns 0.93
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2894570.5 ns 2895115 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 52602166 ns 53787500 ns 0.98
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 83256812 ns 27724333.5 ns 3.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 173355916.5 ns 165600125 ns 1.05
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 42228833 ns 44506604 ns 0.95
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 250172041.5 ns 250628729 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 148659167 ns 174546375 ns 0.85
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 115831270.5 ns 115593979.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 106484375 ns 447286479 ns 0.24
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/GPU/CUDA 5471067 ns 5483866.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 1103002500 ns 1100854958 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 857541375 ns 467966979 ns 1.83
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 826884708.5 ns 825353979.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 740474770.5 ns 1759896083 ns 0.42
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 35136266 ns 32267946 ns 1.09
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1006767188 ns 1018635708.5 ns 0.99
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 974529458 ns 665400833 ns 1.46
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1286053500 ns 1204699750 ns 1.07
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 727101250 ns 1733389813 ns 0.42
mlp7layer_bn(relu)(32 x 256)/forward/CPU/2 thread(s) 1308583 ns 1226521 ns 1.07
mlp7layer_bn(relu)(32 x 256)/forward/CPU/4 thread(s) 664854.5 ns 961167 ns 0.69
mlp7layer_bn(relu)(32 x 256)/forward/CPU/8 thread(s) 906375 ns 918083 ns 0.99
mlp7layer_bn(relu)(32 x 256)/forward/CPU/1 thread(s) 2049458 ns 2051083.5 ns 1.00
mlp7layer_bn(relu)(32 x 256)/forward/GPU/CUDA 565223.5 ns 578415 ns 0.98
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/2 thread(s) 5804687.5 ns 5660417 ns 1.03
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/4 thread(s) 8913625 ns 2618917 ns 3.40
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/8 thread(s) 24320125 ns 23019583 ns 1.06
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/1 thread(s) 3694792 ns 7086333 ns 0.52
mlp7layer_bn(relu)(32 x 256)/zygote/GPU/CUDA 1307349 ns 1349133 ns 0.97
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/2 thread(s) 9459208 ns 9707250 ns 0.97
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/4 thread(s) 15996021 ns 6502604.5 ns 2.46
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/8 thread(s) 31660167 ns 30901167 ns 1.02
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/1 thread(s) 4429208.5 ns 7612917 ns 0.58
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/2 thread(s) 433416.5 ns 383916 ns 1.13
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/4 thread(s) 466208 ns 31791 ns 14.66
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/8 thread(s) 1932812 ns 2087375 ns 0.93
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/1 thread(s) 54000 ns 91083 ns 0.59
Dense(128 => 128, gelu)(128 x 128)/forward/GPU/CUDA 27617 ns 28712 ns 0.96
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/2 thread(s) 370958.5 ns 406208 ns 0.91
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/4 thread(s) 459083 ns 175875 ns 2.61
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/8 thread(s) 4366749.5 ns 4346812.5 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/1 thread(s) 193875 ns 272959 ns 0.71
Dense(128 => 128, gelu)(128 x 128)/zygote/GPU/CUDA 216603.5 ns 216512 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/2 thread(s) 684292 ns 678958 ns 1.01
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/4 thread(s) 731125 ns 442584 ns 1.65
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/8 thread(s) 4502166 ns 4679166 ns 0.96
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/1 thread(s) 435458 ns 543542 ns 0.80
Dense(128 => 128, relu)(128 x 128)/forward/CPU/2 thread(s) 377416 ns 329792 ns 1.14
Dense(128 => 128, relu)(128 x 128)/forward/CPU/4 thread(s) 405042 ns 13125 ns 30.86
Dense(128 => 128, relu)(128 x 128)/forward/CPU/8 thread(s) 718500 ns 603208 ns 1.19
Dense(128 => 128, relu)(128 x 128)/forward/CPU/1 thread(s) 12834 ns 54708 ns 0.23
Dense(128 => 128, relu)(128 x 128)/forward/GPU/CUDA 27924.5 ns 27935 ns 1.00
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/2 thread(s) 303979.5 ns 354458 ns 0.86
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/4 thread(s) 340916.5 ns 25792 ns 13.22
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/8 thread(s) 858875 ns 719333 ns 1.19
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/1 thread(s) 26333 ns 151792 ns 0.17
Dense(128 => 128, relu)(128 x 128)/zygote/GPU/CUDA 206665 ns 206354.5 ns 1.00
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/2 thread(s) 320916.5 ns 370041 ns 0.87
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/4 thread(s) 355500 ns 45958 ns 7.74
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/8 thread(s) 900792 ns 469708 ns 1.92
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/1 thread(s) 28875 ns 151167 ns 0.19
vgg16(32, 32, 3, 64)/forward/CPU/2 thread(s) 603792041 ns 600144917 ns 1.01
vgg16(32, 32, 3, 64)/forward/CPU/4 thread(s) 430597750 ns 239512791.5 ns 1.80
vgg16(32, 32, 3, 64)/forward/CPU/8 thread(s) 375897687.5 ns 368675770.5 ns 1.02
vgg16(32, 32, 3, 64)/forward/CPU/1 thread(s) 321301750 ns 878611917 ns 0.37
vgg16(32, 32, 3, 64)/forward/GPU/CUDA 7676185 ns 7673480 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/2 thread(s) 2002056937.5 ns 2001190937.5 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/4 thread(s) 1637403750 ns 950953500.5 ns 1.72
vgg16(32, 32, 3, 64)/zygote/CPU/8 thread(s) 1658326812.5 ns 1611271729.5 ns 1.03
vgg16(32, 32, 3, 64)/zygote/CPU/1 thread(s) 1181133416 ns 2652863625 ns 0.45
vgg16(32, 32, 3, 64)/zygote/GPU/CUDA 27018077.5 ns 27099744 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/2 thread(s) 527292 ns 532625 ns 0.99
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/4 thread(s) 402500 ns 175791 ns 2.29
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/8 thread(s) 1773874.5 ns 1765937 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/1 thread(s) 217896 ns 873854 ns 0.25
Dense(512 => 512, gelu)(512 x 128)/forward/GPU/CUDA 47539 ns 48010 ns 0.99
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/2 thread(s) 1972750 ns 1862104.5 ns 1.06
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/4 thread(s) 1830041 ns 1105875 ns 1.65
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/8 thread(s) 14502542 ns 14941916 ns 0.97
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/1 thread(s) 1511084 ns 2753625 ns 0.55
Dense(512 => 512, gelu)(512 x 128)/zygote/GPU/CUDA 222835 ns 222255 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/2 thread(s) 3104000 ns 2909834 ns 1.07
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/4 thread(s) 5000208 ns 2231271 ns 2.24
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/8 thread(s) 15174146 ns 15268500 ns 0.99
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/1 thread(s) 2515479.5 ns 3879541.5 ns 0.65
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/2 thread(s) 1599584 ns 1471375 ns 1.09
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/4 thread(s) 933250 ns 1256917 ns 0.74
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/8 thread(s) 1233959 ns 1257021 ns 0.98
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/1 thread(s) 2349500 ns 2211792 ns 1.06
mlp7layer_bn(tanh)(32 x 256)/forward/GPU/CUDA 564727.5 ns 567447 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/2 thread(s) 5989584 ns 5894166 ns 1.02
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/4 thread(s) 8876479.5 ns 2856229.5 ns 3.11
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/8 thread(s) 25076041 ns 24506146 ns 1.02
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/1 thread(s) 3931104 ns 7294791 ns 0.54
mlp7layer_bn(tanh)(32 x 256)/zygote/GPU/CUDA 1312718 ns 1317261 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/2 thread(s) 11659958.5 ns 11665375 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/4 thread(s) 18499562.5 ns 8766624.5 ns 2.11
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/8 thread(s) 34871271.5 ns 34961958 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/1 thread(s) 6354542 ns 9529500 ns 0.67
Dense(16 => 16, relu)(16 x 128)/forward/CPU/2 thread(s) 4666.5 ns 2959 ns 1.58
Dense(16 => 16, relu)(16 x 128)/forward/CPU/4 thread(s) 2625 ns 2542 ns 1.03
Dense(16 => 16, relu)(16 x 128)/forward/CPU/8 thread(s) 4333 ns 2959 ns 1.46
Dense(16 => 16, relu)(16 x 128)/forward/CPU/1 thread(s) 2292 ns 2520.5 ns 0.91
Dense(16 => 16, relu)(16 x 128)/forward/GPU/CUDA 24932 ns 24587 ns 1.01
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/2 thread(s) 7209 ns 7167 ns 1.01
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/4 thread(s) 9792 ns 7084 ns 1.38
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/8 thread(s) 7375 ns 7458 ns 0.99
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/1 thread(s) 7208 ns 7167 ns 1.01
Dense(16 => 16, relu)(16 x 128)/zygote/GPU/CUDA 190569.5 ns 185970 ns 1.02
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/2 thread(s) 8167 ns 8208 ns 1.00
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/4 thread(s) 8416 ns 8375 ns 1.00
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/8 thread(s) 8375 ns 8416 ns 1.00
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/1 thread(s) 5917 ns 6000 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/2 thread(s) 10437.5 ns 10208 ns 1.02
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/4 thread(s) 13583 ns 14937.5 ns 0.91
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/8 thread(s) 11104.5 ns 10916 ns 1.02
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/1 thread(s) 7250 ns 8833 ns 0.82
Dense(16 => 16, gelu)(16 x 128)/forward/GPU/CUDA 24757 ns 24920 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/2 thread(s) 21708 ns 21542 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/4 thread(s) 21625 ns 21625 ns 1
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/8 thread(s) 21750 ns 21916 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/1 thread(s) 21709 ns 21916.5 ns 0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/GPU/CUDA 195121 ns 195496 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/2 thread(s) 57500 ns 53500 ns 1.07
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/4 thread(s) 53500 ns 56875 ns 0.94
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/8 thread(s) 53583 ns 53625 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/1 thread(s) 55083 ns 55083 ns 1
Dense(128 => 128, identity)(128 x 128)/forward/CPU/2 thread(s) 28583 ns 28541 ns 1.00
Dense(128 => 128, identity)(128 x 128)/forward/CPU/4 thread(s) 28667 ns 28437.5 ns 1.01
Dense(128 => 128, identity)(128 x 128)/forward/CPU/8 thread(s) 29000 ns 28250 ns 1.03
Dense(128 => 128, identity)(128 x 128)/forward/CPU/1 thread(s) 46334 ns 46083 ns 1.01
Dense(128 => 128, identity)(128 x 128)/forward/GPU/CUDA 25674 ns 25773 ns 1.00
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/2 thread(s) 227125 ns 224458 ns 1.01
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/4 thread(s) 276125 ns 44229.5 ns 6.24
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/8 thread(s) 4228416.5 ns 4410250 ns 0.96
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/1 thread(s) 63084 ns 145625 ns 0.43
Dense(128 => 128, identity)(128 x 128)/zygote/GPU/CUDA 166940.5 ns 167043 ns 1.00
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/2 thread(s) 246687 ns 242125 ns 1.02
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/4 thread(s) 293708 ns 68875 ns 4.26
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/8 thread(s) 4174375 ns 4299458 ns 0.97
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/1 thread(s) 68833 ns 145667 ns 0.47
Dense(16 => 16, identity)(16 x 128)/forward/CPU/2 thread(s) 1979.5 ns 2125 ns 0.93
Dense(16 => 16, identity)(16 x 128)/forward/CPU/4 thread(s) 2042 ns 1750 ns 1.17
Dense(16 => 16, identity)(16 x 128)/forward/CPU/8 thread(s) 2583.5 ns 2583 ns 1.00
Dense(16 => 16, identity)(16 x 128)/forward/CPU/1 thread(s) 2000 ns 1917 ns 1.04
Dense(16 => 16, identity)(16 x 128)/forward/GPU/CUDA 22856 ns 22918.5 ns 1.00
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/2 thread(s) 5416 ns 5250 ns 1.03
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/4 thread(s) 5291 ns 5292 ns 1.00
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/8 thread(s) 5375 ns 5417 ns 0.99
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/1 thread(s) 5291 ns 5250 ns 1.01
Dense(16 => 16, identity)(16 x 128)/zygote/GPU/CUDA 171204 ns 171129 ns 1.00
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/2 thread(s) 7500 ns 7583 ns 0.99
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/4 thread(s) 7542 ns 8125 ns 0.93
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/8 thread(s) 7750 ns 7500 ns 1.03
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/1 thread(s) 5708 ns 5125 ns 1.11
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 80930834 ns 81032541 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 48596833 ns 39920458 ns 1.22
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 45693208 ns 45590917 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 56260583.5 ns 153513167 ns 0.37
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/GPU/CUDA 2631409 ns 2660470 ns 0.99
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 622112500 ns 675206709 ns 0.92
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 426582750 ns 319221521 ns 1.34
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 411799708 ns 412689584 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 506749771 ns 704326792 ns 0.72
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 15162045 ns 15217384 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 882246666 ns 875714312.5 ns 1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 844291292 ns 502738834 ns 1.68
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 1135779771 ns 1160733354 ns 0.98
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 925012854.5 ns 1210583500 ns 0.76

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.