Skip to content

Commit

Permalink
amend: fixed predict so that it return a vector of distributions-> fi…
Browse files Browse the repository at this point in the history
…xed evaluate!
  • Loading branch information
pasq-cat committed Oct 18, 2024
1 parent 74d778e commit be80e32
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
8 changes: 4 additions & 4 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.5"
manifest_format = "2.0"
project_hash = "0bd11d5fa58aad2714bf7893e520fc7c086ef3ca"
project_hash = "616a9e89f5c520a58672ad91b5525001e0dadab3"

[[deps.ANSIColoredPrinters]]
git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c"
Expand Down Expand Up @@ -955,10 +955,10 @@ uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
version = "1.3.1"

[[deps.LaplaceRedux]]
deps = ["ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"]
git-tree-sha1 = "a84b72a27c93c72a6af5d22216eb81a419b1b97a"
deps = ["CategoricalDistributions", "ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"]
path = "C:\\Users\\Pasqu\\Documents\\julia_projects\\LaplaceRedux.jl"
uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
version = "1.0.2"
version = "1.1.1"

[[deps.Latexify]]
deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"]
Expand Down
2 changes: 1 addition & 1 deletion src/direct_mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ function MMI.predict(m::LaplaceModels, fitresult, Xnew)
means, variances = yhat

# Create Normal distributions from the means and variances
return [Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)]
return vec([Normal(μ, sqrt(σ)) for (μ, σ) in zip(means, variances)])

else
predictions =
Expand Down
18 changes: 13 additions & 5 deletions test/direct_mlj_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ cv = CV(; nfolds=3)
model = LaplaceRegressor(model=flux_model,epochs=50)

X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
#train, test = partition(eachindex(y), 0.7); # 70:30 split
mach = machine(model, X, y) #|> MLJBase.fit! #|> (fitresult,cache,report)
MLJBase.fit!(mach,verbosity=1)
Xnew, _ = make_regression(3, 4; rng=123)
yhat = MLJBase.predict(mach, Xnew) # probabilistic predictions
MLJBase.predict_mode(mach, Xnew) # point predictions
MLJBase.fit!(mach, verbosity=1)
#Xnew, ynew = make_regression(3, 4; rng=123)
yhat = MLJBase.predict(mach, X) # probabilistic predictions
MLJBase.predict_mode(mach, X) # point predictions
MLJBase.fitted_params(mach) #fitted params function
MLJBase.training_losses(mach) #training loss history
model.epochs= 100 #changing number of epochs
Expand All @@ -31,7 +32,14 @@ cv = CV(; nfolds=3)
MLJBase.fit!(mach) #testing update function
model.fit_prior_nsteps = 200 #changing LaplaceRedux fit steps
MLJBase.fit!(mach) #testing update function (the laplace part)
# evaluate!(mach, resampling=cv, measure=l2, verbosity=0)
yhat = MLJBase.predict(mach, X) # probabilistic predictions
println( typeof(yhat) )
println( size(yhat) )
println( typeof(y) )
println( size(y) )

evaluate!(mach, resampling=cv, measure=log_loss, verbosity=0)

end


Expand Down

0 comments on commit be80e32

Please sign in to comment.