From fbfc851fd260d86e9b9caad29b1b01804d76720b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 31 Dec 2024 10:48:06 -0500 Subject: [PATCH] feat: BF16 training + inference --- examples/CIFAR10/Project.toml | 4 +++- examples/CIFAR10/common.jl | 31 +++++++++++++++++++------------ examples/CIFAR10/conv_mixer.jl | 4 ++-- examples/CIFAR10/simple_cnn.jl | 4 ++-- 4 files changed, 26 insertions(+), 17 deletions(-) diff --git a/examples/CIFAR10/Project.toml b/examples/CIFAR10/Project.toml index c1b785b18..c0dffde55 100644 --- a/examples/CIFAR10/Project.toml +++ b/examples/CIFAR10/Project.toml @@ -1,4 +1,5 @@ [deps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e" @@ -21,6 +22,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +BFloat16s = "0.5.0" Comonicon = "1.0.8" ConcreteStructs = "0.2.3" DataAugmentation = "0.3" @@ -36,7 +38,7 @@ OneHotArrays = "0.2.5" Optimisers = "0.4.1" Printf = "1.10" Random = "1.10" -Reactant = "0.2.5" +Reactant = "0.2.12" StableRNGs = "1.0.2" Statistics = "1.10" Zygote = "0.6.70" diff --git a/examples/CIFAR10/common.jl b/examples/CIFAR10/common.jl index 095a9b8d2..7457306d7 100644 --- a/examples/CIFAR10/common.jl +++ b/examples/CIFAR10/common.jl @@ -1,5 +1,5 @@ using ConcreteStructs, DataAugmentation, ImageShow, Lux, MLDatasets, MLUtils, OneHotArrays, - Printf, ProgressTables, Random + Printf, ProgressTables, Random, BFloat16s using Reactant, LuxCUDA @concrete struct TensorDataset @@ -15,21 +15,22 @@ function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, Abstrac return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img), y end -function get_cifar10_dataloaders(batchsize; kwargs...) - cifar10_mean = (0.4914, 0.4822, 0.4465) - cifar10_std = (0.2471, 0.2435, 0.2616) +function get_cifar10_dataloaders(::Type{T}, batchsize; kwargs...) where {T} + cifar10_mean = (0.4914, 0.4822, 0.4465) .|> T + cifar10_std = (0.2471, 0.2435, 0.2616) .|> T train_transform = RandomResizeCrop((32, 32)) |> Maybe(FlipX{2}()) |> ImageToTensor() |> - Normalize(cifar10_mean, cifar10_std) + Normalize(cifar10_mean, cifar10_std) |> + ToEltype(T) - test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std) + test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std) |> ToEltype(T) - trainset = TensorDataset(CIFAR10(:train), train_transform) + trainset = TensorDataset(CIFAR10(; Tx=T, split=:train), train_transform) trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...) - testset = TensorDataset(CIFAR10(:test), test_transform) + testset = TensorDataset(CIFAR10(; Tx=T, split=:test), test_transform) testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...) return trainloader, testloader @@ -64,24 +65,30 @@ end function train_model( model, opt, scheduler=nothing; - backend::String, batchsize::Int=512, seed::Int=1234, epochs::Int=25 + backend::String, batchsize::Int=512, seed::Int=1234, epochs::Int=25, + bfloat16::Bool=false ) rng = Random.default_rng() Random.seed!(rng, seed) + prec = bfloat16 ? bf16 : f32 + prec_jl = bfloat16 ? BFloat16 : Float32 + prec_str = bfloat16 ? "BFloat16" : "Float32" + @printf "[Info] Using %s precision\n" prec_str + accelerator_device = get_accelerator_device(backend) kwargs = accelerator_device isa ReactantDevice ? (; partial=false) : () - trainloader, testloader = get_cifar10_dataloaders(batchsize; kwargs...) |> + trainloader, testloader = get_cifar10_dataloaders(prec_jl, batchsize; kwargs...) |> accelerator_device - ps, st = Lux.setup(rng, model) |> accelerator_device + ps, st = Lux.setup(rng, model) |> prec |> accelerator_device train_state = Training.TrainState(model, ps, st, opt) adtype = backend == "reactant" ? AutoEnzyme() : AutoZygote() if backend == "reactant" - x_ra = rand(rng, Float32, size(first(trainloader)[1])) |> accelerator_device + x_ra = rand(rng, prec_jl, size(first(trainloader)[1])) |> accelerator_device @printf "[Info] Compiling model with Reactant.jl\n" st_test = Lux.testmode(st) model_compiled = Reactant.compile(model, (x_ra, ps, st_test)) diff --git a/examples/CIFAR10/conv_mixer.jl b/examples/CIFAR10/conv_mixer.jl index 6981bf1e6..8c261b822 100644 --- a/examples/CIFAR10/conv_mixer.jl +++ b/examples/CIFAR10/conv_mixer.jl @@ -35,7 +35,7 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=0.0001, clip_norm::Bool=false, seed::Int=1234, epochs::Int=25, lr_max::Float64=0.05, - backend::String="reactant" + backend::String="reactant", bfloat16::Bool=false ) model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) @@ -46,5 +46,5 @@ Comonicon.@main function main(; [0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0] ) - return train_model(model, opt, lr_schedule; backend, batchsize, seed, epochs) + return train_model(model, opt, lr_schedule; backend, batchsize, seed, epochs, bfloat16) end diff --git a/examples/CIFAR10/simple_cnn.jl b/examples/CIFAR10/simple_cnn.jl index 075b1e28d..9eed26f19 100644 --- a/examples/CIFAR10/simple_cnn.jl +++ b/examples/CIFAR10/simple_cnn.jl @@ -23,12 +23,12 @@ end Comonicon.@main function main(; batchsize::Int=512, weight_decay::Float64=0.0001, clip_norm::Bool=false, seed::Int=1234, epochs::Int=50, lr::Float64=0.003, - backend::String="reactant" + backend::String="reactant", bfloat16::Bool=false ) model = SimpleCNN() opt = AdamW(; eta=lr, lambda=weight_decay) clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) - return train_model(model, opt, nothing; backend, batchsize, seed, epochs) + return train_model(model, opt, nothing; backend, batchsize, seed, epochs, bfloat16) end