-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: compile training loop automatically using reactant (#969)
* feat: compile training loop automatically using reactant * refactor: add a level of indirection for the train_step * feat: directly compile step + grad function * fix: make note of current issue with inplace update * chore: bump minimum reactant version * test: setup specific reactant test group * ci: temporarily disable other tests (drop me) * test: fix installation of Reactant * test: start adding loss function tests * fix: xlogx and xlogy now work with Reactant scalars * feat: support regression losses + tests * test: classification losses * fix: more specialization * fix: support all loss functions * chore: comments * fix: bump reactant version * test: don't run reactant tests on windows * test: temporarily disable more tests * fix: reactant GPU support * fix: remove old LossFunctions.jl dispatches * test: try using MSELoss directly * ci: reactivate all tests * ci(windows): don't test Reactant on windows
- Loading branch information
Showing
15 changed files
with
487 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
module LuxReactantExt | ||
|
||
using Enzyme: Enzyme, Const, Duplicated, Active | ||
using Optimisers: Optimisers | ||
using Reactant: Reactant, @compile, TracedRArray | ||
using Setfield: @set! | ||
using Static: False | ||
|
||
using Lux: Lux, LuxOps, Training | ||
using Lux.Training: TrainingBackendCache, ReactantBackend | ||
|
||
include("training.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
function Lux.Training.compute_gradients_impl( | ||
backend::ReactantBackend, objective_function::F, | ||
data, ts::Training.TrainState) where {F} | ||
compiled_gradient_function = @compile compute_gradients_internal( | ||
objective_function, ts.model, data, ts.parameters, ts.states) | ||
|
||
grads, loss, stats, st = compiled_gradient_function( | ||
objective_function, ts.model, data, ts.parameters, ts.states) | ||
|
||
cache = TrainingBackendCache(backend, False(), nothing, (; compiled_gradient_function)) | ||
@set! ts.cache = cache | ||
@set! ts.objective_function = objective_function | ||
@set! ts.states = st | ||
return grads, loss, stats, ts | ||
end | ||
|
||
function Lux.Training.compute_gradients_impl(::ReactantBackend, obj_fn::F, data, | ||
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F} | ||
grads, loss, stats, st = ts.cache.extras.compiled_gradient_function( | ||
obj_fn, ts.model, data, ts.parameters, ts.states) | ||
@set! ts.states = st | ||
return grads, loss, stats, ts | ||
end | ||
|
||
function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F} | ||
dps = Enzyme.make_zero(ps) | ||
_, (loss, stₙ, stats) = Enzyme.autodiff( | ||
Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model), | ||
Duplicated(ps, dps), Const(st), Const(data)) | ||
return dps, loss, stats, stₙ | ||
end | ||
|
||
for inplace in ("!", "") | ||
fname = Symbol(:single_train_step_impl, inplace) | ||
internal_fn = Symbol(:compute_gradients_internal_and_step, inplace) | ||
|
||
@eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F, | ||
data, ts::Training.TrainState) where {F} | ||
compiled_grad_and_step_function = @compile $(internal_fn)( | ||
objective_function, ts.model, data, ts.parameters, ts.states, | ||
ts.optimizer_state) | ||
|
||
grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function( | ||
objective_function, ts.model, data, ts.parameters, ts.states, | ||
ts.optimizer_state) | ||
|
||
cache = TrainingBackendCache( | ||
backend, False(), nothing, (; compiled_grad_and_step_function)) | ||
@set! ts.cache = cache | ||
@set! ts.objective_function = objective_function | ||
@set! ts.states = st | ||
@set! ts.parameters = ps | ||
@set! ts.optimizer_state = opt_state | ||
@set! ts.step = ts.step + 1 | ||
|
||
return grads, loss, stats, ts | ||
end | ||
|
||
@eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data, | ||
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F} | ||
grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function( | ||
obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) | ||
|
||
@set! ts.states = st | ||
@set! ts.parameters = ps | ||
@set! ts.optimizer_state = opt_state | ||
@set! ts.step = ts.step + 1 | ||
|
||
return grads, loss, stats, ts | ||
end | ||
end | ||
|
||
function compute_gradients_internal_and_step(objective_function::F, model, data, ps, | ||
st, opt_state) where {F} | ||
dps = Enzyme.make_zero(ps) | ||
_, (loss, stₙ, stats) = Enzyme.autodiff( | ||
Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model), | ||
Duplicated(ps, dps), Const(st), Const(data)) | ||
opt_state, ps = Optimisers.update(opt_state, ps, dps) | ||
return dps, ps, loss, stats, stₙ, opt_state | ||
end | ||
|
||
function compute_gradients_internal_and_step!(objective_function::F, model, data, ps, | ||
st, opt_state) where {F} | ||
dps = Enzyme.make_zero(ps) | ||
_, (loss, stₙ, stats) = Enzyme.autodiff( | ||
Enzyme.ReverseWithPrimal, Const(objective_function), Active, Const(model), | ||
Duplicated(ps, dps), Const(st), Const(data)) | ||
# XXX: Inplace updates not actually inplace | ||
opt_state, ps = Optimisers.update!(opt_state, ps, dps) | ||
return dps, ps, loss, stats, stₙ, opt_state | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
1b0d6f8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lux Benchmarks
Dense(512 => 512, identity)(512 x 128)/forward/CPU/2 thread(s)
412250
ns414958
ns0.99
Dense(512 => 512, identity)(512 x 128)/forward/CPU/4 thread(s)
244083
ns322541
ns0.76
Dense(512 => 512, identity)(512 x 128)/forward/CPU/8 thread(s)
322041
ns323167
ns1.00
Dense(512 => 512, identity)(512 x 128)/forward/CPU/1 thread(s)
739625
ns739562.5
ns1.00
Dense(512 => 512, identity)(512 x 128)/forward/GPU/CUDA
43576
ns44543
ns0.98
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/2 thread(s)
1368688
ns1335729
ns1.02
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/4 thread(s)
1198625
ns485000
ns2.47
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/8 thread(s)
13918417
ns14073833
ns0.99
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/1 thread(s)
929312.5
ns2211312.5
ns0.42
Dense(512 => 512, identity)(512 x 128)/zygote/GPU/CUDA
190464
ns194175
ns0.98
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/2 thread(s)
1348750
ns1374959
ns0.98
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/4 thread(s)
1282083
ns596188
ns2.15
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/8 thread(s)
13837312.5
ns13290875.5
ns1.04
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/1 thread(s)
987250
ns2199270.5
ns0.45
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s)
1655917
ns1665666
ns0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s)
1089000
ns1186833.5
ns0.92
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s)
1532499.5
ns1536854.5
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s)
2439708
ns2912062.5
ns0.84
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/GPU/CUDA
211500
ns213313
ns0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s)
12136437.5
ns12145187.5
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s)
8847479
ns9486083
ns0.93
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s)
9240938
ns9213083
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s)
17956208
ns18563708
ns0.97
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/GPU/CUDA
1905747
ns1921274.5
ns0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s)
17305250
ns17291000
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s)
13985416
ns14310062.5
ns0.98
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s)
14505584
ns14535333
ns1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s)
21107833
ns21812208
ns0.97
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s)
249894083
ns250754270.5
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s)
148856208
ns174424541
ns0.85
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s)
115718875
ns115532521
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s)
101619125
ns446573667
ns0.23
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/GPU/CUDA
5485492
ns5489738
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s)
1228009625
ns1222307709
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s)
931338167
ns543403209
ns1.71
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s)
829169479
ns832977124.5
ns1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s)
628483479
ns1653507000
ns0.38
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA
38151835
ns34972271
ns1.09
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s)
1134889125
ns1142743917
ns0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s)
992066062.5
ns686139667
ns1.45
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s)
1309459854
ns1325824667
ns0.99
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s)
745440771
ns1748793708.5
ns0.43
lenet(28, 28, 1, 32)/forward/CPU/2 thread(s)
1092042
ns1120083
ns0.97
lenet(28, 28, 1, 32)/forward/CPU/4 thread(s)
1645709
ns820062
ns2.01
lenet(28, 28, 1, 32)/forward/CPU/8 thread(s)
3466333
ns3738667
ns0.93
lenet(28, 28, 1, 32)/forward/CPU/1 thread(s)
957250
ns785458.5
ns1.22
lenet(28, 28, 1, 32)/forward/GPU/CUDA
270549.5
ns280004
ns0.97
lenet(28, 28, 1, 32)/zygote/CPU/2 thread(s)
2979042
ns2992209
ns1.00
lenet(28, 28, 1, 32)/zygote/CPU/4 thread(s)
4110542
ns2457292
ns1.67
lenet(28, 28, 1, 32)/zygote/CPU/8 thread(s)
10529229
ns10152708
ns1.04
lenet(28, 28, 1, 32)/zygote/CPU/1 thread(s)
3308833
ns3200812.5
ns1.03
lenet(28, 28, 1, 32)/zygote/GPU/CUDA
1070477
ns1093024
ns0.98
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s)
2350792
ns2350792
ns1
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s)
1364187.5
ns1546500
ns0.88
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s)
1709000
ns1703875
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s)
3666666.5
ns4314041
ns0.85
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/GPU/CUDA
210396
ns214178
ns0.98
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s)
20275459
ns20293542
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s)
16981437
ns17667520.5
ns0.96
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s)
18162375
ns18215687.5
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s)
26198500
ns26742770.5
ns0.98
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA
1979369
ns1989179
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s)
46206895.5
ns44338833
ns1.04
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s)
41017187.5
ns29803125
ns1.38
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s)
41176208.5
ns41253750.5
ns1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s)
47588917
ns49627062.5
ns0.96
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s)
4669000
ns4666292
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s)
2603916
ns2867729.5
ns0.91
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s)
2999833
ns3027042
ns0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s)
7252188
ns8637375
ns0.84
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/GPU/CUDA
517525.5
ns512379
ns1.01
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s)
40878729.5
ns40744375
ns1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s)
33994250
ns34861000
ns0.98
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s)
33958333
ns34130125
ns0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s)
51263292
ns53656458
ns0.96
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA
3013320.5
ns3039460
ns0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s)
113392541.5
ns109854709
ns1.03
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s)
136850541
ns60211500
ns2.27
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s)
250011854.5
ns244742208.5
ns1.02
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s)
95314208
ns100222291.5
ns0.95
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s)
270234083
ns270538750
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s)
157676542
ns187102791.5
ns0.84
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s)
128100708
ns128131083.5
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s)
144520145.5
ns496544584
ns0.29
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/GPU/CUDA
7091283
ns7095920
ns1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s)
1503173291.5
ns1493043979
ns1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s)
1201978125
ns820794750
ns1.46
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s)
1103595666.5
ns1089880791.5
ns1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s)
1028790125.5
ns2057983854
ns0.50
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA
33654931
ns33958491
ns0.99
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s)
2089411062.5
ns2031772896
ns1.03
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s)
1851532083
ns1169902125
ns1.58
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s)
2117297604.5
ns2031263062.5
ns1.04
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s)
1605439208
ns2621950583
ns0.61
lenet(28, 28, 1, 128)/forward/CPU/2 thread(s)
2066438
ns2080791
ns0.99
lenet(28, 28, 1, 128)/forward/CPU/4 thread(s)
3005354
ns1266458.5
ns2.37
lenet(28, 28, 1, 128)/forward/CPU/8 thread(s)
7102958.5
ns7120666.5
ns1.00
lenet(28, 28, 1, 128)/forward/CPU/1 thread(s)
2151875
ns2469583
ns0.87
lenet(28, 28, 1, 128)/forward/GPU/CUDA
270072.5
ns271494.5
ns0.99
lenet(28, 28, 1, 128)/zygote/CPU/2 thread(s)
9657334
ns9645292
ns1.00
lenet(28, 28, 1, 128)/zygote/CPU/4 thread(s)
11945459
ns6578875
ns1.82
lenet(28, 28, 1, 128)/zygote/CPU/8 thread(s)
23020875
ns23723729.5
ns0.97
lenet(28, 28, 1, 128)/zygote/CPU/1 thread(s)
10467750
ns11743000
ns0.89
lenet(28, 28, 1, 128)/zygote/GPU/CUDA
1095059
ns1108810
ns0.99
vgg16(32, 32, 3, 32)/forward/CPU/2 thread(s)
381251625
ns378620500
ns1.01
vgg16(32, 32, 3, 32)/forward/CPU/4 thread(s)
309062375
ns148222625
ns2.09
vgg16(32, 32, 3, 32)/forward/CPU/8 thread(s)
241236375
ns232625416.5
ns1.04
vgg16(32, 32, 3, 32)/forward/CPU/1 thread(s)
180294333.5
ns452981958.5
ns0.40
vgg16(32, 32, 3, 32)/forward/GPU/CUDA
4847355
ns4877366.5
ns0.99
vgg16(32, 32, 3, 32)/zygote/CPU/2 thread(s)
1146004375
ns1151789959
ns0.99
vgg16(32, 32, 3, 32)/zygote/CPU/4 thread(s)
966522375
ns608267042
ns1.59
vgg16(32, 32, 3, 32)/zygote/CPU/8 thread(s)
1026283833
ns958784875
ns1.07
vgg16(32, 32, 3, 32)/zygote/CPU/1 thread(s)
662156542
ns1398102250
ns0.47
vgg16(32, 32, 3, 32)/zygote/GPU/CUDA
17798543
ns17573518
ns1.01
lenet(28, 28, 1, 64)/forward/CPU/2 thread(s)
1050458
ns1044000
ns1.01
lenet(28, 28, 1, 64)/forward/CPU/4 thread(s)
1656750
ns966666.5
ns1.71
lenet(28, 28, 1, 64)/forward/CPU/8 thread(s)
6491250
ns5483500
ns1.18
lenet(28, 28, 1, 64)/forward/CPU/1 thread(s)
1312792
ns1369000
ns0.96
lenet(28, 28, 1, 64)/forward/GPU/CUDA
270319.5
ns277684.5
ns0.97
lenet(28, 28, 1, 64)/zygote/CPU/2 thread(s)
6504813
ns6395583
ns1.02
lenet(28, 28, 1, 64)/zygote/CPU/4 thread(s)
13132417
ns4649000
ns2.82
lenet(28, 28, 1, 64)/zygote/CPU/8 thread(s)
19754250
ns18457937.5
ns1.07
lenet(28, 28, 1, 64)/zygote/CPU/1 thread(s)
5741521
ns6087375
ns0.94
lenet(28, 28, 1, 64)/zygote/GPU/CUDA
1124270
ns1153126
ns0.97
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s)
70469479
ns70616187.5
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s)
43706291.5
ns34338895.5
ns1.27
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s)
39518625
ns39546146
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s)
35367542
ns132480333
ns0.27
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/GPU/CUDA
1851430
ns1837859.5
ns1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s)
356004604
ns354650895.5
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s)
270290792
ns158687854
ns1.70
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s)
254164750
ns253835396.5
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s)
271950333.5
ns535065979.5
ns0.51
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/GPU/CUDA
16539357
ns16493787
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s)
395899500
ns394789208
ns1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s)
372060292
ns245506292
ns1.52
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s)
713782625
ns682534167
ns1.05
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s)
447779125
ns711436834
ns0.63
vgg16(32, 32, 3, 128)/forward/CPU/2 thread(s)
1190490459
ns1186860667
ns1.00
vgg16(32, 32, 3, 128)/forward/CPU/4 thread(s)
832670062.5
ns435266750
ns1.91
vgg16(32, 32, 3, 128)/forward/CPU/8 thread(s)
629944291
ns628427791
ns1.00
vgg16(32, 32, 3, 128)/forward/CPU/1 thread(s)
681507396
ns1780484229
ns0.38
vgg16(32, 32, 3, 128)/forward/GPU/CUDA
12475051
ns12477417
ns1.00
vgg16(32, 32, 3, 128)/zygote/CPU/2 thread(s)
3708044854
ns3652621271
ns1.02
vgg16(32, 32, 3, 128)/zygote/CPU/4 thread(s)
2828581542
ns1639329875
ns1.73
vgg16(32, 32, 3, 128)/zygote/CPU/8 thread(s)
2698925958
ns2709465041
ns1.00
vgg16(32, 32, 3, 128)/zygote/CPU/1 thread(s)
2137669604.5
ns5075123916
ns0.42
vgg16(32, 32, 3, 128)/zygote/GPU/CUDA
49415932
ns49797376
ns0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s)
3423125
ns3423958
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s)
2078500
ns2097249.5
ns0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s)
2518458
ns2534854
ns0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s)
4870375
ns6018625
ns0.81
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/GPU/CUDA
586699.5
ns580639
ns1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s)
25989500
ns25949208
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s)
19069958.5
ns20274833
ns0.94
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s)
19259312
ns19561271
ns0.98
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s)
36800833
ns39224583
ns0.94
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA
2993892
ns2980196
ns1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s)
54216125
ns55399562.5
ns0.98
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s)
83642959
ns28378292
ns2.95
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s)
174413208.5
ns172128937.5
ns1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s)
42857708.5
ns45636375
ns0.94
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s)
1784458
ns1783542
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s)
1095646
ns1197959
ns0.91
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s)
1575292
ns1577021.5
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s)
2364687
ns3027083.5
ns0.78
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/GPU/CUDA
216504.5
ns218302
ns0.99
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s)
12531833
ns12541854.5
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s)
9200375
ns9966458
ns0.92
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s)
9626292
ns9641875
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s)
18391667
ns18982208
ns0.97
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA
1950268
ns1943759
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s)
17650333.5
ns17642208
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s)
14301166
ns14745479
ns0.97
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s)
14560250.5
ns14622083
ns1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s)
21506145.5
ns22196749.5
ns0.97
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s)
70470354
ns70503791.5
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s)
43665542
ns34154375
ns1.28
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s)
39582249.5
ns39724625.5
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s)
35175625
ns133426312.5
ns0.26
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/GPU/CUDA
1838843
ns1855504
ns0.99
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s)
360077895.5
ns357508542
ns1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s)
349062958.5
ns236762959
ns1.47
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s)
305213917
ns305563667
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s)
462206583
ns731068500
ns0.63
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA
13925027
ns13898567
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s)
417720542
ns418493479
ns1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s)
426193583
ns253429500
ns1.68
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s)
717833375.5
ns696829083.5
ns1.03
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s)
394045333.5
ns717012750
ns0.55
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/2 thread(s)
1908458
ns1657812.5
ns1.15
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/4 thread(s)
1382145.5
ns1559500
ns0.89
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/8 thread(s)
1574208
ns1547979
ns1.02
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/1 thread(s)
2658583
ns2615500
ns1.02
mlp7layer_bn(gelu)(32 x 256)/forward/GPU/CUDA
567560
ns579410
ns0.98
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/2 thread(s)
9263291
ns8948521
ns1.04
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/4 thread(s)
15741709
ns5918125
ns2.66
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/8 thread(s)
30677874.5
ns30404791
ns1.01
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/1 thread(s)
6782125
ns10061562
ns0.67
mlp7layer_bn(gelu)(32 x 256)/zygote/GPU/CUDA
1355856
ns1389319
ns0.98
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/2 thread(s)
23068125
ns22293791
ns1.03
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/4 thread(s)
28298875
ns19118125
ns1.48
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/8 thread(s)
49366125
ns50278833
ns0.98
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/1 thread(s)
15664541
ns19441208.5
ns0.81
Dense(512 => 512, relu)(512 x 128)/forward/CPU/2 thread(s)
787000
ns687479
ns1.14
Dense(512 => 512, relu)(512 x 128)/forward/CPU/4 thread(s)
613416
ns71083
ns8.63
Dense(512 => 512, relu)(512 x 128)/forward/CPU/8 thread(s)
1014937.5
ns1021000
ns0.99
Dense(512 => 512, relu)(512 x 128)/forward/CPU/1 thread(s)
67541.5
ns725458.5
ns0.09310181078586852
Dense(512 => 512, relu)(512 x 128)/forward/GPU/CUDA
47213.5
ns48336
ns0.98
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/2 thread(s)
1547187.5
ns1568500
ns0.99
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/4 thread(s)
1017917
ns284021
ns3.58
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/8 thread(s)
1412645.5
ns1426229
ns0.99
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/1 thread(s)
321542
ns2289417
ns0.14
Dense(512 => 512, relu)(512 x 128)/zygote/GPU/CUDA
211309
ns213525.5
ns0.99
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/2 thread(s)
1571042
ns1518000
ns1.03
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/4 thread(s)
1020042
ns446709
ns2.28
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/8 thread(s)
1402125.5
ns1398833.5
ns1.00
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/1 thread(s)
343812
ns2227979
ns0.15
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s)
3408000.5
ns3424312.5
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s)
2049583.5
ns2076145.5
ns0.99
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s)
2491583.5
ns2517375
ns0.99
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s)
4842271
ns6002625
ns0.81
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/GPU/CUDA
580126
ns580319
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s)
24112333.5
ns24064958.5
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s)
17188792
ns18099937.5
ns0.95
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s)
17119042
ns17179812.5
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s)
34987687
ns37498749.5
ns0.93
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/GPU/CUDA
2894570.5
ns2895115
ns1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s)
52602166
ns53787500
ns0.98
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s)
83256812
ns27724333.5
ns3.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s)
173355916.5
ns165600125
ns1.05
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s)
42228833
ns44506604
ns0.95
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s)
250172041.5
ns250628729
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s)
148659167
ns174546375
ns0.85
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s)
115831270.5
ns115593979.5
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s)
106484375
ns447286479
ns0.24
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/GPU/CUDA
5471067
ns5483866.5
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s)
1103002500
ns1100854958
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s)
857541375
ns467966979
ns1.83
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s)
826884708.5
ns825353979.5
ns1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s)
740474770.5
ns1759896083
ns0.42
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/GPU/CUDA
35136266
ns32267946
ns1.09
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s)
1006767188
ns1018635708.5
ns0.99
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s)
974529458
ns665400833
ns1.46
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s)
1286053500
ns1204699750
ns1.07
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s)
727101250
ns1733389813
ns0.42
mlp7layer_bn(relu)(32 x 256)/forward/CPU/2 thread(s)
1308583
ns1226521
ns1.07
mlp7layer_bn(relu)(32 x 256)/forward/CPU/4 thread(s)
664854.5
ns961167
ns0.69
mlp7layer_bn(relu)(32 x 256)/forward/CPU/8 thread(s)
906375
ns918083
ns0.99
mlp7layer_bn(relu)(32 x 256)/forward/CPU/1 thread(s)
2049458
ns2051083.5
ns1.00
mlp7layer_bn(relu)(32 x 256)/forward/GPU/CUDA
565223.5
ns578415
ns0.98
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/2 thread(s)
5804687.5
ns5660417
ns1.03
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/4 thread(s)
8913625
ns2618917
ns3.40
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/8 thread(s)
24320125
ns23019583
ns1.06
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/1 thread(s)
3694792
ns7086333
ns0.52
mlp7layer_bn(relu)(32 x 256)/zygote/GPU/CUDA
1307349
ns1349133
ns0.97
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/2 thread(s)
9459208
ns9707250
ns0.97
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/4 thread(s)
15996021
ns6502604.5
ns2.46
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/8 thread(s)
31660167
ns30901167
ns1.02
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/1 thread(s)
4429208.5
ns7612917
ns0.58
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/2 thread(s)
433416.5
ns383916
ns1.13
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/4 thread(s)
466208
ns31791
ns14.66
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/8 thread(s)
1932812
ns2087375
ns0.93
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/1 thread(s)
54000
ns91083
ns0.59
Dense(128 => 128, gelu)(128 x 128)/forward/GPU/CUDA
27617
ns28712
ns0.96
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/2 thread(s)
370958.5
ns406208
ns0.91
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/4 thread(s)
459083
ns175875
ns2.61
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/8 thread(s)
4366749.5
ns4346812.5
ns1.00
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/1 thread(s)
193875
ns272959
ns0.71
Dense(128 => 128, gelu)(128 x 128)/zygote/GPU/CUDA
216603.5
ns216512
ns1.00
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/2 thread(s)
684292
ns678958
ns1.01
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/4 thread(s)
731125
ns442584
ns1.65
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/8 thread(s)
4502166
ns4679166
ns0.96
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/1 thread(s)
435458
ns543542
ns0.80
Dense(128 => 128, relu)(128 x 128)/forward/CPU/2 thread(s)
377416
ns329792
ns1.14
Dense(128 => 128, relu)(128 x 128)/forward/CPU/4 thread(s)
405042
ns13125
ns30.86
Dense(128 => 128, relu)(128 x 128)/forward/CPU/8 thread(s)
718500
ns603208
ns1.19
Dense(128 => 128, relu)(128 x 128)/forward/CPU/1 thread(s)
12834
ns54708
ns0.23
Dense(128 => 128, relu)(128 x 128)/forward/GPU/CUDA
27924.5
ns27935
ns1.00
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/2 thread(s)
303979.5
ns354458
ns0.86
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/4 thread(s)
340916.5
ns25792
ns13.22
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/8 thread(s)
858875
ns719333
ns1.19
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/1 thread(s)
26333
ns151792
ns0.17
Dense(128 => 128, relu)(128 x 128)/zygote/GPU/CUDA
206665
ns206354.5
ns1.00
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/2 thread(s)
320916.5
ns370041
ns0.87
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/4 thread(s)
355500
ns45958
ns7.74
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/8 thread(s)
900792
ns469708
ns1.92
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/1 thread(s)
28875
ns151167
ns0.19
vgg16(32, 32, 3, 64)/forward/CPU/2 thread(s)
603792041
ns600144917
ns1.01
vgg16(32, 32, 3, 64)/forward/CPU/4 thread(s)
430597750
ns239512791.5
ns1.80
vgg16(32, 32, 3, 64)/forward/CPU/8 thread(s)
375897687.5
ns368675770.5
ns1.02
vgg16(32, 32, 3, 64)/forward/CPU/1 thread(s)
321301750
ns878611917
ns0.37
vgg16(32, 32, 3, 64)/forward/GPU/CUDA
7676185
ns7673480
ns1.00
vgg16(32, 32, 3, 64)/zygote/CPU/2 thread(s)
2002056937.5
ns2001190937.5
ns1.00
vgg16(32, 32, 3, 64)/zygote/CPU/4 thread(s)
1637403750
ns950953500.5
ns1.72
vgg16(32, 32, 3, 64)/zygote/CPU/8 thread(s)
1658326812.5
ns1611271729.5
ns1.03
vgg16(32, 32, 3, 64)/zygote/CPU/1 thread(s)
1181133416
ns2652863625
ns0.45
vgg16(32, 32, 3, 64)/zygote/GPU/CUDA
27018077.5
ns27099744
ns1.00
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/2 thread(s)
527292
ns532625
ns0.99
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/4 thread(s)
402500
ns175791
ns2.29
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/8 thread(s)
1773874.5
ns1765937
ns1.00
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/1 thread(s)
217896
ns873854
ns0.25
Dense(512 => 512, gelu)(512 x 128)/forward/GPU/CUDA
47539
ns48010
ns0.99
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/2 thread(s)
1972750
ns1862104.5
ns1.06
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/4 thread(s)
1830041
ns1105875
ns1.65
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/8 thread(s)
14502542
ns14941916
ns0.97
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/1 thread(s)
1511084
ns2753625
ns0.55
Dense(512 => 512, gelu)(512 x 128)/zygote/GPU/CUDA
222835
ns222255
ns1.00
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/2 thread(s)
3104000
ns2909834
ns1.07
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/4 thread(s)
5000208
ns2231271
ns2.24
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/8 thread(s)
15174146
ns15268500
ns0.99
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/1 thread(s)
2515479.5
ns3879541.5
ns0.65
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/2 thread(s)
1599584
ns1471375
ns1.09
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/4 thread(s)
933250
ns1256917
ns0.74
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/8 thread(s)
1233959
ns1257021
ns0.98
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/1 thread(s)
2349500
ns2211792
ns1.06
mlp7layer_bn(tanh)(32 x 256)/forward/GPU/CUDA
564727.5
ns567447
ns1.00
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/2 thread(s)
5989584
ns5894166
ns1.02
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/4 thread(s)
8876479.5
ns2856229.5
ns3.11
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/8 thread(s)
25076041
ns24506146
ns1.02
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/1 thread(s)
3931104
ns7294791
ns0.54
mlp7layer_bn(tanh)(32 x 256)/zygote/GPU/CUDA
1312718
ns1317261
ns1.00
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/2 thread(s)
11659958.5
ns11665375
ns1.00
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/4 thread(s)
18499562.5
ns8766624.5
ns2.11
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/8 thread(s)
34871271.5
ns34961958
ns1.00
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/1 thread(s)
6354542
ns9529500
ns0.67
Dense(16 => 16, relu)(16 x 128)/forward/CPU/2 thread(s)
4666.5
ns2959
ns1.58
Dense(16 => 16, relu)(16 x 128)/forward/CPU/4 thread(s)
2625
ns2542
ns1.03
Dense(16 => 16, relu)(16 x 128)/forward/CPU/8 thread(s)
4333
ns2959
ns1.46
Dense(16 => 16, relu)(16 x 128)/forward/CPU/1 thread(s)
2292
ns2520.5
ns0.91
Dense(16 => 16, relu)(16 x 128)/forward/GPU/CUDA
24932
ns24587
ns1.01
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/2 thread(s)
7209
ns7167
ns1.01
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/4 thread(s)
9792
ns7084
ns1.38
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/8 thread(s)
7375
ns7458
ns0.99
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/1 thread(s)
7208
ns7167
ns1.01
Dense(16 => 16, relu)(16 x 128)/zygote/GPU/CUDA
190569.5
ns185970
ns1.02
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/2 thread(s)
8167
ns8208
ns1.00
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/4 thread(s)
8416
ns8375
ns1.00
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/8 thread(s)
8375
ns8416
ns1.00
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/1 thread(s)
5917
ns6000
ns0.99
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/2 thread(s)
10437.5
ns10208
ns1.02
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/4 thread(s)
13583
ns14937.5
ns0.91
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/8 thread(s)
11104.5
ns10916
ns1.02
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/1 thread(s)
7250
ns8833
ns0.82
Dense(16 => 16, gelu)(16 x 128)/forward/GPU/CUDA
24757
ns24920
ns0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/2 thread(s)
21708
ns21542
ns1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/4 thread(s)
21625
ns21625
ns1
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/8 thread(s)
21750
ns21916
ns0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/1 thread(s)
21709
ns21916.5
ns0.99
Dense(16 => 16, gelu)(16 x 128)/zygote/GPU/CUDA
195121
ns195496
ns1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/2 thread(s)
57500
ns53500
ns1.07
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/4 thread(s)
53500
ns56875
ns0.94
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/8 thread(s)
53583
ns53625
ns1.00
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/1 thread(s)
55083
ns55083
ns1
Dense(128 => 128, identity)(128 x 128)/forward/CPU/2 thread(s)
28583
ns28541
ns1.00
Dense(128 => 128, identity)(128 x 128)/forward/CPU/4 thread(s)
28667
ns28437.5
ns1.01
Dense(128 => 128, identity)(128 x 128)/forward/CPU/8 thread(s)
29000
ns28250
ns1.03
Dense(128 => 128, identity)(128 x 128)/forward/CPU/1 thread(s)
46334
ns46083
ns1.01
Dense(128 => 128, identity)(128 x 128)/forward/GPU/CUDA
25674
ns25773
ns1.00
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/2 thread(s)
227125
ns224458
ns1.01
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/4 thread(s)
276125
ns44229.5
ns6.24
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/8 thread(s)
4228416.5
ns4410250
ns0.96
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/1 thread(s)
63084
ns145625
ns0.43
Dense(128 => 128, identity)(128 x 128)/zygote/GPU/CUDA
166940.5
ns167043
ns1.00
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/2 thread(s)
246687
ns242125
ns1.02
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/4 thread(s)
293708
ns68875
ns4.26
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/8 thread(s)
4174375
ns4299458
ns0.97
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/1 thread(s)
68833
ns145667
ns0.47
Dense(16 => 16, identity)(16 x 128)/forward/CPU/2 thread(s)
1979.5
ns2125
ns0.93
Dense(16 => 16, identity)(16 x 128)/forward/CPU/4 thread(s)
2042
ns1750
ns1.17
Dense(16 => 16, identity)(16 x 128)/forward/CPU/8 thread(s)
2583.5
ns2583
ns1.00
Dense(16 => 16, identity)(16 x 128)/forward/CPU/1 thread(s)
2000
ns1917
ns1.04
Dense(16 => 16, identity)(16 x 128)/forward/GPU/CUDA
22856
ns22918.5
ns1.00
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/2 thread(s)
5416
ns5250
ns1.03
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/4 thread(s)
5291
ns5292
ns1.00
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/8 thread(s)
5375
ns5417
ns0.99
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/1 thread(s)
5291
ns5250
ns1.01
Dense(16 => 16, identity)(16 x 128)/zygote/GPU/CUDA
171204
ns171129
ns1.00
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/2 thread(s)
7500
ns7583
ns0.99
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/4 thread(s)
7542
ns8125
ns0.93
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/8 thread(s)
7750
ns7500
ns1.03
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/1 thread(s)
5708
ns5125
ns1.11
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s)
80930834
ns81032541
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s)
48596833
ns39920458
ns1.22
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s)
45693208
ns45590917
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s)
56260583.5
ns153513167
ns0.37
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/GPU/CUDA
2631409
ns2660470
ns0.99
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s)
622112500
ns675206709
ns0.92
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s)
426582750
ns319221521
ns1.34
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s)
411799708
ns412689584
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s)
506749771
ns704326792
ns0.72
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA
15162045
ns15217384
ns1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s)
882246666
ns875714312.5
ns1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s)
844291292
ns502738834
ns1.68
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s)
1135779771
ns1160733354
ns0.98
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s)
925012854.5
ns1210583500
ns0.76
This comment was automatically generated by workflow using github-action-benchmark.