Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: conditional VAE #1157

Merged
merged 8 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pages = [
"tutorials/intermediate/2_BayesianNN.md",
"tutorials/intermediate/3_HyperNet.md",
"tutorials/intermediate/4_PINN2DPDE.md",
"tutorials/intermediate/5_ConditionalVAE.md",
],
"Advanced" => [
"tutorials/advanced/1_GravitationalWaveForm.md"
Expand Down
4 changes: 4 additions & 0 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ export default defineConfig({
text: "Training a PINN on 2D PDE",
link: "/tutorials/intermediate/4_PINN2DPDE",
},
{
text: "Conditional VAE for MNIST using Reactant",
link: "/tutorials/intermediate/5_ConditionalVAE",
}
],
},
{
Expand Down
Binary file added docs/src/public/conditional_vae.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ const intermediate = [
src: "../pinn_nested_ad.gif",
caption: "Training a PINN",
desc: "Train a PINN to solve 2D PDEs (using Nested AD)."
},
{
href: "intermediate/5_ConditionalVAE",
src: "../conditional_vae.png",
caption: "Conditional VAE for MNIST using Reactant",
desc: "Train a Conditional VAE to generate images from a latent space."
}
];

Expand Down
14 changes: 11 additions & 3 deletions docs/tutorials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const INTERMEDIATE_TUTORIALS = [
"BayesianNN/main.jl" => "CPU",
"HyperNet/main.jl" => "CUDA",
"PINN2DPDE/main.jl" => "CUDA",
"ConditionalVAE/main.jl" => "CUDA",
]
const ADVANCED_TUTORIALS = [
"GravitationalWaveForm/main.jl" => "CPU",
Expand Down Expand Up @@ -41,9 +42,16 @@ end

const TUTORIALS_BUILDING = if BUILDKITE_PARALLEL_JOB_COUNT > 0
id = parse(Int, ENV["BUILDKITE_PARALLEL_JOB"]) + 1 # Index starts from 0
splits = collect(Iterators.partition(TUTORIALS_WITH_BACKEND,
cld(length(TUTORIALS_WITH_BACKEND), BUILDKITE_PARALLEL_JOB_COUNT)))
id > length(splits) ? [] : splits[id]
splits = Vector{Vector{eltype(TUTORIALS_WITH_BACKEND)}}(
undef, BUILDKITE_PARALLEL_JOB_COUNT)
for i in eachindex(TUTORIALS_WITH_BACKEND)
idx = mod1(i, BUILDKITE_PARALLEL_JOB_COUNT)
if !isassigned(splits, idx)
splits[idx] = Vector{eltype(TUTORIALS_WITH_BACKEND)}()
end
push!(splits[idx], TUTORIALS_WITH_BACKEND[i])
end
(id > length(splits) || !isassigned(splits, id)) ? [] : splits[id]
else
TUTORIALS_WITH_BACKEND
end
Expand Down
30 changes: 30 additions & 0 deletions examples/ConditionalVAE/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[deps]
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
ConcreteStructs = "0.2.3"
DataAugmentation = "0.3.2"
Enzyme = "0.13.20"
ImageShow = "0.3.8"
Images = "0.26.1"
Lux = "1.4.1"
MLDatasets = "0.7.18"
MLUtils = "0.4.4"
OneHotArrays = "0.2.6"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.9"
287 changes: 287 additions & 0 deletions examples/ConditionalVAE/main.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
# # [Conditional VAE for MNIST using Reactant](@id Conditional-VAE-Tutorial)

# Convolutional variational autoencoder (CVAE) implementation in MLX using MNIST. This is
# based on the [CVAE implementation in MLX](https://github.com/ml-explore/mlx-examples/blob/main/cvae/).

using Lux, Reactant, MLDatasets, Random, Statistics, Enzyme, MLUtils, DataAugmentation,
ConcreteStructs, OneHotArrays, ImageShow, Images, Printf, Optimisers, Comonicon,
StableRNGs

const xdev = reactant_device(; force=true)
const cdev = cpu_device()

# ## Model Definition

# First we will define the encoder.It maps the input to a normal distribution in latent
# space and sample a latent vector from that distribution.

function cvae_encoder(
rng=Random.default_rng(); num_latent_dims::Int,
image_shape::Dims{3}, max_num_filters::Int
)
flattened_dim = prod(image_shape[1:2] .÷ 8) * max_num_filters
return @compact(;
embed=Chain(
Chain(
Conv((3, 3), image_shape[3] => max_num_filters ÷ 4; stride=2, pad=1),
BatchNorm(max_num_filters ÷ 4, leakyrelu)
),
Chain(
Conv((3, 3), max_num_filters ÷ 4 => max_num_filters ÷ 2; stride=2, pad=1),
BatchNorm(max_num_filters ÷ 2, leakyrelu)
),
Chain(
Conv((3, 3), max_num_filters ÷ 2 => max_num_filters; stride=2, pad=1),
BatchNorm(max_num_filters, leakyrelu)
),
FlattenLayer()
),
proj_mu=Dense(flattened_dim, num_latent_dims; init_bias=zeros32),
proj_log_var=Dense(flattened_dim, num_latent_dims; init_bias=zeros32),
rng) do x
y = embed(x)

μ = proj_mu(y)
logσ² = proj_log_var(y)

T = eltype(logσ²)
logσ² = clamp.(logσ², -T(20.0f0), T(10.0f0))
σ = exp.(logσ² .* T(0.5))

## Generate a tensor of random values from a normal distribution
rng = Lux.replicate(rng)
ϵ = randn_like(rng, σ)

## Reparameterization trick to brackpropagate through sampling
z = ϵ .* σ .+ μ

@return z, μ, logσ²
end
end

# Similarly we define the decoder.

function cvae_decoder(; num_latent_dims::Int, image_shape::Dims{3}, max_num_filters::Int)
flattened_dim = prod(image_shape[1:2] .÷ 8) * max_num_filters
return @compact(;
linear=Dense(num_latent_dims, flattened_dim),
upchain=Chain(
Chain(
Upsample(2),
Conv((3, 3), max_num_filters => max_num_filters ÷ 2; stride=1, pad=1),
BatchNorm(max_num_filters ÷ 2, leakyrelu)
),
Chain(
Upsample(2),
Conv((3, 3), max_num_filters ÷ 2 => max_num_filters ÷ 4; stride=1, pad=1),
BatchNorm(max_num_filters ÷ 4, leakyrelu)
),
Chain(
Upsample(2),
Conv((3, 3), max_num_filters ÷ 4 => image_shape[3],
sigmoid; stride=1, pad=1)
)
),
max_num_filters) do x
y = linear(x)
img = reshape(y, image_shape[1] ÷ 8, image_shape[2] ÷ 8, max_num_filters, :)
@return upchain(img)
end
end

@concrete struct CVAE <: Lux.AbstractLuxContainerLayer{(:encoder, :decoder)}
encoder <: Lux.AbstractLuxLayer
decoder <: Lux.AbstractLuxLayer
end

function CVAE(rng=Random.default_rng(); num_latent_dims::Int,
image_shape::Dims{3}, max_num_filters::Int)
decoder = cvae_decoder(; num_latent_dims, image_shape, max_num_filters)
encoder = cvae_encoder(rng; num_latent_dims, image_shape, max_num_filters)
return CVAE(encoder, decoder)
end

function (cvae::CVAE)(x, ps, st)
(z, μ, logσ²), st_enc = cvae.encoder(x, ps.encoder, st.encoder)
x_rec, st_dec = cvae.decoder(z, ps.decoder, st.decoder)
return (x_rec, μ, logσ²), (; encoder=st_enc, decoder=st_dec)
end

function encode(cvae::CVAE, x, ps, st)
(z, _, _), st_enc = cvae.encoder(x, ps.encoder, st.encoder)
return z, (; encoder=st_enc, st.decoder)
end

function decode(cvae::CVAE, z, ps, st)
x_rec, st_dec = cvae.decoder(z, ps.decoder, st.decoder)
return x_rec, (; decoder=st_dec, st.encoder)
end

# ## Loading MNIST

@concrete struct TensorDataset
dataset
transform
end

Base.length(ds::TensorDataset) = length(ds.dataset)

function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange})
img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3))
return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img)
end

