Skip to content

Commit

Permalink
init poc kme (#3)
Browse files Browse the repository at this point in the history
* add code for dummy data

* rename functions in dummy data

* add initial example for scalar to distribution regression using KME

* updatesssssss

* add initial example for distribution to distribution regression using KME
  • Loading branch information
DaanVanHauwermeiren authored Dec 10, 2024
1 parent 6f67873 commit 173c0ca
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 0 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
[deps]
COSMO = "1e616198-aa4e-51ec-90a2-23f7fbd31d8d"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
63 changes: 63 additions & 0 deletions src/dummydata.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# some functions to get dummy histogram data

using Plots
using Distributions

function get_grid()
binwidth = 10
# ! Note to self: x might be confusing, choose other variable name
xmin = 5
xmax = 250
# midpoints of the bins
x = xmin:binwidth:xmax
xl = xmin-binwidth/2:binwidth:xmax - binwidth/2
xu = xmin+binwidth/2:binwidth:xmax + binwidth/2
x, xl, xu
end

"""
Get some dummy data.
This is normalised histogram data of some truncated normal distribution.
x defines the midpoints of the bins (i.e. for a physical interpretation, think the equivalent particle diameter).
p defines the probability mass of each bin.
sum(p) should be approximately 1.
"""
function dummydata_1(;mu=150, sigma=20)
x, xl, xu = get_grid()

l = xl[1]
u = xu[end]
d = truncated(Normal(mu, sigma), l, u)

p = cdf.(d, xu) - cdf.(d, xl)

@assert sum(p) 1

x, p
end

"""
Get some dummy data, part 2: the bimodal one.
This is normalised histogram data of mixture model of 2 truncated normal distributions.
x defines the midpoints of the bins (i.e. for a physical interpretation, think the equivalent particle diameter).
p defines the probability mass of each bin.
sum(p) should be approximately 1.
"""
function dummydata_2(;
mu_1=100, sigma_1=50, mu_2=200, sigma_2=20, p_1=0.5, p_2=0.5
)
x, xl, xu = get_grid()

l = xl[1]
u = xu[end]
d = MixtureModel([
truncated(Normal(mu_1, sigma_1), l, u),
truncated(Normal(mu_2, sigma_2), l, u),
], [p_1, p_2])

p = cdf.(d, xu) - cdf.(d, xl)

@assert sum(p) 1

x, p
end
96 changes: 96 additions & 0 deletions src/kme__distr2distr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
using KernelFunctions, LinearAlgebra, StatsBase
using JuMP, COSMO
# using Plots

include("dummydata.jl")

N = 4 # 4^4 = 256 combinations
mu_1 = range(start=30, stop=45, length=N)
sigma_1 = range(start=10, stop=20, length=N)
mu_2 = range(start=55, stop=95, length=N)
sigma_2 = range(start=20, stop=50, length=N)
# let us keep p_1 and p_2 fixed for now
Z = stack(Iterators.product(mu_1, sigma_1, mu_2, sigma_2), dims=1)

res = map(x -> dummydata_2(mu_1=x[1], sigma_1=x[2], mu_2=x[3], sigma_2=x[4], p_1=0.5, p_2=0.5), eachrow(Z))
# all the same, so we can just take the first one
# 25 bins
bins = res[1][1] |> collect # collecting to ensure it is a vector
# N x 25 Matrix
A_in = stack(map(x -> x[2], res), dims=1)


mu_1 = range(start=75, stop=125, length=N)
sigma_1 = range(start=10, stop=30, length=N)
mu_2 = range(start=175, stop=225, length=N)
sigma_2 = range(start=20, stop=50, length=N)
# let us keep p_1 and p_2 fixed for now
Z = stack(Iterators.product(mu_1, sigma_1, mu_2, sigma_2), dims=1)

res = map(x -> dummydata_2(mu_1=x[1], sigma_1=x[2], mu_2=x[3], sigma_2=x[4], p_1=0.5, p_2=0.5), eachrow(Z))
# all the same, so we can just take the first one
# 25 bins
bins = res[1][1] |> collect # collecting to ensure it is a vector
# N x 25 Matrix
A_out = stack(map(x -> x[2], res), dims=1)



# This hyperparameter should be estimated
σ_RBF_grid = 0.712762951907117
k = with_lengthscale(RBFKernel(), σ_RBF_grid)
# adding some small bias for numerical stability
K = kernelmatrix(k, log10.(bins)) + 0.1*I


# kernel over the output distributions
Q = A_out * K



λ = 1e-4
# building a model
H = A_in * KC * A_in'
F = H * Q
F_loo = (I - Diagonal(H)) \ (F - Diagonal(H) *Q)


function compute_pre_image(F_loo::Matrix{Float64}, K::Matrix{Float64}, A::Matrix{Float64}=A_out)::Matrix{Float64}
n_distr, n_classes = size(A)
predicted_weights = Array{Float64}(undef, size(A))
for i in 1:n_distr
# COSMO
model = JuMP.Model(COSMO.Optimizer)
# Ipopt
# model = JuMP.Model(Ipopt.Optimizer)
# no printing to stdout !
set_silent(model)
@variable(model, β[1:n_classes] >= 0.0)
@constraint(model, sum(β) == 1.0)
@objective(model, Min, sum' * K * β) - 2dot(β, F_loo[i, :]))
optimize!(model)
predicted_weights[i,:] = JuMP.value.(β)
# @show termination_status(model)
# @show primal_status(model)
# @show objective_value(model)
end
return predicted_weights
end


predicted_weights = compute_pre_image(F_loo, K)
heatmap(predicted_weights, color = :viridis)
# SSE
sum((predicted_weights - A_out).^2)



# store figures
# for i in 1:size(A)[1]
# plot(bins, A[i,:], label="measured", dpi=150)
# plot!(bins, predicted_weights[i,:], label="predicted")
# plot!(xscale=:log10, xlabel="particle size", ylabel="volume fraction", legend=:topleft)
# fn = "prediction_$(exps[i]).png"
# savefig(fn)
# end

90 changes: 90 additions & 0 deletions src/kme__scalar2distr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
using KernelFunctions, LinearAlgebra, StatsBase
using JuMP, COSMO
import Base.Iterators
# using Plots

include("dummydata.jl")

N = 4 # 4^4 = 256 combinations
mu_1 = range(start=75, stop=125, length=N)
sigma_1 = range(start=10, stop=30, length=N)
mu_2 = range(start=175, stop=225, length=N)
sigma_2 = range(start=20, stop=50, length=N)
# let us keep p_1 and p_2 fixed for now
Z = stack(Iterators.product(mu_1, sigma_1, mu_2, sigma_2), dims=1)

res = map(x -> dummydata_2(mu_1=x[1], sigma_1=x[2], mu_2=x[3], sigma_2=x[4], p_1=0.5, p_2=0.5), eachrow(Z))
# all the same, so we can just take the first one
# 25 bins
bins = res[1][1] |> collect # collecting to ensure it is a vector
# N x 25 Matrix
A = stack(map(x -> x[2], res), dims=1)



# standardize process settings
Z = standardize(ZScoreTransform, Z, dims=1)

# This hyperparameter should be estimated
σ_RBF_grid = 0.712762951907117
k = with_lengthscale(RBFKernel(), σ_RBF_grid)
# adding some small bias for numerical stability
K = kernelmatrix(k, log10.(bins)) + 0.1*I

# This hyperparameter should be estimated
σ_RBF_procvars = 3
k = with_lengthscale(RBFKernel(), σ_RBF_procvars)
C = kernelmatrix(k, RowVecs(Z)) + 0.05*I

# heatmap(K, color = :viridis)
# heatmap(C, color = :viridis)


# kernel over the output distributions
Q = A * K

λ = 1e-4
# building a model
H = C / (C + λ*I)
F = H * Q
F_loo = (I - Diagonal(H)) \ (F - Diagonal(H) *Q)


function compute_pre_image(F_loo::Matrix{Float64}, K::Matrix{Float64}, A::Matrix{Float64}=A)::Matrix{Float64}
n_distr, n_classes = size(A)
predicted_weights = Array{Float64}(undef, size(A))
for i in 1:n_distr
# COSMO
model = JuMP.Model(COSMO.Optimizer)
# Ipopt
# model = JuMP.Model(Ipopt.Optimizer)
# no printing to stdout !
set_silent(model)
@variable(model, β[1:n_classes] >= 0.0)
@constraint(model, sum(β) == 1.0)
@objective(model, Min, sum' * K * β) - 2dot(β, F_loo[i, :]))
optimize!(model)
predicted_weights[i,:] = JuMP.value.(β)
# @show termination_status(model)
# @show primal_status(model)
# @show objective_value(model)
end
return predicted_weights
end


predicted_weights = compute_pre_image(F_loo, K)
heatmap(predicted_weights, color = :viridis)
# SSE
sum((predicted_weights - A).^2)


# store figures
# for i in 1:size(A)[1]
# plot(bins, A[i,:], label="measured", dpi=150)
# plot!(bins, predicted_weights[i,:], label="predicted")
# plot!(xscale=:log10, xlabel="particle size", ylabel="volume fraction", legend=:topleft)
# fn = "prediction_$(exps[i]).png"
# savefig(fn)
# end

0 comments on commit 173c0ca

Please sign in to comment.