Skip to content

Commit

Permalink
feat: BF16 training + inference
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 1, 2025
1 parent e02d9ea commit fbfc851
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 17 deletions.
4 changes: 3 additions & 1 deletion examples/CIFAR10/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
Expand All @@ -21,6 +22,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
BFloat16s = "0.5.0"
Comonicon = "1.0.8"
ConcreteStructs = "0.2.3"
DataAugmentation = "0.3"
Expand All @@ -36,7 +38,7 @@ OneHotArrays = "0.2.5"
Optimisers = "0.4.1"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.5"
Reactant = "0.2.12"
StableRNGs = "1.0.2"
Statistics = "1.10"
Zygote = "0.6.70"
31 changes: 19 additions & 12 deletions examples/CIFAR10/common.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ConcreteStructs, DataAugmentation, ImageShow, Lux, MLDatasets, MLUtils, OneHotArrays,
Printf, ProgressTables, Random
Printf, ProgressTables, Random, BFloat16s
using Reactant, LuxCUDA

@concrete struct TensorDataset
Expand All @@ -15,21 +15,22 @@ function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, Abstrac
return stack(parent itemdata Base.Fix1(apply, ds.transform), img), y
end

function get_cifar10_dataloaders(batchsize; kwargs...)
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)
function get_cifar10_dataloaders(::Type{T}, batchsize; kwargs...) where {T}
cifar10_mean = (0.4914, 0.4822, 0.4465) .|> T
cifar10_std = (0.2471, 0.2435, 0.2616) .|> T

train_transform = RandomResizeCrop((32, 32)) |>
Maybe(FlipX{2}()) |>
ImageToTensor() |>
Normalize(cifar10_mean, cifar10_std)
Normalize(cifar10_mean, cifar10_std) |>
ToEltype(T)

test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std)
test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std) |> ToEltype(T)

trainset = TensorDataset(CIFAR10(:train), train_transform)
trainset = TensorDataset(CIFAR10(; Tx=T, split=:train), train_transform)
trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...)

testset = TensorDataset(CIFAR10(:test), test_transform)
testset = TensorDataset(CIFAR10(; Tx=T, split=:test), test_transform)
testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...)

return trainloader, testloader
Expand Down Expand Up @@ -64,24 +65,30 @@ end

function train_model(
model, opt, scheduler=nothing;
backend::String, batchsize::Int=512, seed::Int=1234, epochs::Int=25
backend::String, batchsize::Int=512, seed::Int=1234, epochs::Int=25,
bfloat16::Bool=false
)
rng = Random.default_rng()
Random.seed!(rng, seed)

prec = bfloat16 ? bf16 : f32
prec_jl = bfloat16 ? BFloat16 : Float32
prec_str = bfloat16 ? "BFloat16" : "Float32"
@printf "[Info] Using %s precision\n" prec_str

accelerator_device = get_accelerator_device(backend)
kwargs = accelerator_device isa ReactantDevice ? (; partial=false) : ()
trainloader, testloader = get_cifar10_dataloaders(batchsize; kwargs...) |>
trainloader, testloader = get_cifar10_dataloaders(prec_jl, batchsize; kwargs...) |>
accelerator_device

ps, st = Lux.setup(rng, model) |> accelerator_device
ps, st = Lux.setup(rng, model) |> prec |> accelerator_device

train_state = Training.TrainState(model, ps, st, opt)

adtype = backend == "reactant" ? AutoEnzyme() : AutoZygote()

if backend == "reactant"
x_ra = rand(rng, Float32, size(first(trainloader)[1])) |> accelerator_device
x_ra = rand(rng, prec_jl, size(first(trainloader)[1])) |> accelerator_device
@printf "[Info] Compiling model with Reactant.jl\n"
st_test = Lux.testmode(st)
model_compiled = Reactant.compile(model, (x_ra, ps, st_test))
Expand Down
4 changes: 2 additions & 2 deletions examples/CIFAR10/conv_mixer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Comonicon.@main function main(;
batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=0.0001,
clip_norm::Bool=false, seed::Int=1234, epochs::Int=25, lr_max::Float64=0.05,
backend::String="reactant"
backend::String="reactant", bfloat16::Bool=false
)
model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size)

Expand All @@ -46,5 +46,5 @@ Comonicon.@main function main(;
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]
)

return train_model(model, opt, lr_schedule; backend, batchsize, seed, epochs)
return train_model(model, opt, lr_schedule; backend, batchsize, seed, epochs, bfloat16)
end
4 changes: 2 additions & 2 deletions examples/CIFAR10/simple_cnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ end
Comonicon.@main function main(;
batchsize::Int=512, weight_decay::Float64=0.0001,
clip_norm::Bool=false, seed::Int=1234, epochs::Int=50, lr::Float64=0.003,
backend::String="reactant"
backend::String="reactant", bfloat16::Bool=false
)
model = SimpleCNN()

opt = AdamW(; eta=lr, lambda=weight_decay)
clip_norm && (opt = OptimiserChain(ClipNorm(), opt))

return train_model(model, opt, nothing; backend, batchsize, seed, epochs)
return train_model(model, opt, nothing; backend, batchsize, seed, epochs, bfloat16)
end

0 comments on commit fbfc851

Please sign in to comment.