diff --git a/ext/LuxDynamicExpressionsExt.jl b/ext/LuxDynamicExpressionsExt.jl index e554a8f217..5295eafc9b 100644 --- a/ext/LuxDynamicExpressionsExt.jl +++ b/ext/LuxDynamicExpressionsExt.jl @@ -3,7 +3,7 @@ module LuxDynamicExpressionsExt using ChainRulesCore: ChainRulesCore, NoTangent using DynamicExpressions: DynamicExpressions, Node, OperatorEnum, eval_grad_tree_array using FastClosures: @closure -using Lux: Lux, NAME_TYPE, Chain, Parallel, WrappedFunction +using Lux: Lux, NAME_TYPE, Chain, Parallel const CRC = ChainRulesCore @@ -20,7 +20,7 @@ function Lux.DynamicExpressionsLayer( Parallel(nothing, ntuple(i -> Lux.DynamicExpressionsLayer(operator_enum, expressions[i], _name_fn(i), turbo, bumper), length(expressions))...), - WrappedFunction{:direct_call}(Lux.__stack1); + Lux.__stack1; name="DynamicExpressionsLayer") #! format: on end diff --git a/ext/LuxFluxExt.jl b/ext/LuxFluxExt.jl index 30dc4ab76f..2e2b56d9d6 100644 --- a/ext/LuxFluxExt.jl +++ b/ext/LuxFluxExt.jl @@ -17,7 +17,7 @@ function __from_flux_adaptor(l::T; preserve_ps_st::Bool=false, kwargs...) where return Lux.FluxLayer(l) end -__from_flux_adaptor(l::Function; kwargs...) = Lux.WrappedFunction{:direct_call}(l) +__from_flux_adaptor(l::Function; kwargs...) = Lux.WrappedFunction(l) function __from_flux_adaptor(l::Flux.Chain; kwargs...) fn = x -> __from_flux_adaptor(x; kwargs...) @@ -29,7 +29,7 @@ end function __from_flux_adaptor(l::Flux.Dense; preserve_ps_st::Bool=false, kwargs...) out_dims, in_dims = size(l.weight) if preserve_ps_st - bias = l.bias isa Bool ? nothing : reshape(copy(l.bias), out_dims, 1) + bias = l.bias isa Bool ? nothing : copy(l.bias) return Lux.Dense(in_dims => out_dims, l.σ; init_weight=Returns(copy(l.weight)), init_bias=Returns(bias), use_bias=!(l.bias isa Bool)) else diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 740c8265cd..0fc3e2bb83 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -214,8 +214,7 @@ struct NoOpLayer <: AbstractExplicitLayer end @inline (noop::NoOpLayer)(x, ps, st::NamedTuple) = x, st """ - WrappedFunction{DC}(f) - WrappedFunction(f) -> WrappedFunction{:runtime_check}(f) + WrappedFunction(f) Wraps a stateless and parameter less function. Might be used when a function is added to `Chain`. For example, `Chain(x -> relu.(x))` would not work and the right thing to do would @@ -224,9 +223,6 @@ be `Chain((x, ps, st) -> (relu.(x), st))`. An easier thing to do would be ## Arguments - - `DC`: If `:runtime_check`, then we check if the function can be called with the input - `x`, `ps`, and `st` using `hasmethod`. If `:direct_call`, we call `f(x)` directly. - For all other values, we call `f(x, ps, st)` which must return a tuple. - `f`: Some function. ## Inputs @@ -238,33 +234,14 @@ be `Chain((x, ps, st) -> (relu.(x), st))`. An easier thing to do would be - Output of `f(x)` - Empty `NamedTuple()` """ -@concrete struct WrappedFunction{DC} <: AbstractExplicitLayer +@concrete struct WrappedFunction <: AbstractExplicitLayer func <: Function end -WrappedFunction(f::F) where {F} = WrappedFunction{:runtime_check}(f) +(wf::WrappedFunction)(x, ps, st::NamedTuple{}) = wf.func(x), st -function (wf::WrappedFunction{:direct_call})(x, ps, st::NamedTuple) - return __maybe_direct_call(wf.func, x, ps, st, Val(true)) -end - -function (wf::WrappedFunction)(x, ps, st::NamedTuple) - return __maybe_direct_call(wf.func, x, ps, st, Val(false)) -end - -@generated function (wf::WrappedFunction{:runtime_check, F})( - x, ps, st::NamedTuple) where {F} - DC = hasmethod(F, Tuple{typeof(x)}) - return :(__maybe_direct_call(wf.func, x, ps, st, Val($(DC)))) -end - -@inline __maybe_direct_call(f, x, ps, st, ::Val{false}) = f(x, ps, st) -@inline __maybe_direct_call(f, x, ps, st, ::Val{true}) = f(x), st - -function Base.show(io::IO, w::WrappedFunction{T}) where {T} - print(io, "WrappedFunction{$(Meta.quot(T))}(") - show(io, w.func) - print(io, ")") +function Base.show(io::IO, w::WrappedFunction) + print(io, "WrappedFunction(", w.func, ")") end """ @@ -332,7 +309,7 @@ end function initialparameters(rng::AbstractRNG, d::Dense{use_bias}) where {use_bias} if use_bias return (weight=d.init_weight(rng, d.out_dims, d.in_dims), - bias=d.init_bias(rng, d.out_dims, 1)) #TODO: In v0.6 make it a vector + bias=d.init_bias(rng, d.out_dims)) else return (weight=d.init_weight(rng, d.out_dims, d.in_dims),) end @@ -345,10 +322,6 @@ statelength(d::Dense) = 0 outputsize(d::Dense) = (d.out_dims,) -@inline function (d::Dense)(x::AbstractVector, ps, st::NamedTuple) - return vec(first(d(reshape(x, :, 1), ps, st))), st -end - @inline function (d::Dense)(x::AbstractArray, ps, st::NamedTuple) return reshape(first(d(reshape(x, size(x, 1), :), ps, st)), :, size(x)[2:end]...), st end @@ -357,7 +330,7 @@ end y = match_eltype(d, ps, st, x) return ( fused_dense_bias_activation( - d.activation, ps.weight, y, _vec(_getproperty(ps, Val(:bias)))), + d.activation, ps.weight, y, _getproperty(ps, Val(:bias))), st) end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 7182b369e0..a4fe3298b2 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -504,7 +504,7 @@ function set to `Base.Fix2(pixel_shuffle, r)` - Output of size `(r x W, r x H, C, N)` for 4D-arrays, and `(r x W, r x H, ..., C, N)` for D-dimensional data, where `D = ndims(x) - 2` """ -PixelShuffle(r::Int) = WrappedFunction{:direct_call}(Base.Fix2(pixel_shuffle, r)) +PixelShuffle(r::Int) = WrappedFunction(Base.Fix2(pixel_shuffle, r)) @doc doc""" CrossCor(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 186c9e7783..87974b594e 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -54,7 +54,7 @@ DynamicExpressionsLayer( layer_1 = DynamicExpressionsLayer(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*)}, Tuple{typeof(cos)}}((+, -, *), (cos,)), x1 * cos(x2 - 3.2); turbo=Val{false}(), bumper=Val{false}()), # 1 parameters layer_2 = DynamicExpressionsLayer(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*)}, Tuple{typeof(cos)}}((+, -, *), (cos,)), ((x2 - (x1 * x2)) + 2.5) - (1.0 * x1); turbo=Val{false}(), bumper=Val{false}()), # 2 parameters ), - layer_2 = WrappedFunction{:direct_call}(Lux.__stack1), + layer_2 = WrappedFunction(__stack1), ) # Total: 3 parameters, # plus 0 states. diff --git a/test/contrib/share_parameters_tests.jl b/test/contrib/share_parameters_tests.jl index cdee97d57b..430156811d 100644 --- a/test/contrib/share_parameters_tests.jl +++ b/test/contrib/share_parameters_tests.jl @@ -16,9 +16,9 @@ @test ps_1.d3.weight == ps_1.d2.l1.weight @test ps_1.d3.bias == ps_1.d2.l1.bias - ps_new_1 = (; weight=randn(rng, Float32, 4, 2), bias=randn(rng, Float32, 4, 1)) |> + ps_new_1 = (; weight=randn(rng, Float32, 4, 2), bias=randn(rng, Float32, 4)) |> device - ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2, 1)) |> + ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2)) |> device ps_2 = Lux.Experimental.share_parameters(ps, sharing, (ps_new_1, ps_new_2)) @@ -46,9 +46,9 @@ ps, sharing, (ps_new_1,)) # Parameter Structure Mismatch - ps_new_1 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 4, 1)) |> + ps_new_1 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 4)) |> device - ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2, 1)) |> + ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2)) |> device @test_throws ArgumentError Lux.Experimental.share_parameters( diff --git a/test/helpers/compact_tests.jl b/test/helpers/compact_tests.jl index 41201433b6..c522183438 100644 --- a/test/helpers/compact_tests.jl +++ b/test/helpers/compact_tests.jl @@ -104,14 +104,14 @@ ps, st = Lux.setup(rng, model) |> device @test size(ps.w1.weight) == (128, 1) - @test size(ps.w1.bias) == (128, 1) + @test size(ps.w1.bias) == (128,) @test length(ps.w2) == nlayers for i in 1:nlayers @test size(ps.w2[i].weight) == (128, 128) - @test size(ps.w2[i].bias) == (128, 1) + @test size(ps.w2[i].bias) == (128,) end @test size(ps.w3.weight) == (1, 128) - @test size(ps.w3.bias) == (1, 1) + @test size(ps.w3.bias) == (1,) x = randn(n_in, 32) |> aType diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 8fcca38005..91e471ef6a 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -98,26 +98,6 @@ @jet layer(x, ps, st) __f = x -> sum(first(layer(x, ps, st))) @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 - - f11(x) = x .* x - - layer = WrappedFunction{:runtime_check}(f11) - display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = randn(rng, 2, 3) |> aType - - @test layer(x, ps, st)[1] ≈ x .* x - @inferred layer(x, ps, st) - - f12(x, ps, st) = x .+ 1, st - - layer = WrappedFunction{:runtime_check}(f12) - display(layer) - ps, st = Lux.setup(rng, layer) .|> device - x = randn(rng, 2, 3) |> aType - - @test layer(x, ps, st)[1] ≈ x .+ 1 - @inferred layer(x, ps, st) end @testset "PeriodicEmbedding" begin @@ -149,7 +129,7 @@ end ps, st = Lux.setup(rng, layer) .|> device @test size(ps.weight) == (100, 10) - @test size(ps.bias) == (100, 1) + @test size(ps.bias) == (100,) @test layer.activation == identity layer = Dense(10, 100, relu; use_bias=false) diff --git a/test/layers/containers_tests.jl b/test/layers/containers_tests.jl index 93022d1737..a7db9f2e01 100644 --- a/test/layers/containers_tests.jl +++ b/test/layers/containers_tests.jl @@ -4,8 +4,7 @@ @testset "$mode" for (mode, aType, device, ongpu) in MODES @testset "zero sum" begin - layer = SkipConnection( - WrappedFunction{:direct_call}(Broadcast.BroadcastFunction(zero)), .+) + layer = SkipConnection(WrappedFunction(Broadcast.BroadcastFunction(zero)), .+) display(layer) ps, st = Lux.setup(rng, layer) .|> device x = randn(rng, 10, 10, 10, 10) |> aType @@ -38,8 +37,7 @@ end @testset "$mode" for (mode, aType, device, ongpu) in MODES @testset "zero sum" begin layer = Parallel( - +, WrappedFunction{:direct_call}(Broadcast.BroadcastFunction(zero)), - NoOpLayer()) + +, WrappedFunction(Broadcast.BroadcastFunction(zero)), NoOpLayer()) display(layer) ps, st = Lux.setup(rng, layer) .|> device x = randn(rng, 10, 10, 10, 10) |> aType @@ -320,7 +318,7 @@ end end @testset "constructors" begin - f1(x, ps, st::NamedTuple) = (x .+ 1, st) + f1(x) = x .+ 1 f2(x) = x .+ 2 model = Chain((Dense(2 => 3), Dense(3 => 2)), f1, f2, NoOpLayer()) diff --git a/test/transform/flux_tests.jl b/test/transform/flux_tests.jl index 6d2e31da28..835c9c7493 100644 --- a/test/transform/flux_tests.jl +++ b/test/transform/flux_tests.jl @@ -503,7 +503,7 @@ @testset "Functions" begin @test tolux(Flux.flatten) isa Lux.FlattenLayer @test tolux(identity) isa Lux.NoOpLayer - @test tolux(+) isa Lux.WrappedFunction{:direct_call} + @test tolux(+) isa Lux.WrappedFunction end @testset "Unsupported Layers" begin