Skip to content

Commit

Permalink
chore!: remove annotation of WrappedFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 4, 2024
1 parent 509e152 commit 077a391
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 74 deletions.
4 changes: 2 additions & 2 deletions ext/LuxDynamicExpressionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module LuxDynamicExpressionsExt
using ChainRulesCore: ChainRulesCore, NoTangent
using DynamicExpressions: DynamicExpressions, Node, OperatorEnum, eval_grad_tree_array
using FastClosures: @closure
using Lux: Lux, NAME_TYPE, Chain, Parallel, WrappedFunction
using Lux: Lux, NAME_TYPE, Chain, Parallel

const CRC = ChainRulesCore

Expand All @@ -20,7 +20,7 @@ function Lux.DynamicExpressionsLayer(
Parallel(nothing,
ntuple(i -> Lux.DynamicExpressionsLayer(operator_enum, expressions[i],
_name_fn(i), turbo, bumper), length(expressions))...),
WrappedFunction{:direct_call}(Lux.__stack1);
Lux.__stack1;
name="DynamicExpressionsLayer")
#! format: on
end
Expand Down
4 changes: 2 additions & 2 deletions ext/LuxFluxExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function __from_flux_adaptor(l::T; preserve_ps_st::Bool=false, kwargs...) where
return Lux.FluxLayer(l)
end

__from_flux_adaptor(l::Function; kwargs...) = Lux.WrappedFunction{:direct_call}(l)
__from_flux_adaptor(l::Function; kwargs...) = Lux.WrappedFunction(l)

function __from_flux_adaptor(l::Flux.Chain; kwargs...)
fn = x -> __from_flux_adaptor(x; kwargs...)
Expand All @@ -29,7 +29,7 @@ end
function __from_flux_adaptor(l::Flux.Dense; preserve_ps_st::Bool=false, kwargs...)
out_dims, in_dims = size(l.weight)
if preserve_ps_st
bias = l.bias isa Bool ? nothing : reshape(copy(l.bias), out_dims, 1)
bias = l.bias isa Bool ? nothing : copy(l.bias)
return Lux.Dense(in_dims => out_dims, l.σ; init_weight=Returns(copy(l.weight)),
init_bias=Returns(bias), use_bias=!(l.bias isa Bool))
else
Expand Down
41 changes: 7 additions & 34 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,7 @@ struct NoOpLayer <: AbstractExplicitLayer end
@inline (noop::NoOpLayer)(x, ps, st::NamedTuple) = x, st

"""
WrappedFunction{DC}(f)
WrappedFunction(f) -> WrappedFunction{:runtime_check}(f)
WrappedFunction(f)
Wraps a stateless and parameter less function. Might be used when a function is added to
`Chain`. For example, `Chain(x -> relu.(x))` would not work and the right thing to do would
Expand All @@ -224,9 +223,6 @@ be `Chain((x, ps, st) -> (relu.(x), st))`. An easier thing to do would be
## Arguments
- `DC`: If `:runtime_check`, then we check if the function can be called with the input
`x`, `ps`, and `st` using `hasmethod`. If `:direct_call`, we call `f(x)` directly.
For all other values, we call `f(x, ps, st)` which must return a tuple.
- `f`: Some function.
## Inputs
Expand All @@ -238,33 +234,14 @@ be `Chain((x, ps, st) -> (relu.(x), st))`. An easier thing to do would be
- Output of `f(x)`
- Empty `NamedTuple()`
"""
@concrete struct WrappedFunction{DC} <: AbstractExplicitLayer
@concrete struct WrappedFunction <: AbstractExplicitLayer
func <: Function
end

WrappedFunction(f::F) where {F} = WrappedFunction{:runtime_check}(f)
(wf::WrappedFunction)(x, ps, st::NamedTuple{}) = wf.func(x), st

function (wf::WrappedFunction{:direct_call})(x, ps, st::NamedTuple)
return __maybe_direct_call(wf.func, x, ps, st, Val(true))
end

function (wf::WrappedFunction)(x, ps, st::NamedTuple)
return __maybe_direct_call(wf.func, x, ps, st, Val(false))
end

