From 0adbce78cee591940b4742a32cbd2488fb2afc0b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 18 Mar 2024 14:22:40 -0400 Subject: [PATCH] Handle special case for simplechains --- ext/LuxSimpleChainsExt.jl | 24 +++++++++--------------- test/transform/simple_chains_tests.jl | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/ext/LuxSimpleChainsExt.jl b/ext/LuxSimpleChainsExt.jl index 62403244c..499306ef3 100644 --- a/ext/LuxSimpleChainsExt.jl +++ b/ext/LuxSimpleChainsExt.jl @@ -7,7 +7,15 @@ import Lux: SimpleChainsModelConversionError, __to_simplechains_adaptor, import Optimisers function __fix_input_dims_simplechain(layers::Vector, input_dims) - return SimpleChains.SimpleChain(input_dims, layers...) + L = Tuple(layers) + return SimpleChains.SimpleChain{typeof(input_dims), typeof(L)}(input_dims, L) +end + +function __fix_input_dims_simplechain(layers, input_dims) + @warn "The model provided is not a `Chain`. Trying to wrap it into a `Chain` but this \ + might fail. Please consider using `Chain` directly (potentially with \ + `disable_optimizations = true`)." + return __fix_input_dims_simplechain([layers], input_dims) end __equivalent_simplechains_fn(::typeof(Lux.relu)) = SimpleChains.relu @@ -75,18 +83,4 @@ function NNlib.logsoftmax!(y::SimpleChains.StrideArray{T1, 2}, return y end -# Nicer Interactions with Optimisers.jl -# function Optimisers._setup(opt::Optimisers.AbstractRule, -# ps::Union{SimpleChains.StrideArray, SimpleChains.PtrArray}; cache) -# ℓ = Leaf(rule, init(rule, x)) -# if isbits(x) -# cache[nothing] = nothing # just to disable the warning -# ℓ -# else -# cache[x] = ℓ -# end -# error(1) -# return Optimisers.setup(opt, ps .- ps) -# end - end diff --git a/test/transform/simple_chains_tests.jl b/test/transform/simple_chains_tests.jl index 8b5f97cbf..c4c186cc7 100644 --- a/test/transform/simple_chains_tests.jl +++ b/test/transform/simple_chains_tests.jl @@ -47,4 +47,21 @@ gs = Zygote.gradient((x, p) -> sum(first(simple_chains_model(x, p, st))), x, ps) @test size(gs[1]) == size(x) @test length(gs[2].params) == length(ps.params) + + @testset "Single Layer Conversion: LuxDL/Lux.jl#545" begin + lux_model = Dense(10 => 5) + + adaptor = ToSimpleChainsAdaptor((static(10),)) + + simple_chains_model = @test_warn "The model provided is not a `Chain`. Trying to wrap it into a `Chain` but this might fail. Please consider using `Chain` directly (potentially with `disable_optimizations = true`)." adaptor(lux_model) + + ps, st = Lux.setup(Random.default_rng(), simple_chains_model) + + x = randn(Float32, 10, 3) + @test size(first(simple_chains_model(x, ps, st))) == (5, 3) + + gs = Zygote.gradient((x, p) -> sum(first(simple_chains_model(x, p, st))), x, ps) + @test size(gs[1]) == size(x) + @test length(gs[2].params) == length(ps.params) + end end