From ca96618daebcb030caf5ffca113a9ecd18e8d4ec Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 18 Mar 2024 13:54:14 -0400 Subject: [PATCH 1/4] Fail docs if tutorials fail --- docs/tutorials.jl | 37 +------------------------------------ 1 file changed, 1 insertion(+), 36 deletions(-) diff --git a/docs/tutorials.jl b/docs/tutorials.jl index 4c7b2f29b..565331627 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -21,41 +21,6 @@ TUTORIALS = [collect(Iterators.product(["beginner"], BEGINNER_TUTORIALS))..., collect(Iterators.product(["intermediate"], INTERMEDIATE_TUTORIALS))..., collect(Iterators.product(["advanced"], ADVANCED_TUTORIALS))...] -@info "Installing and Precompiling Tutorial Dependencies" - -const storage_dir = joinpath(@__DIR__, "..", "tutorial_deps") - -mkpath(storage_dir) - -try - pmap(TUTORIALS) do (d, p) - p_ = get_example_path(p) - name = first(split(p, '/')) - - pkg_log_path = joinpath(storage_dir, "$(name)_pkg.log") - tutorial_proj = dirname(p_) - lux_path = joinpath(@__DIR__, "..") - - withenv("PKG_LOG_PATH" => pkg_log_path, "LUX_PATH" => lux_path) do - cmd = `$(Base.julia_cmd()) --color=yes --project=$(tutorial_proj) -e \ - 'using Pkg; - io=open(ENV["PKG_LOG_PATH"], "w"); - Pkg.develop(; path=ENV["LUX_PATH"], io); - Pkg.instantiate(; io); - Pkg.precompile(; io); - eval(Meta.parse("using " * join(keys(Pkg.project().dependencies), ", "))); - close(io)'` - @info "Running Command: $(cmd)" - run(cmd) - return - end - return - end -catch e - rmprocs(workers()...) - @error e -end - @info "Starting tutorial build" try @@ -87,5 +52,5 @@ try end catch e rmprocs(workers()...) - @error e + rethrow(e) end From 850610303d83a8c72cdaa119fe02a201030cb93d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 18 Mar 2024 13:54:45 -0400 Subject: [PATCH 2/4] Add warning for incorrect uses of f<__> functions --- Project.toml | 2 +- src/utils.jl | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 79e65883c..d7fe3b99e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.23" +version = "0.5.24" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/utils.jl b/src/utils.jl index dffb0f236..19547b4d5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -212,6 +212,26 @@ for (fname, ftype) in zip((:f16, :f32, :f64), (Float16, Float32, Float64)) end end +# Common incorrect usage +for f in (f16, f32, f64) + warn_msg = "$(f) is not meant to be broadcasted like `$(f).(x)` or `x .|> $(f)`, and \ + this might give unexpected results and could lead to crashes. Directly use \ + `$(f)` as `$(f)(x)` or `x |> $(f)` instead." + @eval begin + function Base.Broadcast.broadcasted(::typeof($(f)), arg1) + @warn $(warn_msg) + arg1′ = Broadcast.broadcastable(arg1) + return Broadcast.broadcasted(Broadcast.combine_styles(arg1′), $(f), arg1′) + end + + function Base.Broadcast.broadcasted(::typeof(|>), arg1, ::typeof($(f))) + @warn $(warn_msg) + arg1′ = Broadcast.broadcastable(arg1) + return Broadcast.broadcasted(Broadcast.combine_styles(arg1′), $(f), arg1′) + end + end +end + # Used in freezing ## Extend for custom types @inline _pairs(x) = pairs(x) From 5a0f26b2b95c068e52c72b1873a727b7024e90f2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 18 Mar 2024 14:07:33 -0400 Subject: [PATCH 3/4] Fix bug in eltype adaptor --- docs/tutorials.jl | 35 +++++++++++++++++++++++++++++++++++ ext/LuxComponentArraysExt.jl | 2 +- src/utils.jl | 3 +++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/docs/tutorials.jl b/docs/tutorials.jl index 565331627..ef994b691 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -21,6 +21,41 @@ TUTORIALS = [collect(Iterators.product(["beginner"], BEGINNER_TUTORIALS))..., collect(Iterators.product(["intermediate"], INTERMEDIATE_TUTORIALS))..., collect(Iterators.product(["advanced"], ADVANCED_TUTORIALS))...] +@info "Installing and Precompiling Tutorial Dependencies" + +const storage_dir = joinpath(@__DIR__, "..", "tutorial_deps") + +mkpath(storage_dir) + +try + pmap(TUTORIALS) do (d, p) + p_ = get_example_path(p) + name = first(split(p, '/')) + + pkg_log_path = joinpath(storage_dir, "$(name)_pkg.log") + tutorial_proj = dirname(p_) + lux_path = joinpath(@__DIR__, "..") + + withenv("PKG_LOG_PATH" => pkg_log_path, "LUX_PATH" => lux_path) do + cmd = `$(Base.julia_cmd()) --color=yes --project=$(tutorial_proj) -e \ + 'using Pkg; + io=open(ENV["PKG_LOG_PATH"], "w"); + Pkg.develop(; path=ENV["LUX_PATH"], io); + Pkg.instantiate(; io); + Pkg.precompile(; io); + eval(Meta.parse("using " * join(keys(Pkg.project().dependencies), ", "))); + close(io)'` + @info "Running Command: $(cmd)" + run(cmd) + return + end + return + end +catch e + rmprocs(workers()...) + @error e +end + @info "Starting tutorial build" try diff --git a/ext/LuxComponentArraysExt.jl b/ext/LuxComponentArraysExt.jl index 6a9a245e0..ff63065f2 100644 --- a/ext/LuxComponentArraysExt.jl +++ b/ext/LuxComponentArraysExt.jl @@ -1,6 +1,6 @@ module LuxComponentArraysExt -using ComponentArrays, Functors, Lux, Optimisers +using Adapt, ComponentArrays, Functors, Lux, Optimisers import TruncatedStacktraces: @truncate_stacktrace import ChainRulesCore as CRC diff --git a/src/utils.jl b/src/utils.jl index 19547b4d5..f2801786a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -189,6 +189,9 @@ end struct LuxEltypeAdaptor{T} end (l::LuxEltypeAdaptor)(x) = fmap(adapt(l), x) +function (l::LuxEltypeAdaptor)(x::AbstractArray{T}) where {T} + return isbitstype(T) ? adapt(l, x) : map(adapt(l), x) +end function Adapt.adapt_storage( ::LuxEltypeAdaptor{T}, x::AbstractArray{<:AbstractFloat}) where {T <: AbstractFloat} From 0adbce78cee591940b4742a32cbd2488fb2afc0b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 18 Mar 2024 14:22:40 -0400 Subject: [PATCH 4/4] Handle special case for simplechains --- ext/LuxSimpleChainsExt.jl | 24 +++++++++--------------- test/transform/simple_chains_tests.jl | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/ext/LuxSimpleChainsExt.jl b/ext/LuxSimpleChainsExt.jl index 62403244c..499306ef3 100644 --- a/ext/LuxSimpleChainsExt.jl +++ b/ext/LuxSimpleChainsExt.jl @@ -7,7 +7,15 @@ import Lux: SimpleChainsModelConversionError, __to_simplechains_adaptor, import Optimisers function __fix_input_dims_simplechain(layers::Vector, input_dims) - return SimpleChains.SimpleChain(input_dims, layers...) + L = Tuple(layers) + return SimpleChains.SimpleChain{typeof(input_dims), typeof(L)}(input_dims, L) +end + +function __fix_input_dims_simplechain(layers, input_dims) + @warn "The model provided is not a `Chain`. Trying to wrap it into a `Chain` but this \ + might fail. Please consider using `Chain` directly (potentially with \ + `disable_optimizations = true`)." + return __fix_input_dims_simplechain([layers], input_dims) end __equivalent_simplechains_fn(::typeof(Lux.relu)) = SimpleChains.relu @@ -75,18 +83,4 @@ function NNlib.logsoftmax!(y::SimpleChains.StrideArray{T1, 2}, return y end -# Nicer Interactions with Optimisers.jl -# function Optimisers._setup(opt::Optimisers.AbstractRule, -# ps::Union{SimpleChains.StrideArray, SimpleChains.PtrArray}; cache) -# ℓ = Leaf(rule, init(rule, x)) -# if isbits(x) -# cache[nothing] = nothing # just to disable the warning -# ℓ -# else -# cache[x] = ℓ -# end -# error(1) -# return Optimisers.setup(opt, ps .- ps) -# end - end diff --git a/test/transform/simple_chains_tests.jl b/test/transform/simple_chains_tests.jl index 8b5f97cbf..c4c186cc7 100644 --- a/test/transform/simple_chains_tests.jl +++ b/test/transform/simple_chains_tests.jl @@ -47,4 +47,21 @@ gs = Zygote.gradient((x, p) -> sum(first(simple_chains_model(x, p, st))), x, ps) @test size(gs[1]) == size(x) @test length(gs[2].params) == length(ps.params) + + @testset "Single Layer Conversion: LuxDL/Lux.jl#545" begin + lux_model = Dense(10 => 5) + + adaptor = ToSimpleChainsAdaptor((static(10),)) + + simple_chains_model = @test_warn "The model provided is not a `Chain`. Trying to wrap it into a `Chain` but this might fail. Please consider using `Chain` directly (potentially with `disable_optimizations = true`)." adaptor(lux_model) + + ps, st = Lux.setup(Random.default_rng(), simple_chains_model) + + x = randn(Float32, 10, 3) + @test size(first(simple_chains_model(x, ps, st))) == (5, 3) + + gs = Zygote.gradient((x, p) -> sum(first(simple_chains_model(x, p, st))), x, ps) + @test size(gs[1]) == size(x) + @test length(gs[2].params) == length(ps.params) + end end