Skip to content

Commit

Permalink
feat: implement resnet20 baseline
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 1, 2025
1 parent d17a14d commit 1cd8fc6
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 10 deletions.
10 changes: 9 additions & 1 deletion examples/CIFAR10/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,15 @@ julia --startup-file=no \
On a RTX 4050 6GB Laptop GPU the training takes approximately 3 mins and the final training
and test accuracies are 97% and 65%, respectively.

## MLP-Mixer
## ResNet 20

```bash
julia --startup-file=no \
--project=. \
--threads=auto \
resnet20.jl \
--backend=reactant
```

## ConvMixer

Expand Down
14 changes: 8 additions & 6 deletions examples/CIFAR10/conv_mixer.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Comonicon, Interpolations, Lux, Optimisers, Printf, Random, Statistics, Zygote, Enzyme
using Comonicon, Interpolations, Lux, Optimisers, Printf, Random, Statistics, Zygote

include("common.jl")

Expand Down Expand Up @@ -39,12 +39,14 @@ Comonicon.@main function main(;
)
model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size)

opt = AdamW(; eta=lr_max, lambda=weight_decay)
clip_norm && (opt = OptimiserChain(ClipNorm(), opt))
opt = Adam(0.001f0)
# opt = AdamW(; eta=lr_max, lambda=weight_decay)
# clip_norm && (opt = OptimiserChain(ClipNorm(), opt))

lr_schedule = linear_interpolation(
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]
)
# lr_schedule = linear_interpolation(
# [0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]
# )
lr_schedule = nothing

return train_model(model, opt, lr_schedule; backend, batchsize, seed, epochs, bfloat16)
end
3 changes: 0 additions & 3 deletions examples/CIFAR10/mlp_mixer.jl

This file was deleted.

78 changes: 78 additions & 0 deletions examples/CIFAR10/resnet20.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
using Comonicon, Lux, Optimisers, Printf, Random, Statistics, Zygote

include("common.jl")

function ConvBN(kernel_size, (in_chs, out_chs), act; kwargs...)
return Chain(
Conv(kernel_size, in_chs => out_chs, act; kwargs...),
BatchNorm(out_chs)
)
end

function BasicBlock(in_channels, out_channels; stride=1)
connection = if (stride == 1 && in_channels == out_channels)
NoOpLayer()
else
Conv((3, 3), in_channels => out_channels, identity; stride=stride, pad=SamePad())
end
return Chain(
Parallel(
+,
connection,
Chain(
ConvBN((3, 3), in_channels => out_channels, relu; stride, pad=SamePad()),
ConvBN((3, 3), out_channels => out_channels, identity; pad=SamePad())
)
),
Base.BroadcastFunction(relu)
)
end

function ResNet20(; num_classes=10)
layers = []

# Initial Conv Layer
push!(layers, Chain(
Conv((3, 3), 3 => 16, relu; pad=SamePad()),
BatchNorm(16)
))

# Residual Blocks
block_configs = [
(16, 16, 3, 1), # (in_channels, out_channels, num_blocks, stride)
(16, 32, 3, 2),
(32, 64, 3, 2)
]

for (in_channels, out_channels, num_blocks, stride) in block_configs
for i in 1:num_blocks
push!(layers,
BasicBlock(
i == 1 ? in_channels : out_channels, out_channels;
stride=(i == 1 ? stride : 1)
))
end
end

# Global Pooling and Final Dense Layer
push!(layers, GlobalMeanPool())
push!(layers, FlattenLayer())
push!(layers, Dense(64 => num_classes))

return Chain(layers...)
end

Comonicon.@main function main(;
batchsize::Int=512, weight_decay::Float64=0.0001,
clip_norm::Bool=false, seed::Int=1234, epochs::Int=100, lr::Float64=0.001,
backend::String="reactant", bfloat16::Bool=false
)
model = ResNet20()

opt = AdamW(; eta=lr, lambda=weight_decay)
clip_norm && (opt = OptimiserChain(ClipNorm(), opt))

lr_schedule = nothing

return train_model(model, opt, lr_schedule; backend, batchsize, seed, epochs, bfloat16)
end

0 comments on commit 1cd8fc6

Please sign in to comment.