From da6d8f66443ce2339e6042b095db8f038154e8ab Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 22:27:03 -0400 Subject: [PATCH] fix: update simplechains layer API --- src/layers/extension.jl | 41 ++++++++++++++--------------------- src/transform/simplechains.jl | 2 +- 2 files changed, 17 insertions(+), 26 deletions(-) diff --git a/src/layers/extension.jl b/src/layers/extension.jl index e4d7298ca7..8242790a86 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -48,8 +48,8 @@ Base.show(io::IO, ::MIME"text/plain", l::FluxLayer) = print(io, "FluxLayer($(l.l ## SimpleChains.jl """ - SimpleChainsLayer{ToArray}(layer, lux_layer=nothing) - SimpleChainsLayer(layer, ToArray::Union{Bool, Val}=Val(false)) + SimpleChainsLayer(layer, to_array::Union{Bool, Val}=Val(false)) + SimpleChainsLayer(layer, lux_layer, to_array) Wraps a `SimpleChains` layer into a `Lux` layer. All operations are performed using `SimpleChains` but the layer satisfies the `AbstractLuxLayer` interface. @@ -62,39 +62,30 @@ regular `Array` or not. Default is `false`. - `layer`: SimpleChains layer - `lux_layer`: Potentially equivalent Lux layer that is used for printing """ -struct SimpleChainsLayer{ToArray, SL, LL <: Union{Nothing, AbstractLuxLayer}} <: - AbstractLuxLayer - to_array::ToArray - layer::SL - lux_layer::LL - - function SimpleChainsLayer{ToArray}(layer, lux_layer=nothing) where {ToArray} - to_array = static(ToArray) - return new{typeof(to_array), typeof(layer), typeof(lux_layer)}( - to_array, layer, lux_layer) - end - function SimpleChainsLayer(layer, ToArray::BoolType=False()) - to_array = static(ToArray) - return new{typeof(to_array), typeof(layer), Nothing}(to_array, layer, nothing) - end +@concrete struct SimpleChainsLayer <: AbstractLuxLayer + layer + lux_layer <: Union{Nothing, AbstractLuxLayer} + to_array <: StaticBool +end + +function SimpleChainsLayer(layer, to_array::BoolType=False()) + return SimpleChainsLayer(layer, nothing, static(to_array)) end -function Base.show( - io::IO, ::MIME"text/plain", s::SimpleChainsLayer{ToArray}) where {ToArray} - PrettyPrinting.print_wrapper_model( - io, "SimpleChainsLayer{to_array=$ToArray}", s.lux_layer) +function Base.show(io::IO, ::MIME"text/plain", s::SimpleChainsLayer) + PrettyPrinting.print_wrapper_model(io, "SimpleChainsLayer", s.lux_layer) end function (sc::SimpleChainsLayer)(x, ps, st) y = match_eltype(sc, ps, st, x) return ( - simple_chain_output( - sc, apply_simple_chain(sc.layer, y, ps.params, MLDataDevices.get_device(x))), + to_array(sc.to_array, + apply_simple_chain(sc.layer, y, ps.params, MLDataDevices.get_device(x))), st) end -simple_chain_output(::SimpleChainsLayer{False}, y) = y -simple_chain_output(::SimpleChainsLayer{True}, y) = convert(Array, y) +to_array(::False, y) = y +to_array(::True, y) = convert(Array, y) apply_simple_chain(layer, x, ps, ::CPUDevice) = layer(x, ps) diff --git a/src/transform/simplechains.jl b/src/transform/simplechains.jl index f6e2ecb7e9..ca0626dbc2 100644 --- a/src/transform/simplechains.jl +++ b/src/transform/simplechains.jl @@ -69,7 +69,7 @@ function Adapt.adapt(to::ToSimpleChainsAdaptor, L::AbstractLuxLayer) error("`ToSimpleChainsAdaptor` requires `SimpleChains.jl` to be loaded.") end sc_layer = fix_simplechain_input_dims(make_simplechain_network(L), to.input_dims) - return SimpleChainsLayer{to.convert_to_array}(sc_layer, L) + return SimpleChainsLayer(sc_layer, L, to.convert_to_array) end function make_simplechain_network end