diff --git a/Project.toml b/Project.toml index 38ad2b5..955e80f 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.1.0" [deps] BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" @@ -24,10 +25,12 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" +Unzip = "41fe7b60-77ed-43a1-b4f0-825fd5a5650d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extras] diff --git a/examples/Sliced_ISL/MNIST_sliced.jl b/examples/Sliced_ISL/MNIST_sliced.jl index 47a13ea..ade17c5 100644 --- a/examples/Sliced_ISL/MNIST_sliced.jl +++ b/examples/Sliced_ISL/MNIST_sliced.jl @@ -5,6 +5,14 @@ using Images using ImageTransformations # For resizing images if necessary using LinearAlgebra +function load_mnist() + # Load MNIST data + train_x, train_y = MLDatasets.MNIST.traindata() + test_x, test_y = MLDatasets.MNIST.testdata() + + return (reshape(Float32.(train_x), 28 * 28, :), train_y)#, (test_x, test_y) +end + function load_mnist(digit::Int) # Load MNIST data train_x, train_y = MLDatasets.MNIST.traindata() @@ -41,14 +49,14 @@ function load_mnist_normalized(digit::Int, max::Int) image_tensor = reshape(@.(2.0f0 * selected_images - 1.0f0), 28, 28, :) - train_data = reshape(image_tensor, 28 * 28, :) - - return (train_data, train_y) + return (reshape(Float32.(image_tensor), 28 * 28, :), train_y) end +(train_x, train_y) = load_mnist() +(train_x, train_y) = (train_x[:, 1:5000], train_y[1:5000]) (train_x, train_y) = load_mnist(0) (train_x, train_y) = load_mnist(9, 100) -(train_x, train_y) = load_mnist_normalized(9, 100) +(train_x, train_y) = load_mnist_normalized(8, 100) # Dimension dims = 100 @@ -103,6 +111,9 @@ function Discriminator() ) end +latent_dim = 100 +# weight initialization as given in the paper https://arxiv.org/abs/1511.06434 +dcgan_init(shape...) = randn(Float32, shape...) * 0.02f0 function Generator(latent_dim::Int) return Chain( Dense(latent_dim, 7 * 7 * 256), @@ -113,11 +124,12 @@ function Generator(latent_dim::Int) ConvTranspose((4, 4), 128 => 64; stride=2, pad=1, init=dcgan_init), BatchNorm(64, relu), ConvTranspose((4, 4), 64 => 1; stride=2, pad=1, init=dcgan_init), + Flux.flatten, x -> tanh.(x), ) end -model = Generator(dims) +model = Generator(latent_dim) #model = Chain( ConvTranspose((7, 7), 100 => 256, stride=1, padding=0), BatchNorm(256, relu), ConvTranspose((4, 4), 256 => 128, stride=2, padding=1), BatchNorm(128, relu), ConvTranspose((4, 4), 128 => 1, stride=2, padding=1), tanh )) # Mean vector (zero vector of length dim) @@ -132,16 +144,16 @@ noise_model = MvNormal(mean_vector, cov_matrix) n_samples = 10000 hparams = HyperParamsSlicedISL(; - K=10, samples=100, epochs=1, η=1e-2, noise_model=noise_model, m=200 + K=10, samples=100, epochs=1, η=1e-2, noise_model=noise_model, m=10 ) # Create a data loader for training batch_size = 100 -train_loader = DataLoader(train_x; batchsize=batch_size, shuffle=false, partial=false) +train_loader = DataLoader(train_x; batchsize=batch_size, shuffle=true, partial=false) total_loss = [] -@showprogress for _ in 1:20 - append!(total_loss, sliced_invariant_statistical_loss(model, train_loader, hparams)) +@showprogress for _ in 1:200 + append!(total_loss, optimized_loss(model, train_loader, hparams)) end img = model(Float32.(rand(hparams.noise_model, 1))) diff --git a/examples/Sliced_ISL/MNIST_sliced2.jl b/examples/Sliced_ISL/MNIST_sliced2.jl new file mode 100644 index 0000000..5833958 --- /dev/null +++ b/examples/Sliced_ISL/MNIST_sliced2.jl @@ -0,0 +1,47 @@ +using ISL +using Flux +using MLDatasets +using Images +using ImageTransformations # For resizing images if necessary +using LinearAlgebra + +function load_mnist() + # Load MNIST data + train_x, train_y = MLDatasets.MNIST.traindata() + test_x, test_y = MLDatasets.MNIST.testdata() + return (reshape(Float32.(train_x), 28 * 28, :), train_y)#, (test_x, test_y) +end + +(images, labels) = load_mnist() + +n_outputs = length(unique(labels)) + +ys = [Flux.onehot(labels, 0:9) for labels in labels] + +n_inputs, n_latent, n_outputs = 28 * 28, 50, 10 +model = Chain( + Dense(n_inputs, n_latent, identity), + Dense(n_latent, n_latent, identity), + Dense(n_latent, n_outputs, identity), + softmax, +) +loss(x, y) = Flux.crossentropy(model(x), y) + +function create_batch(r) + xs = images[:, r] + ys = [Flux.onehot(labels, 0:9) for labels in labels[r]] + return (xs, Flux.batch(ys)) +end + +trainbatch = create_batch(1:5000) + +opt = Flux.setup(Flux.Adam(hparams.η), model) +opt = ADAM() + +@showprogress for _ in 1:1000 + Flux.train!(loss, Flux.params(model), [trainbatch], opt) +end + +model(images[:, 1]) +img2 = reshape(images[:, 1], 28, 28) +display(Gray.(img2)) diff --git a/src/CustomLossFunction.jl b/src/CustomLossFunction.jl index 81ce6e1..59d31f3 100644 --- a/src/CustomLossFunction.jl +++ b/src/CustomLossFunction.jl @@ -86,7 +86,7 @@ The contribution is computed according to the formula: ``` """ function γ(yₖ::Matrix{T}, yₙ::T, m::Int64) where {T<:AbstractFloat} - eₘ(m) = [j == m ? 1.0 : 0.0 for j in 0:length(yₖ)] + eₘ(m) = [j == m ? T(1.0) : T(0.0) for j in 0:length(yₖ)] return eₘ(m) * ψₘ(ϕ(yₖ, yₙ), m) end; @@ -140,7 +140,7 @@ The formula for generating `aₖ` is as follows: aₖ = ∑_{k=0}^K γ(ŷ, y, k) = ∑_{k=0}^K ∑_{i=1}^N ψₖ(ŷ, yᵢ) ``` """ -function generate_aₖ(ŷ::Matrix{T}, y::T) where {T<:AbstractFloat} +@inline function generate_aₖ(ŷ::Matrix{T}, y::T) where {T<:AbstractFloat} return sum([γ(ŷ, y, k) for k in 0:length(ŷ)]) end @@ -153,7 +153,8 @@ Scalar difference between the vector representing our subrogate histogram and th loss = ||q-1/k+1||_{2} = ∑_{k=0}^K (qₖ - 1/K+1)^2 ``` """ -scalar_diff(q::Vector{T}) where {T<:AbstractFloat} = sum((q .- (1 ./ length(q))) .^ 2) +@inline scalar_diff(q::Vector{T}) where {T<:AbstractFloat} = + sum((q .- (T(1.0f0) ./ T(length(q)))) .^ 2) """ `jensen_shannon_∇(aₖ)`` @@ -277,6 +278,13 @@ function get_window_of_Aₖ(transform, model, data, K::Int64) return [count(x -> x == i, window) for i in 0:K] end; +@inline function get_window_of_Aₖ(transform, model, ω, data, K::Int64) + ŷₖ = model(Float32.(rand(transform, K))) + ŷₖ_proj = [dot(ω, ŷₖ[:, i]) for i in 1:size(ŷₖ, 2)] + window = count.([ŷₖ_proj .< d for d in data]) + return [count(x -> x == i, window) for i in 0:K] +end; + """ `convergence_to_uniform(aₖ)`` @@ -475,7 +483,7 @@ Base.@kwdef mutable struct HyperParamsSlicedISL m::Int = 10 # Number of random directions end -function sample_random_direction(n::Int)::Vector{Float32} +@inline function sample_random_direction(n::Int)::Vector{Float32} # Generate a random vector where each component is from a standard normal distribution random_vector = randn(Float32, n) @@ -492,21 +500,22 @@ function sample_ornormal_random_direction(n::Int, m::Int)::Vector{Vector{Float32 return [matrix[:, i] for i in 1:m] end +using ThreadsX function sliced_invariant_statistical_loss(nn_model, loader, hparams::HyperParamsSlicedISL) @assert loader.batchsize == hparams.samples @assert length(loader) == hparams.epochs losses = Vector{Float32}() optim = Flux.setup(Flux.Adam(hparams.η), nn_model) @showprogress for data in loader + Ω = ThreadsX.map(_ -> sample_random_direction(size(data)[1]), 1:(hparams.m)) loss, grads = Flux.withgradient(nn_model) do nn - Ω = [sample_random_direction(size(data)[1]) for _ in 1:(hparams.m)] total = 0.0f0 for ω in Ω aₖ = zeros(hparams.K + 1) for i in 1:(hparams.samples) x = Float32.(rand(hparams.noise_model, hparams.K)) yₖ = nn(x) - s = collect(reshape(ω' * yₖ, 1, hparams.K)) + s = Matrix(ω' * yₖ) aₖ += generate_aₖ(s, ω ⋅ data[:, i]) end total += scalar_diff(aₖ ./ sum(aₖ)) @@ -519,6 +528,46 @@ function sliced_invariant_statistical_loss(nn_model, loader, hparams::HyperParam return losses end; +function sliced_invariant_statistical_loss_optimized(nn_model, loader, hparams) + @assert loader.batchsize == hparams.samples + @assert length(loader) == hparams.epochs + losses = Vector{Float32}() + optim = Flux.setup(Flux.Adam(hparams.η), nn_model) + + @showprogress for data in loader + Ω = ThreadsX.map(_ -> sample_random_direction(size(data)[1]), 1:(hparams.m)) + loss, grads = Flux.withgradient(nn_model) do nn + total = 0.0f0 + for ω in Ω + aₖ = zeros(Float32, hparams.K + 1) # Reset aₖ for each new ω + + # Generate all random numbers in one go + x_batch = rand(hparams.noise_model, hparams.samples * hparams.K) + + # Process batch through nn_model + yₖ_batch = nn(Float32.(x_batch)) + + s = Matrix(ω' * yₖ_batch) + + @inbounds for i in 2:(hparams.samples) + start_col = hparams.K * (i - 1) + end_col = hparams.K * i + + aₖ_slice = s[:, start_col:(end_col - 1)] + ω_data_dot_product = ω ⋅ data[:, i] + + aₖ += generate_aₖ(aₖ_slice, ω_data_dot_product) + end + total += scalar_diff(aₖ ./ sum(aₖ)) + end + total / hparams.m + end + Flux.update!(optim, nn_model, grads[1]) + push!(losses, loss) + end + return losses +end + function sliced_ortonormal_invariant_statistical_loss( nn_model, loader, hparams::HyperParamsSlicedISL ) @@ -559,17 +608,23 @@ function sliced_invariant_statistical_loss_selected_directions( @assert length(loader) == hparams.epochs losses = Vector{Float32}() optim = Flux.setup(Flux.Adam(hparams.η), nn_model) + + function compute_p_value(nn, data, hparams) + ω = sample_random_direction(size(data)[1]) + return ( + ω, + convergence_to_uniform( + get_window_of_Aₖ(hparams.noise_model, nn, ω, data, hparams.K) + ), + ) + end + @showprogress for data in loader + values = ThreadsX.map(_ -> compute_p_value(nn_model, data, hparams), 1:1000) + + sorted_Ω = [direction for (direction, _) in sort(values; by=x -> x[2], rev=true)][1:(hparams.m)] + loss, grads = Flux.withgradient(nn_model) do nn - Ω = [sample_random_direction(size(data)[1]) for _ in 1:1000] - p_values = [ - convergence_to_uniform( - get_window_of_Aₖ(hparams.noise_model, nn, ω .⋅ data, hparams.K) - ) for ω in Ω - ] - direction_pvalues = zip(Ω, p_values) - sorted_directions = sort(direction_pvalues; by=x -> x[2]) - sorted_Ω = [direction for (direction, pvalue) in sorted_directions] total = 0.0f0 for ω in sorted_Ω aₖ = zeros(hparams.K + 1) diff --git a/src/ISL.jl b/src/ISL.jl index 43a86cd..954d4d3 100644 --- a/src/ISL.jl +++ b/src/ISL.jl @@ -9,6 +9,7 @@ using MLUtils using LinearAlgebra using Parameters: @with_kw using ProgressMeter +using Random using StaticArrays @@ -39,5 +40,6 @@ export _sigmoid, sliced_invariant_statistical_loss_multithreaded, sliced_invariant_statistical_loss_multithreaded_2, sliced_invariant_statistical_loss_selected_directions, - sliced_ortonormal_invariant_statistical_loss + sliced_ortonormal_invariant_statistical_loss, + sliced_invariant_statistical_loss_optimized end