function loadmnist(batchsize, image_size::Dims{2})
## Load MNIST: Only 1500 for demonstration purposes
N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing
train_dataset = MNIST(; split=:train)
test_dataset = MNIST(; split=:test)
if N !== nothing
train_dataset = train_dataset[1:N]
test_dataset = test_dataset[1:N]
end

train_transform = ScaleKeepAspect(image_size) |> ImageToTensor()
trainset = TensorDataset(train_dataset, train_transform)
trainloader = DataLoader(trainset; batchsize, shuffle=true, partial=false)

return trainloader
end

# ## Helper Functions

# Generate an Image Grid from a list of images

function create_image_grid(imgs::AbstractArray, grid_rows::Int, grid_cols::Int)
total_images = grid_rows * grid_cols
imgs = map(eachslice(imgs[:, :, :, 1:total_images]; dims=4)) do img
cimg = size(img, 3) == 1 ? colorview(Gray, view(img, :, :, 1)) : colorview(RGB, img)
return cimg'
end
return create_image_grid(imgs, grid_rows, grid_cols)
end

function create_image_grid(images::Vector, grid_rows::Int, grid_cols::Int)
## Check if the number of images matches the grid
total_images = grid_rows * grid_cols
@assert length(images) == total_images

## Get the size of a single image (assuming all images are the same size)
img_height, img_width = size(images[1])