@generated function (wf::WrappedFunction{:runtime_check, F})(
x, ps, st::NamedTuple) where {F}
DC = hasmethod(F, Tuple{typeof(x)})
return :(__maybe_direct_call(wf.func, x, ps, st, Val($(DC))))
end

@inline __maybe_direct_call(f, x, ps, st, ::Val{false}) = f(x, ps, st)
@inline __maybe_direct_call(f, x, ps, st, ::Val{true}) = f(x), st

function Base.show(io::IO, w::WrappedFunction{T}) where {T}
print(io, "WrappedFunction{$(Meta.quot(T))}(")
show(io, w.func)
print(io, ")")
function Base.show(io::IO, w::WrappedFunction)
print(io, "WrappedFunction(", w.func, ")")
end

"""
Expand Down Expand Up @@ -332,7 +309,7 @@ end
function initialparameters(rng::AbstractRNG, d::Dense{use_bias}) where {use_bias}
if use_bias
return (weight=d.init_weight(rng, d.out_dims, d.in_dims),
bias=d.init_bias(rng, d.out_dims, 1)) #TODO: In v0.6 make it a vector
bias=d.init_bias(rng, d.out_dims))
else
return (weight=d.init_weight(rng, d.out_dims, d.in_dims),)
end
Expand All @@ -345,10 +322,6 @@ statelength(d::Dense) = 0

outputsize(d::Dense) = (d.out_dims,)

@inline function (d::Dense)(x::AbstractVector, ps, st::NamedTuple)
return vec(first(d(reshape(x, :, 1), ps, st))), st
end

@inline function (d::Dense)(x::AbstractArray, ps, st::NamedTuple)
return reshape(first(d(reshape(x, size(x, 1), :), ps, st)), :, size(x)[2:end]...), st
end
Expand All @@ -357,7 +330,7 @@ end
y = match_eltype(d, ps, st, x)
return (
fused_dense_bias_activation(
d.activation, ps.weight, y, _vec(_getproperty(ps, Val(:bias)))),
d.activation, ps.weight, y, _getproperty(ps, Val(:bias))),
st)
end

Expand Down
2 changes: 1 addition & 1 deletion src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ function set to `Base.Fix2(pixel_shuffle, r)`
- Output of size `(r x W, r x H, C, N)` for 4D-arrays, and `(r x W, r x H, ..., C, N)`
for D-dimensional data, where `D = ndims(x) - 2`
"""
PixelShuffle(r::Int) = WrappedFunction{:direct_call}(Base.Fix2(pixel_shuffle, r))
PixelShuffle(r::Int) = WrappedFunction(Base.Fix2(pixel_shuffle, r))

