Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed issue with MLJFlux.train #113

Merged
merged 4 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),

*Note*: We try to adhere to these practices as of version [v0.2.1].


## Version [1.0.2] - 2024-08-12
###
- added TaijaPlotting to the docs env
### Changed
- modified the MLJFlux.train function so that it now properly return a trained chain [[#112](https://github.com/JuliaTrustworthyAI/LaplaceRedux.jl/issues/112)]

## Version [1.0.0] - 2024-07-22

### Changed
Expand Down
99 changes: 20 additions & 79 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# This file is machine-generated - editing it directly is not advised

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rockdeldiablo, changing the julia version is correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MojiFarmanbar yes it shouldn't be a problem.

julia_version = "1.10.4"
julia_version = "1.10.3"
manifest_format = "2.0"
project_hash = "07bab9fa5d046478b21247a44464171c6b19ad4c"
project_hash = "e74e14fb1831d3b0a43faae4918a6a4b752d6a56"

[[deps.ANSIColoredPrinters]]
git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c"
Expand Down Expand Up @@ -213,18 +213,6 @@ version = "0.10.8"
SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"

[[deps.CategoricalDistributions]]
deps = ["CategoricalArrays", "Distributions", "Missings", "OrderedCollections", "Random", "ScientificTypes"]
git-tree-sha1 = "926862f549a82d6c3a7145bc7f1adff2a91a39f0"
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
version = "0.1.15"

[deps.CategoricalDistributions.extensions]
UnivariateFiniteDisplayExt = "UnicodePlots"

[deps.CategoricalDistributions.weakdeps]
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"

[[deps.ChainRules]]
deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"]
git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03"
Expand Down Expand Up @@ -372,6 +360,12 @@ version = "1.0.0"
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[deps.Dbus_jll]]
deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl"]
git-tree-sha1 = "fc173b380865f70627d7dd1190dc2fce6cc105af"
uuid = "ee1fde0b-3d02-5ea6-8484-8dfef6360eab"
version = "1.14.10+0"

[[deps.DefineSingletons]]
git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c"
uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52"
Expand Down Expand Up @@ -580,7 +574,7 @@ deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"

[[deps.GLFW_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "xkbcommon_jll"]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Libglvnd_jll", "Xorg_libXcursor_jll", "Xorg_libXi_jll", "Xorg_libXinerama_jll", "Xorg_libXrandr_jll", "libdecor_jll", "xkbcommon_jll"]
git-tree-sha1 = "3f74912a156096bd8fdbef211eff66ab446e7297"
uuid = "0656b61e-2033-5cc2-a64a-77c0f6c09b89"
version = "3.4.0+0"
Expand Down Expand Up @@ -831,12 +825,6 @@ git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec"
uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
version = "1.3.1"

[[deps.LaplaceRedux]]
deps = ["ChainRulesCore", "Compat", "ComputationalResources", "Distributions", "Flux", "LinearAlgebra", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables", "Tullio", "Zygote"]
path = ".."
uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
version = "1.0.0"

[[deps.Latexify]]
deps = ["Format", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Requires"]
git-tree-sha1 = "5b0d630f3020b82c0775a51d05895852f8506f50"
Expand All @@ -860,12 +848,6 @@ version = "1.2.2"
deps = ["Artifacts", "Pkg"]
uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3"

[[deps.LearnAPI]]
deps = ["InteractiveUtils", "Statistics"]
git-tree-sha1 = "ec695822c1faaaa64cee32d0b21505e1977b4809"
uuid = "92ad9a40-7767-427a-9ee6-6e577f1266cb"
version = "0.1.0"

[[deps.LibCURL]]
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
Expand Down Expand Up @@ -970,18 +952,6 @@ git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075"
uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36"
version = "1.0.3"

[[deps.MLJBase]]
deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LearnAPI", "LinearAlgebra", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "RecipesBase", "Reexport", "ScientificTypes", "Serialization", "StatisticalMeasuresBase", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "6f45e12073bc2f2e73ed0473391db38c31e879c9"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
version = "1.7.0"

[deps.MLJBase.extensions]
DefaultMeasuresExt = "StatisticalMeasures"

[deps.MLJBase.weakdeps]
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"

[[deps.MLJFlux]]
deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "Optimisers", "ProgressMeter", "Random", "Statistics", "Tables"]
git-tree-sha1 = "50c7f24b84005a2a80875c10d4f4059df17a0f68"
Expand Down Expand Up @@ -1188,11 +1158,11 @@ git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.11.31"

[[deps.Parameters]]
deps = ["OrderedCollections", "UnPack"]
git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe"
uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a"
version = "0.12.3"
[[deps.Pango_jll]]
deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "FriBidi_jll", "Glib_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl"]
git-tree-sha1 = "cb5a2ab6763464ae0f19c86c56c63d4a2b0f5bda"
uuid = "36c8627f-9965-5494-a995-c6b170f724f3"
version = "1.52.2+0"

[[deps.Parsers]]
deps = ["Dates", "PrecompileTools", "UUIDs"]
Expand Down Expand Up @@ -1423,12 +1393,6 @@ version = "0.4.2+0"
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
version = "0.7.0"

[[deps.ScientificTypes]]
deps = ["CategoricalArrays", "ColorTypes", "Dates", "Distributions", "PrettyTables", "Reexport", "ScientificTypesBase", "StatisticalTraits", "Tables"]
git-tree-sha1 = "75ccd10ca65b939dab03b812994e571bf1e3e1da"
uuid = "321657f4-b219-11e9-178b-2701a2544e81"
version = "3.0.2"

[[deps.ScientificTypesBase]]
git-tree-sha1 = "a8e18eb383b5ecf1b5e6fc237eb39255044fd92b"
uuid = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
Expand Down Expand Up @@ -1529,12 +1493,6 @@ git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682"
uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
version = "1.4.3"

[[deps.StatisticalMeasuresBase]]
deps = ["CategoricalArrays", "InteractiveUtils", "MLUtils", "MacroTools", "OrderedCollections", "PrecompileTools", "ScientificTypesBase", "Statistics"]
git-tree-sha1 = "17dfb22e2e4ccc9cd59b487dce52883e0151b4d3"
uuid = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"
version = "0.1.1"

[[deps.StatisticalTraits]]
deps = ["ScientificTypesBase"]
git-tree-sha1 = "542d979f6e756f13f862aa00b224f04f9e445f11"
Expand Down Expand Up @@ -1685,24 +1643,6 @@ git-tree-sha1 = "79eb0ed763084a3e7de81fe1838379ac6a23b6a0"
uuid = "592b5752-818d-11e9-1e9a-2b8ca4a44cd1"
version = "2.0.3"

[[deps.Tullio]]
deps = ["DiffRules", "LinearAlgebra", "Requires"]
git-tree-sha1 = "6d476962ba4e435d7f4101a403b1d3d72afe72f3"
uuid = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
version = "0.3.7"

[deps.Tullio.extensions]
TullioCUDAExt = "CUDA"
TullioChainRulesCoreExt = "ChainRulesCore"
TullioFillArraysExt = "FillArrays"
TullioTrackerExt = "Tracker"

[deps.Tullio.weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[[deps.URIs]]
git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b"
uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4"
Expand All @@ -1712,11 +1652,6 @@ version = "1.5.1"
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[deps.UnPack]]
git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b"
uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
version = "1.0.2"

[[deps.Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

Expand Down Expand Up @@ -2024,6 +1959,12 @@ deps = ["Artifacts", "Libdl"]
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
version = "5.8.0+1"

[[deps.libdecor_jll]]
deps = ["Artifacts", "Dbus_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "Pango_jll", "Wayland_jll", "xkbcommon_jll"]
git-tree-sha1 = "9bf7903af251d2050b467f76bdbe57ce541f7f4f"
uuid = "1183f4f0-6f2a-5f1a-908b-139f9cdfea6f"
version = "0.2.2+0"

[[deps.libevdev_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "141fe65dc3efabb0b1d5ba74e91f6ad26f84cc22"
Expand Down
1 change: 0 additions & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
Expand Down
63 changes: 31 additions & 32 deletions src/mlj_flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,22 +219,6 @@ function MLJFlux.train(
)
X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X

if !isa(chain, AbstractLaplace)
la = LaplaceRedux.Laplace(
chain;
likelihood=:regression,
subset_of_weights=model.subset_of_weights,
subnetwork_indices=model.subnetwork_indices,
hessian_structure=model.hessian_structure,
backend=model.backend,
σ=model.σ,
μ₀=model.μ₀,
P₀=model.P₀,
)
else
la = chain
end

# Initialize history:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rockdeldiablo , I would suggest to keep this if-else on top as it was because it checks if chain is abstractlaplace or not. moving this piece down before LaplaceRedux.fit doesn't add any logics. if you agree i would suggest the followings:

  • keeping that if-else as it was to check if chain is laplace object
  • line 242 and 243, changing chain to la in for-loop, it would solve the problem
    what do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MojiFarmanbar honestly i do not think it would work because
MLJFlux.train_epoch(
model, chain, regularized_optimiser, optimiser_state, X, y
) requires an actual chain not a laplace object

the laplace object is required only by LaplaceRedux.fit which is at the end.

have you tested it?

history = []
verbose_laplace = false
Expand Down Expand Up @@ -263,6 +247,22 @@ function MLJFlux.train(
push!(history, current_loss)
end

if !isa(chain, AbstractLaplace)
la = LaplaceRedux.Laplace(
chain;
likelihood=:regression,
subset_of_weights=model.subset_of_weights,
subnetwork_indices=model.subnetwork_indices,
hessian_structure=model.hessian_structure,
backend=model.backend,
σ=model.σ,
μ₀=model.μ₀,
P₀=model.P₀,
)
else
la = chain
end

# fit the Laplace model:
LaplaceRedux.fit!(la, zip(X, y))
optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps)
Expand Down Expand Up @@ -387,22 +387,6 @@ function MLJFlux.train(
)
X = X isa Tables.MatrixTable ? MLJBase.matrix(X) : X

if !isa(chain, AbstractLaplace)
la = LaplaceRedux.Laplace(
chain;
likelihood=:classification,
subset_of_weights=model.subset_of_weights,
subnetwork_indices=model.subnetwork_indices,
hessian_structure=model.hessian_structure,
backend=model.backend,
σ=model.σ,
μ₀=model.μ₀,
P₀=model.P₀,
)
else
la = chain
end

# Initialize history:
history = []
verbose_laplace = false
Expand Down Expand Up @@ -432,6 +416,21 @@ function MLJFlux.train(
push!(history, current_loss)
end

if !isa(chain, AbstractLaplace)
la = LaplaceRedux.Laplace(
chain;
likelihood=:classification,
subset_of_weights=model.subset_of_weights,
subnetwork_indices=model.subnetwork_indices,
hessian_structure=model.hessian_structure,
backend=model.backend,
σ=model.σ,
μ₀=model.μ₀,
P₀=model.P₀,
)
else
la = chain
end
# fit the Laplace model:
LaplaceRedux.fit!(la, zip(X, y))
optimize_prior!(la; verbose=verbose_laplace, n_steps=model.fit_prior_nsteps)
Expand Down
Loading