## Create a blank grid canvas
grid_height = img_height * grid_rows
grid_width = img_width * grid_cols
grid_canvas = similar(images[1], grid_height, grid_width)

## Place each image in the correct position on the canvas
for idx in 1:total_images
row = div(idx - 1, grid_cols) + 1
col = mod(idx - 1, grid_cols) + 1

start_row = (row - 1) * img_height + 1
start_col = (col - 1) * img_width + 1

grid_canvas[start_row:(start_row + img_height - 1), start_col:(start_col + img_width - 1)] .= images[idx]
end

return grid_canvas
end

function loss_function(model, ps, st, X)
(y, μ, logσ²), st = model(X, ps, st)
reconstruction_loss = MSELoss(; agg=sum)(y, X)
kldiv_loss = -sum(1 .+ logσ² .- μ .^ 2 .- exp.(logσ²)) / 2
loss = reconstruction_loss + kldiv_loss
return loss, st, (; y, μ, logσ², reconstruction_loss, kldiv_loss)
end

function generate_images(
model, ps, st; num_samples::Int=128, num_latent_dims::Int, decode_compiled=nothing)
z = randn(Float32, num_latent_dims, num_samples) |> get_device((ps, st))
if decode_compiled === nothing
images, _ = decode(model, z, ps, Lux.testmode(st))
else
images, _ = decode_compiled(model, z, ps, Lux.testmode(st))
images = images |> cpu_device()
end
return create_image_grid(images, 8, num_samples ÷ 8)
end

function reconstruct_images(model, ps, st, X)
(recon, _, _), _ = model(X, ps, Lux.testmode(st))
recon = recon |> cpu_device()
return create_image_grid(recon, 8, size(X, ndims(X)) ÷ 8)
end

# ## Training the Model

function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_filters=64,
seed=0, epochs=50, weight_decay=1e-5, learning_rate=1e-3, num_samples=batchsize)
rng = Xoshiro()
Random.seed!(rng, seed)

cvae = CVAE(rng; num_latent_dims, image_shape=(image_size..., 1), max_num_filters)
ps, st = Lux.setup(rng, cvae) |> xdev

z = randn(Float32, num_latent_dims, num_samples) |> xdev
decode_compiled = @compile decode(cvae, z, ps, Lux.testmode(st))
x = randn(Float32, image_size..., 1, batchsize) |> xdev
cvae_compiled = @compile cvae(x, ps, Lux.testmode(st))

train_dataloader = loadmnist(batchsize, image_size) |> xdev

opt = AdamW(; eta=learning_rate, lambda=weight_decay)

train_state = Training.TrainState(cvae, ps, st, opt)

@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps)/1e6)

is_vscode = isdefined(Main, :VSCodeServer)
empty_row, model_img_full = nothing, nothing

for epoch in 1:epochs
loss_total = 0.0f0
total_samples = 0
total_time = 0.0

for (i, X) in enumerate(train_dataloader)
throughput_tic = time()
(_, loss, stats, train_state) = Training.single_train_step!(
AutoEnzyme(), loss_function, X, train_state)
throughput_toc = time()

loss_total += loss
total_samples += size(X, ndims(X))
total_time += throughput_toc - throughput_tic

if i % 250 == 0 || i == length(train_dataloader)
throughput = total_samples / total_time
@printf "Epoch %d, Iter %d, Loss: %.7f, Throughput: %.6f im/s\n" epoch i loss throughput
end
end

train_loss = loss_total / length(train_dataloader)
throughput = total_samples / total_time
@printf "Epoch %d, Train Loss: %.7f, Time: %.4fs, Throughput: %.6f im/s\n" epoch train_loss total_time throughput

if is_vscode || epoch == epochs
recon_images = reconstruct_images(
cvae_compiled, train_state.parameters, train_state.states,
first(train_dataloader))
gen_images = generate_images(
cvae, train_state.parameters, train_state.states;
num_samples, num_latent_dims, decode_compiled)
if empty_row === nothing
empty_row = similar(gen_images, image_size[1], size(gen_images, 2))
fill!(empty_row, 0)
end
model_img_full = vcat(recon_images, empty_row, gen_images)
is_vscode && display(model_img_full)
end
end

return model_img_full
end

main()
Loading