@doc doc"""
CrossCor(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer},
Expand Down
2 changes: 1 addition & 1 deletion src/layers/extension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ DynamicExpressionsLayer(
layer_1 = DynamicExpressionsLayer(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*)}, Tuple{typeof(cos)}}((+, -, *), (cos,)), x1 * cos(x2 - 3.2); turbo=Val{false}(), bumper=Val{false}()), # 1 parameters
layer_2 = DynamicExpressionsLayer(DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*)}, Tuple{typeof(cos)}}((+, -, *), (cos,)), ((x2 - (x1 * x2)) + 2.5) - (1.0 * x1); turbo=Val{false}(), bumper=Val{false}()), # 2 parameters
),
layer_2 = WrappedFunction{:direct_call}(Lux.__stack1),
layer_2 = WrappedFunction(__stack1),
) # Total: 3 parameters,
# plus 0 states.
Expand Down
8 changes: 4 additions & 4 deletions test/contrib/share_parameters_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
@test ps_1.d3.weight == ps_1.d2.l1.weight
@test ps_1.d3.bias == ps_1.d2.l1.bias

ps_new_1 = (; weight=randn(rng, Float32, 4, 2), bias=randn(rng, Float32, 4, 1)) |>
ps_new_1 = (; weight=randn(rng, Float32, 4, 2), bias=randn(rng, Float32, 4)) |>
device
ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2, 1)) |>
ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2)) |>
device

ps_2 = Lux.Experimental.share_parameters(ps, sharing, (ps_new_1, ps_new_2))
Expand Down Expand Up @@ -46,9 +46,9 @@
ps, sharing, (ps_new_1,))

# Parameter Structure Mismatch
ps_new_1 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 4, 1)) |>
ps_new_1 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 4)) |>
device
ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2, 1)) |>
ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2)) |>
device

@test_throws ArgumentError Lux.Experimental.share_parameters(
Expand Down
6 changes: 3 additions & 3 deletions test/helpers/compact_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@
ps, st = Lux.setup(rng, model) |> device

@test size(ps.w1.weight) == (128, 1)
@test size(ps.w1.bias) == (128, 1)
@test size(ps.w1.bias) == (128,)
@test length(ps.w2) == nlayers
for i in 1:nlayers
@test size(ps.w2[i].weight) == (128, 128)
@test size(ps.w2[i].bias) == (128, 1)
@test size(ps.w2[i].bias) == (128,)
end
@test size(ps.w3.weight) == (1, 128)
@test size(ps.w3.bias) == (1, 1)
@test size(ps.w3.bias) == (1,)

x = randn(n_in, 32) |> aType

Expand Down
22 changes: 1 addition & 21 deletions test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,26 +98,6 @@
@jet layer(x, ps, st)
__f = x -> sum(first(layer(x, ps, st)))
@eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3

f11(x) = x .* x

layer = WrappedFunction{:runtime_check}(f11)
display(layer)
ps, st = Lux.setup(rng, layer) .|> device
x = randn(rng, 2, 3) |> aType

@test layer(x, ps, st)[1] x .* x
@inferred layer(x, ps, st)

f12(x, ps, st) = x .+ 1, st

layer = WrappedFunction{:runtime_check}(f12)
display(layer)
ps, st = Lux.setup(rng, layer) .|> device
x = randn(rng, 2, 3) |> aType

@test layer(x, ps, st)[1] x .+ 1
@inferred layer(x, ps, st)
end

@testset "PeriodicEmbedding" begin
Expand Down Expand Up @@ -149,7 +129,7 @@ end
ps, st = Lux.setup(rng, layer) .|> device

@test size(ps.weight) == (100, 10)
@test size(ps.bias) == (100, 1)
@test size(ps.bias) == (100,)
@test layer.activation == identity

layer = Dense(10, 100, relu; use_bias=false)
Expand Down
8 changes: 3 additions & 5 deletions test/layers/containers_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

@testset "$mode" for (mode, aType, device, ongpu) in MODES
@testset "zero sum" begin
layer = SkipConnection(
WrappedFunction{:direct_call}(Broadcast.BroadcastFunction(zero)), .+)
layer = SkipConnection(WrappedFunction(Broadcast.BroadcastFunction(zero)), .+)
display(layer)
ps, st = Lux.setup(rng, layer) .|> device
x = randn(rng, 10, 10, 10, 10) |> aType
Expand Down Expand Up @@ -38,8 +37,7 @@ end
@testset "$mode" for (mode, aType, device, ongpu) in MODES
@testset "zero sum" begin
layer = Parallel(
+, WrappedFunction{:direct_call}(Broadcast.BroadcastFunction(zero)),
NoOpLayer())
+, WrappedFunction(Broadcast.BroadcastFunction(zero)), NoOpLayer())
display(layer)
ps, st = Lux.setup(rng, layer) .|> device
x = randn(rng, 10, 10, 10, 10) |> aType
Expand Down Expand Up @@ -320,7 +318,7 @@ end
end

@testset "constructors" begin
f1(x, ps, st::NamedTuple) = (x .+ 1, st)
f1(x) = x .+ 1
f2(x) = x .+ 2
model = Chain((Dense(2 => 3), Dense(3 => 2)), f1, f2, NoOpLayer())

Expand Down
2 changes: 1 addition & 1 deletion test/transform/flux_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@
@testset "Functions" begin
@test tolux(Flux.flatten) isa Lux.FlattenLayer
@test tolux(identity) isa Lux.NoOpLayer
@test tolux(+) isa Lux.WrappedFunction{:direct_call}
@test tolux(+) isa Lux.WrappedFunction
end

@testset "Unsupported Layers" begin
Expand Down

0 comments on commit 077a391

Please sign in to comment.