Skip to content

Commit

Permalink
Handle special case for simplechains
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 18, 2024
1 parent 5a0f26b commit 0adbce7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 15 deletions.
24 changes: 9 additions & 15 deletions ext/LuxSimpleChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@ import Lux: SimpleChainsModelConversionError, __to_simplechains_adaptor,
import Optimisers

function __fix_input_dims_simplechain(layers::Vector, input_dims)
return SimpleChains.SimpleChain(input_dims, layers...)
L = Tuple(layers)
return SimpleChains.SimpleChain{typeof(input_dims), typeof(L)}(input_dims, L)
end

function __fix_input_dims_simplechain(layers, input_dims)
@warn "The model provided is not a `Chain`. Trying to wrap it into a `Chain` but this \
might fail. Please consider using `Chain` directly (potentially with \
`disable_optimizations = true`)."
return __fix_input_dims_simplechain([layers], input_dims)
end

__equivalent_simplechains_fn(::typeof(Lux.relu)) = SimpleChains.relu
Expand Down Expand Up @@ -75,18 +83,4 @@ function NNlib.logsoftmax!(y::SimpleChains.StrideArray{T1, 2},
return y
end

# Nicer Interactions with Optimisers.jl
# function Optimisers._setup(opt::Optimisers.AbstractRule,
# ps::Union{SimpleChains.StrideArray, SimpleChains.PtrArray}; cache)
# ℓ = Leaf(rule, init(rule, x))
# if isbits(x)
# cache[nothing] = nothing # just to disable the warning
#
# else
# cache[x] = ℓ
# end
# error(1)
# return Optimisers.setup(opt, ps .- ps)
# end

end
17 changes: 17 additions & 0 deletions test/transform/simple_chains_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,21 @@
gs = Zygote.gradient((x, p) -> sum(first(simple_chains_model(x, p, st))), x, ps)
@test size(gs[1]) == size(x)
@test length(gs[2].params) == length(ps.params)

@testset "Single Layer Conversion: LuxDL/Lux.jl#545" begin
lux_model = Dense(10 => 5)

adaptor = ToSimpleChainsAdaptor((static(10),))

simple_chains_model = @test_warn "The model provided is not a `Chain`. Trying to wrap it into a `Chain` but this might fail. Please consider using `Chain` directly (potentially with `disable_optimizations = true`)." adaptor(lux_model)

ps, st = Lux.setup(Random.default_rng(), simple_chains_model)

x = randn(Float32, 10, 3)
@test size(first(simple_chains_model(x, ps, st))) == (5, 3)

gs = Zygote.gradient((x, p) -> sum(first(simple_chains_model(x, p, st))), x, ps)
@test size(gs[1]) == size(x)
@test length(gs[2].params) == length(ps.params)
end
end

0 comments on commit 0adbce7

Please sign in to comment.