From 3afe9d71733e0a3ddfa30e5dc8ba112fdd6928f2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 13:35:46 -0700 Subject: [PATCH 01/95] chore!: remove cpu/gpu/stacktrace_truncation --- Project.toml | 2 +- docs/src/api/Lux/utilities.md | 24 --------------- src/deprecated.jl | 57 ----------------------------------- test/utils_tests.jl | 4 --- 4 files changed, 1 insertion(+), 86 deletions(-) diff --git a/Project.toml b/Project.toml index 55fd497790..bd381999a9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.5.68" +version = "1.0.0-DEV" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index 7d8a94f5b0..19489e766d 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -119,27 +119,3 @@ StatefulLuxLayer @init_fn @non_trainable ``` - -## Preferences - -```@docs -Lux.set_dispatch_doctor_preferences! -``` - -## Truncated Stacktraces (Deprecated) - -```@docs -Lux.disable_stacktrace_truncation! -``` - -## Device Management / Data Transfer (Deprecated) - -```@docs -Lux.cpu -Lux.gpu -``` - -!!! warning - - For detailed API documentation on Data Transfer check out the - [LuxDeviceUtils.jl](@ref LuxDeviceUtils-API) diff --git a/src/deprecated.jl b/src/deprecated.jl index 2073be8a6a..e69de29bb2 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -1,57 +0,0 @@ -# Deprecations for v1 -""" - cpu(x) - -Transfer `x` to CPU. - -!!! danger "Deprecation Notice" - - This function has been deprecated. Use [`cpu_device`](@ref) instead. -""" -function cpu end - -@deprecate cpu(x) (MLDataDevices.cpu_device())(x) - -""" - gpu(x) - -Transfer `x` to GPU determined by the backend set using [`Lux.gpu_backend!`](@ref). - -!!! danger "Deprecation Notice" - - This function has been deprecated. Use [`gpu_device`](@ref) instead. Using this function - inside performance critical code will cause massive slowdowns due to type inference - failure. -""" -function gpu end - -@deprecate gpu(x) (MLDataDevices.gpu_device())(x) - -""" - disable_stacktrace_truncation!(; disable::Bool=true) - -An easy way to update `TruncatedStacktraces.VERBOSE` without having to load it manually. - -Effectively does `TruncatedStacktraces.VERBOSE[] = disable` - -!!! danger "Deprecation Notice" - - This function is now deprecated and will be removed in v1. -""" -function disable_stacktrace_truncation!(; disable::Bool=true) - Base.depwarn( - "`disable_stacktrace_truncation!` is not needed anymore, as stacktraces are \ - truncated by default. This function is now deprecated and will be removed in v1.", - :disable_stacktrace_truncation) - return -end - -# Other deprecated functions -@deprecate xlogx(x::Number) LuxOps.xlogx(x) -@deprecate xlogy(x::Number, y::Number) LuxOps.xlogy(x, y) -@deprecate foldl_init(args...) LuxOps.foldl_init(args...) -@deprecate istraining(args...) LuxOps.istraining(args...) - -# While the ones below aren't public, we ended up using them at quite a few places -@deprecate _getproperty(args...) LuxOps.getproperty(args...) -@deprecate _eachslice(args...) LuxOps.eachslice(args...) diff --git a/test/utils_tests.jl b/test/utils_tests.jl index d017b11dfa..5a63a574be 100644 --- a/test/utils_tests.jl +++ b/test/utils_tests.jl @@ -34,10 +34,6 @@ end @testitem "Deprecations" tags=[:others] begin using Functors - @test_deprecated Lux.disable_stacktrace_truncation!() - @test_deprecated Lux.cpu(rand(2)) - @test_deprecated Lux.gpu(rand(2)) - model = NoOpLayer() @test_deprecated Lux.Experimental.StatefulLuxLayer(model, (;), (;)) From 1fbe795de714d5379872511909211c78c5127f88 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 13:37:49 -0700 Subject: [PATCH 02/95] chore!: remove old preferences --- src/preferences.jl | 33 ++++----------------------------- 1 file changed, 4 insertions(+), 29 deletions(-) diff --git a/src/preferences.jl b/src/preferences.jl index 920f749ad3..fdb64717dc 100644 --- a/src/preferences.jl +++ b/src/preferences.jl @@ -1,26 +1,3 @@ -module LuxPreferences - -using ArgCheck: @argcheck -using Preferences: load_preference, has_preference, set_preferences! - -using ..Lux: Lux - -macro deprecate_preference(old_pref, new_pref, default) - msg1 = "Preference `$(old_pref)` is deprecated and will be removed in a future \ - release. Use `$(new_pref)` instead." - msg2 = "Both `$(old_pref)` and `$(new_pref)` preferences are set. Please remove \ - `$(old_pref)`." - return esc(quote - if has_preference($(Lux), $(old_pref)) - Base.depwarn($msg1, $(Meta.quot(Symbol(Lux)))) - has_preference($(Lux), $(new_pref)) && error($msg2) - load_preference($(Lux), $(old_pref), $(default)) - else - load_preference($(Lux), $(new_pref), $(default)) - end - end) -end - macro load_preference_with_choices(pref, default, choices) msg1 = "Invalid value for `$(pref)` preference: " msg2 = ". Valid choices are: $(choices)" @@ -32,14 +9,12 @@ macro load_preference_with_choices(pref, default, choices) end # Nested AD -const AUTOMATIC_NESTED_AD_SWITCHING = @deprecate_preference("DisableAutomaticNestedADSwitching", - "automatic_nested_ad_switching", true) +const AUTOMATIC_NESTED_AD_SWITCHING = @load_preference("automatic_nested_ad_switching", + true) # GPU-Aware MPI -const MPI_CUDA_AWARE = @deprecate_preference("LuxDistributedMPICUDAAware", "cuda_aware_mpi", - false) -const MPI_ROCM_AWARE = @deprecate_preference("LuxDistributedMPIROCMAware", "rocm_aware_mpi", - false) +const MPI_CUDA_AWARE = @load_preference("cuda_aware_mpi", false) +const MPI_ROCM_AWARE = @load_preference("rocm_aware_mpi", false) # Eltype Auto Conversion const ELTYPE_MISMATCH_HANDLING = @load_preference_with_choices("eltype_mismatch_handling", From 5a3b01037d3464ea9c26dc6076930d2169b88f95 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 13:41:31 -0700 Subject: [PATCH 03/95] chore!: remove contrib deprecations --- docs/src/api/Lux/contrib.md | 18 ------ src/Lux.jl | 5 -- src/contrib/contrib.jl | 21 ------- src/contrib/deprecated.jl | 16 ------ src/deprecated.jl | 0 test/contrib/freeze_tests.jl | 2 - test/contrib/share_parameters_tests.jl | 2 - test/helpers/stateful_tests.jl | 2 - test/helpers/training_tests.jl | 79 ++++++++++++++++---------- test/utils_tests.jl | 11 ---- 10 files changed, 48 insertions(+), 108 deletions(-) delete mode 100644 src/contrib/deprecated.jl delete mode 100644 src/deprecated.jl diff --git a/docs/src/api/Lux/contrib.md b/docs/src/api/Lux/contrib.md index a5143ae01f..23cd26e14e 100644 --- a/docs/src/api/Lux/contrib.md +++ b/docs/src/api/Lux/contrib.md @@ -8,12 +8,6 @@ All features listed on this page are **experimental** which means: experimental sooner. 3. None of the features are exported. -!!! warning - - Starting v"0.5.2" all Experimental features need to be accessed via - `Lux.Experimental.`. Direct access via `Lux.` will be removed in - v"0.6". - ## Index ```@index @@ -56,15 +50,3 @@ Lux.Experimental.DebugLayer ```@docs Lux.Experimental.share_parameters ``` - -## StatefulLuxLayer - -[`Lux.StatefulLuxLayer`](@ref) used to be part of experimental features, but has been -promoted to stable API. It is now available via `Lux.StatefulLuxLayer`. Change all uses of -`Lux.Experimental.StatefulLuxLayer` to `Lux.StatefulLuxLayer`. - -## Compact Layer API - -[`Lux.@compact`](@ref) used to be part of experimental features, but has been promoted to -stable API. It is now available via `Lux.@compact`. Change all uses of -`Lux.Experimental.@compact` to `Lux.@compact`. diff --git a/src/Lux.jl b/src/Lux.jl index 712996c457..3db606c7bf 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -84,12 +84,7 @@ include("transform/simplechains.jl") include("distributed/backend.jl") include("distributed/public_api.jl") -# Deprecations -include("deprecated.jl") - # Layers -export cpu, gpu # deprecated - export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer export Bilinear, Dense, Embedding, Scale, PeriodicEmbedding export Conv, ConvTranspose, CrossCor, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool, diff --git a/src/contrib/contrib.jl b/src/contrib/contrib.jl index fdcf6b0c2a..a4f49aded1 100644 --- a/src/contrib/contrib.jl +++ b/src/contrib/contrib.jl @@ -22,7 +22,6 @@ include("map.jl") include("freeze.jl") include("share_parameters.jl") include("debug.jl") -include("deprecated.jl") @compat public layer_map, @layer_map @compat public FrozenLayer, freeze, unfreeze @@ -30,23 +29,3 @@ include("deprecated.jl") @compat public DebugLayer, @debug_mode end - -# Deprecations for v1.0 -macro layer_map(f, l, ps, st) - Base.depwarn( - "`Lux.@layer_map` has been deprecated in favor of `Lux.Experimental.@layer_map`", - Symbol("@layer_map")) - quote - Experimental.layer_map($(esc(f)), $(esc(l)), $(esc(ps)), $(esc(st)), $(string(l))) - end -end - -for f in (:layer_map, :share_parameters, :FrozenLayer, :freeze, :unfreeze) - msg = "`Lux.$(f)` has been deprecated in favor of `Lux.Experimental.$(f)`" - @eval begin - $(f)(args...; kwargs...) = begin - Base.depwarn($(msg), Symbol($(f))) - return Experimental.$(f)(args...; kwargs...) - end - end -end diff --git a/src/contrib/deprecated.jl b/src/contrib/deprecated.jl deleted file mode 100644 index c81ff89ee4..0000000000 --- a/src/contrib/deprecated.jl +++ /dev/null @@ -1,16 +0,0 @@ -macro compact(exs...) - Base.depwarn( - "Lux.Experimental.@compact` has been promoted out of `Lux.Experimental` and is now \ - available in `Lux`. In other words this has been deprecated and will be removed \ - in v1. Use `Lux.@compact` instead.", - Symbol("@compact")) - return Lux.CompactMacroImpl.compact_macro_impl(exs...) -end - -Base.@deprecate StatefulLuxLayer(args...; kwargs...) Lux.StatefulLuxLayer( - args...; kwargs...) false - -for f in (:TrainState, :TrainingBackendCache, :single_train_step, :single_train_step!, - :apply_gradients, :apply_gradients!, :compute_gradients) - @eval Base.@deprecate $f(args...; kwargs...) Training.$f(args...; kwargs...) false -end diff --git a/src/deprecated.jl b/src/deprecated.jl deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/test/contrib/freeze_tests.jl b/test/contrib/freeze_tests.jl index ad9a2416a5..96c449135d 100644 --- a/test/contrib/freeze_tests.jl +++ b/test/contrib/freeze_tests.jl @@ -7,8 +7,6 @@ d = Dense(5 => 5) psd, std = Lux.setup(rng, d) .|> dev - @test_deprecated Lux.freeze(d, psd, std, nothing) - fd, ps, st = Lux.Experimental.freeze(d, psd, std, nothing) @test length(keys(ps)) == 0 @test length(keys(st)) == 2 diff --git a/test/contrib/share_parameters_tests.jl b/test/contrib/share_parameters_tests.jl index f39f6bca79..92a8c19d67 100644 --- a/test/contrib/share_parameters_tests.jl +++ b/test/contrib/share_parameters_tests.jl @@ -9,8 +9,6 @@ sharing = (("d2.l2", "d1"), ("d3", "d2.l1")) - @test_deprecated Lux.share_parameters(ps, sharing) - ps_1 = Lux.Experimental.share_parameters(ps, sharing) @test ps_1.d2.l2.weight == ps_1.d1.weight diff --git a/test/helpers/stateful_tests.jl b/test/helpers/stateful_tests.jl index ba3c24691b..8392182a3d 100644 --- a/test/helpers/stateful_tests.jl +++ b/test/helpers/stateful_tests.jl @@ -12,8 +12,6 @@ @test st isa NamedTuple{()} - @test_deprecated StatefulLuxLayer(model, ps, st) - smodel = StatefulLuxLayer{false}(model, ps, st) display(smodel) @test smodel(1) isa Any diff --git a/test/helpers/training_tests.jl b/test/helpers/training_tests.jl index 7982ed6df3..38c934b2ad 100644 --- a/test/helpers/training_tests.jl +++ b/test/helpers/training_tests.jl @@ -7,7 +7,7 @@ model = Dense(3, 2) opt = Adam(0.01f0) - tstate = Lux.Experimental.TrainState(Lux.replicate(rng), model, opt) + tstate = Lux.Training.TrainState(Lux.replicate(rng), model, opt) x = randn(Lux.replicate(rng), Float32, (3, 1)) @@ -36,7 +36,7 @@ end model = Dense(3, 2) opt = Adam(0.01f0) - tstate = Lux.Experimental.TrainState( + tstate = Lux.Training.TrainState( Lux.replicate(rng), model, opt; transform_variables=dev) x = randn(Lux.replicate(rng), Float32, (3, 1)) |> aType @@ -45,9 +45,8 @@ end ongpu && (ad isa AutoReverseDiff || ad isa AutoEnzyme) && continue !LuxTestUtils.ENZYME_TESTING_ENABLED && ad isa AutoEnzyme && continue - grads, _, _, _ = Lux.Experimental.compute_gradients( - ad, _loss_function, x, tstate) - tstate_ = Lux.Experimental.apply_gradients(tstate, grads) + grads, _, _, _ = Lux.Training.compute_gradients(ad, _loss_function, x, tstate) + tstate_ = Lux.Training.apply_gradients(tstate, grads) @test tstate_.step == 1 @test tstate != tstate_ end @@ -78,9 +77,6 @@ end ongpu && (ad isa AutoReverseDiff || ad isa AutoEnzyme) && continue !LuxTestUtils.ENZYME_TESTING_ENABLED && ad isa AutoEnzyme && continue - @test_deprecated Lux.Experimental.TrainState( - Lux.replicate(rng), model, opt; transform_variables=dev) - tstate = Lux.Training.TrainState( Lux.replicate(rng), model, opt; transform_variables=dev) @@ -88,29 +84,20 @@ end for epoch in 1:100, (x, y) in dataset_ grads, loss, _, tstate = allow_unstable() do - Lux.Experimental.compute_gradients(ad, mse, (x, y), tstate) + Lux.Training.compute_gradients(ad, mse, (x, y), tstate) end - tstate = Lux.Experimental.apply_gradients!(tstate, grads) - end - - (x, y) = first(dataset_) - allow_unstable() do - @test_deprecated Lux.Experimental.compute_gradients(ad, mse, (x, y), tstate) - end - grads, loss, _, tstate = allow_unstable() do - Lux.Experimental.compute_gradients(ad, mse, (x, y), tstate) + tstate = Lux.Training.apply_gradients!(tstate, grads) end - @test_deprecated Lux.Experimental.apply_gradients(tstate, grads) for epoch in 1:100, (x, y) in dataset_ grads, loss, _, tstate = allow_unstable() do - Lux.Experimental.single_train_step!(ad, mse, (x, y), tstate) + Lux.Training.single_train_step!(ad, mse, (x, y), tstate) end end for epoch in 1:100, (x, y) in dataset_ grads, loss, _, tstate = allow_unstable() do - Lux.Experimental.single_train_step(ad, mse, (x, y), tstate) + Lux.Training.single_train_step(ad, mse, (x, y), tstate) end end @@ -134,10 +121,10 @@ end struct AutoCustomAD <: ADTypes.AbstractADType end - tstate = Lux.Experimental.TrainState( + tstate = Lux.Training.TrainState( Lux.replicate(rng), model, opt; transform_variables=dev) - @test_throws ArgumentError Lux.Experimental.compute_gradients( + @test_throws ArgumentError Lux.Training.compute_gradients( AutoCustomAD(), mse, dataset_[1], tstate) end end @@ -187,6 +174,36 @@ end else @test_broken false end + + rng = StableRNG(12345) + + model = Chain(Dense(4 => 3), VariationalHiddenDropout(0.5f0), Dense(3 => 4)) + ps, st = Lux.setup(rng, model) + x = randn(rng, Float32, 4, 32) + opt = Adam(0.001f0) + + tstate = Lux.Training.TrainState(model, ps, st, opt) + + _, _, _, tstate_new = @inferred Lux.Training.compute_gradients( + AutoEnzyme(), mse, (x, x), tstate) + + @test tstate_new.states !== tstate.states + + model = Chain(Dense(4 => 3), Dense(3 => 4)) + ps, st = Lux.setup(rng, model) + + tstate = Lux.Training.TrainState(model, ps, st, opt) + + _, _, _, tstate_new = @inferred Lux.Training.compute_gradients( + AutoEnzyme(), mse, (x, x), tstate) + + @test @inferred(Lux.Training.compute_gradients( + AutoEnzyme(), mse, (x, x), tstate_new)) isa Any + + _, _, _, tstate_new2 = @inferred Lux.Training.compute_gradients( + AutoEnzyme(), mse2, (x, x), tstate_new) + @test hasfield(typeof(tstate_new2.cache.extras), :forward) + @test hasfield(typeof(tstate_new2.cache.extras), :reverse) end @testitem "Compiled ReverseDiff" setup=[SharedTestSetup] tags=[:helpers] begin @@ -207,19 +224,19 @@ end Dense(32, 32, tanh), BatchNorm(32), Dense(32, 4)) ps, st = Lux.setup(rng, model) - tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + tstate = Lux.Training.TrainState(model, ps, st, Adam(0.001f0)) # Stateful models are not supported - @test_throws ArgumentError Lux.Experimental.compute_gradients( + @test_throws ArgumentError Lux.Training.compute_gradients( AutoReverseDiff(; compile=true), mse1, dataset[1], tstate) model = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 4)) ps, st = Lux.setup(rng, model) - tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + tstate = Lux.Training.TrainState(model, ps, st, Adam(0.001f0)) # Loss functions that return non-empty `stats` are not supported - @test_throws ArgumentError Lux.Experimental.compute_gradients( + @test_throws ArgumentError Lux.Training.compute_gradients( AutoReverseDiff(; compile=true), mse2, dataset[1], tstate) struct StrangeModel <: Lux.AbstractExplicitLayer end @@ -231,23 +248,23 @@ end model = StrangeModel() ps, st = Lux.setup(rng, model) - tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + tstate = Lux.Training.TrainState(model, ps, st, Adam(0.001f0)) # Stateful models are not supported - @test_throws ArgumentError Lux.Experimental.compute_gradients( + @test_throws ArgumentError Lux.Training.compute_gradients( AutoReverseDiff(; compile=true), mse1, dataset[1], tstate) end model = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 4)) ps, st = Lux.setup(rng, model) - tstate = Lux.Experimental.TrainState(model, ps, st, Adam(0.001f0)) + tstate = Lux.Training.TrainState(model, ps, st, Adam(0.001f0)) loss_initial = first(mse1(model, ps, st, dataset[1])) for i in 1:100 for (x, y) in dataset _, _, _, tstate = allow_unstable() do - Lux.Experimental.single_train_step!( + Lux.Training.single_train_step!( AutoReverseDiff(; compile=true), mse1, (x, y), tstate) end end diff --git a/test/utils_tests.jl b/test/utils_tests.jl index 5a63a574be..8db7b95eda 100644 --- a/test/utils_tests.jl +++ b/test/utils_tests.jl @@ -31,17 +31,6 @@ end @test eltype(ComponentArray(Any[:a, 1], (FlatAxis(),))) == Any end -@testitem "Deprecations" tags=[:others] begin - using Functors - - model = NoOpLayer() - @test_deprecated Lux.Experimental.StatefulLuxLayer(model, (;), (;)) - - @test_deprecated Lux.Experimental.DebugLayer(model; location="model") - dmodel = Lux.Experimental.DebugLayer(model; location="model") - @test dmodel.location == KeyPath(:model) -end - @testitem "multigate" setup=[SharedTestSetup] tags=[:others] begin rng = StableRNG(12345) From eff6f50d30ef134e5c846e2704a9bdce54dc40da Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 13:45:52 -0700 Subject: [PATCH 04/95] chore!: remove `st_fixed_type` --- src/contrib/contrib.jl | 3 ++- src/contrib/debug.jl | 14 ++------------ src/helpers/stateful.jl | 26 ++++++++------------------ src/utils.jl | 2 +- 4 files changed, 13 insertions(+), 32 deletions(-) diff --git a/src/contrib/contrib.jl b/src/contrib/contrib.jl index a4f49aded1..aee5036c62 100644 --- a/src/contrib/contrib.jl +++ b/src/contrib/contrib.jl @@ -1,6 +1,7 @@ module Experimental -using ..Lux: Lux, Training, Utils, Optional +using ..Lux: Lux, Training, Optional +using ..Utils: Utils, BoolType, SymbolType using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, apply using ADTypes: ADTypes diff --git a/src/contrib/debug.jl b/src/contrib/debug.jl index 3187208824..5e7fd3f78b 100644 --- a/src/contrib/debug.jl +++ b/src/contrib/debug.jl @@ -50,19 +50,9 @@ See [`Lux.Experimental.@debug_mode`](@ref) to construct this layer. location::KeyPath end -function DebugLayer(layer::AbstractExplicitLayer; - nan_check::Union{Symbol, StaticSymbol, Val}=static(:both), - error_check::Union{StaticBool, Bool, Val{true}, Val{false}}=True(), - location::Union{KeyPath, String}=KeyPath()) +function DebugLayer(layer::AbstractExplicitLayer; nan_check::SymbolType=static(:both), + error_check::BoolType=True(), location::KeyPath=KeyPath()) @argcheck dynamic(nan_check) in (:both, :forward, :backward, :none) - - if location isa String - Base.depwarn( - "Using a String for location in DebugLayer is deprecated. Use \ - `Functors.KeyPath` instead.", :DebugLayer) - location = KeyPath(Symbol.(split(location, "."))...) - end - return DebugLayer(static(nan_check), static(error_check), layer, location) end diff --git a/src/helpers/stateful.jl b/src/helpers/stateful.jl index 9d47a1c86b..376c2db3ac 100644 --- a/src/helpers/stateful.jl +++ b/src/helpers/stateful.jl @@ -1,6 +1,5 @@ """ - StatefulLuxLayer(model, ps, st; st_fixed_type = Val(true)) # deprecated - StatefulLuxLayer{ST}(model, ps, st) + StatefulLuxLayer{FT}(model, ps, st) !!! warning @@ -18,6 +17,13 @@ This is meant to be used in internal implementation of layers. - Facilitates Nested AD support in Lux. For more details on this feature, see the [Nested AD Manual Page](@ref nested_autodiff). +## Static Parameters + + - If `FT = true` then the type of the `state` is fixed, i.e., + `typeof(last(model(x, ps, st))) == st`. + - If `FT = false` then type of the state might change. Note that while this works in all + cases, it will introduce type instability. + ## Arguments - `model`: A Lux layer @@ -25,13 +31,6 @@ This is meant to be used in internal implementation of layers. the parameters on function call - `st`: The state of the layer -## Keyword Arguments - - - `st_fixed_type`: If `Val(true)`, then the type of the `state` is fixed, i.e., - `typeof(last(model(x, ps, st))) == st`. If this is not the case, then `st_fixed_type` - must be set to `Val(false)`. If `st_fixed_type` is set to `Val(false)`, then type - stability is not guaranteed. - ## Inputs - `x`: The input to the layer @@ -59,15 +58,6 @@ function StatefulLuxLayer{ST}(model, ps, st, st_any) where {ST} return StatefulLuxLayer(model, ps, st, st_any, static(ST)) end -function StatefulLuxLayer(model::AbstractExplicitLayer, st::NamedTuple; kwargs...) - return StatefulLuxLayer(model, nothing, st; kwargs...) -end -function StatefulLuxLayer(model::AbstractExplicitLayer, ps, st::NamedTuple; - st_fixed_type::Val{ST}=Val(true)) where {ST} - Base.depwarn("`st_fixed_type` is deprecated. Use `StatefulLuxLayer{ST}` instead.", - :StatefulLuxLayer) - return StatefulLuxLayer{ST}(model, ps, st) -end function StatefulLuxLayer{true}(model::AbstractExplicitLayer, ps, st::NamedTuple) return StatefulLuxLayer{true}(model, ps, st, nothing) end diff --git a/src/utils.jl b/src/utils.jl index d47cea6139..4009c9b115 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -17,7 +17,7 @@ const CRC = ChainRulesCore const BoolType = Union{StaticBool, Bool, Val{true}, Val{false}} const IntegerType = Union{Integer, StaticInteger} -const SymbolType = Union{Symbol, StaticSymbol} +const SymbolType = Union{Symbol, StaticSymbol, Val} # Aliased `size` from Base size(x::AbstractArray) = Base.size(x) From 3e68b542743f869110d59595146f0d4b30cfa3cd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 14:14:39 -0700 Subject: [PATCH 05/95] chore!: update layer_map/freeze --- docs/src/api/Lux/contrib.md | 7 +-- docs/src/manual/freezing_model_parameters.md | 50 ++++++++------------ src/contrib/map.jl | 40 ++++++---------- src/preferences.jl | 7 +++ 4 files changed, 42 insertions(+), 62 deletions(-) diff --git a/docs/src/api/Lux/contrib.md b/docs/src/api/Lux/contrib.md index 23cd26e14e..e79d6872e7 100644 --- a/docs/src/api/Lux/contrib.md +++ b/docs/src/api/Lux/contrib.md @@ -16,11 +16,6 @@ Pages = ["contrib.md"] ## Parameter Freezing -!!! info - - In the long term, this will be supported via - [Optimisers.jl](https://github.com/FluxML/Optimisers.jl/pull/49). - ```@docs Lux.Experimental.FrozenLayer Lux.Experimental.freeze @@ -32,8 +27,8 @@ For detailed usage example look at the [manual page](@ref freezing-model-paramet ## Map over Layer ```@docs -Lux.Experimental.layer_map Lux.Experimental.@layer_map +Lux.Experimental.layer_map ``` ## Debugging Functionality diff --git a/docs/src/manual/freezing_model_parameters.md b/docs/src/manual/freezing_model_parameters.md index 9b5c8ffb99..588972b67c 100644 --- a/docs/src/manual/freezing_model_parameters.md +++ b/docs/src/manual/freezing_model_parameters.md @@ -12,11 +12,10 @@ To freeze a particular kind of layer, let's say [`Dense`](@ref) in the following We can use [`Lux.Experimental.@layer_map`](@ref) and freeze layers if they are of type `Dense`. -```@example -using Lux, Random +```@example freezing_model_parameters +using Lux, Functors, Random -rng = Random.default_rng() -Random.seed!(rng, 0) +rng = Xoshiro(0) model = Chain(Dense(3, 4), Chain(Dense(4, 4), Dropout(0.5f0), BatchNorm(4)), Dense(4, 1); disable_optimizations=true) @@ -27,10 +26,10 @@ x = randn(rng, Float32, 3, 2) model(x, ps, st) -function freeze_dense(d::Lux.Dense, ps, st, ::String) - return Lux.freeze(d, ps, st, (:weight, :bias)) +function freeze_dense(d::Lux.Dense, ps, st, path) + return Lux.Experimental.freeze(d, ps, st, (:weight, :bias)) end -freeze_dense(l, ps, st, name) = (l, ps, st) +freeze_dense(l, ps, st, path) = (l, ps, st) model_frozen, ps_frozen, st_frozen = Lux.Experimental.@layer_map freeze_dense model ps st @@ -47,19 +46,17 @@ would be `.layer_2.layer_1`. ```julia [Freezing by Layer Name] -function freeze_by_name(d, ps, st, name::String) - if name == "model.layer_2.layer_1" +function freeze_by_name(d, ps, st, name::KeyPath) + name == KeyPath(:model, :layer_2, :layer_1) && return Lux.Experimental.freeze(d, ps, st, (:weight, :bias)) - else - return d, ps, st - end + return d, ps, st end ``` ```julia [Freezing by Layer Type] -function freeze_dense(d::Dense, ps, st, ::String) +function freeze_dense(d::Dense, ps, st, _) return Lux.Experimental.freeze(d, ps, st, (:weight, :bias)) end freeze_dense(l, ps, st, _) = (l, ps, st) @@ -77,24 +74,20 @@ the `weight` parameter while training the `bias` parameter. ```julia [Freezing Some Parameters of a Layer] -function freeze_by_name(d, ps, st, name::String) - if name == "model.layer_2.layer_1" - return Lux.freeze(d, ps, st, (:weight,)) - else - return d, ps, st - end +function freeze_by_name(d, ps, st, name::KeyPath) + name == KeyPath(:model, :layer_2, :layer_1) && + return Lux.Experimental.freeze(d, ps, st, (:weight,)) + return d, ps, st end ``` ```julia [Freezing All Parameters of a Layer] -function freeze_by_name(d, ps, st, name::String) - if name == "model.layer_2.layer_1" - return Lux.freeze(d, ps, st, (:weight, :bias)) - else - return d, ps, st - end +function freeze_by_name(d, ps, st, name::KeyPath) + name == KeyPath(:model, :layer_2, :layer_1) && + return Lux.Experimental.freeze(d, ps, st, (:weight, :bias)) + return d, ps, st end ``` @@ -103,10 +96,7 @@ end ## Freezing Part of a Chain -Starting `v0.4.22`, we can directly index into a `Chain`. So freezing a part of a `Chain`, -is extremely easy. - -```@example +```@example freezing_model_parameters using Lux, Random rng = Random.default_rng() @@ -114,7 +104,7 @@ Random.seed!(rng, 0) model = Chain(Dense(3, 4), Dense(4, 4), Dropout(0.5f0), BatchNorm(4), Dense(4, 1)) -model_frozen = Chain(model[1:2], Lux.freeze(model[3:4]), model[5]) +model_frozen = Chain(model[1:2], Lux.Experimental.freeze(model[3:4]), model[5]) ps, st = Lux.setup(rng, model_frozen) x = randn(rng, Float32, 3, 2) diff --git a/src/contrib/map.jl b/src/contrib/map.jl index 17f1612c7a..6cb6c6cf6f 100644 --- a/src/contrib/map.jl +++ b/src/contrib/map.jl @@ -2,8 +2,8 @@ @layer_map func layer ps st See the documentation of [`Lux.Experimental.layer_map`](@ref) for more details. This macro -eliminates the need to the set the layer name, and uses the variable name as the starting -point. +eliminates the need to the set the layer name, and uses the variable name of layer as the +starting point. ## Example @@ -28,9 +28,9 @@ julia> # Makes parameters of Dense Layers inside Chain zero end; julia> _, ps_new, _ = Lux.Experimental.@layer_map zero_dense_params c ps st; -zeroing params of c.layers.chain.layers.dense_1 -zeroing params of c.layers.chain.layers.dense_2 -zeroing params of c.layers.dense_3 +zeroing params of KeyPath(:c, :layers, :chain, :layers, :dense_1) +zeroing params of KeyPath(:c, :layers, :chain, :layers, :dense_2) +zeroing params of KeyPath(:c, :layers, :dense_3) julia> all(iszero, (ps_new.chain.dense_1.weight, ps_new.chain.dense_1.bias, ps_new.chain.dense_2.weight, ps_new.chain.dense_2.bias, @@ -39,14 +39,14 @@ true ``` """ macro layer_map(f, l, ps, st) - return quote - layer_map($(esc(f)), $(esc(l)), $(esc(ps)), $(esc(st)), $(string(l))) + quote + layer_map($(esc(f)), $(esc(l)), $(esc(ps)), $(esc(st)), $(Meta.quot(l))) end end @doc doc""" layer_map(f::Function, l::AbstractExplicitLayer, ps, st::NamedTuple, - name::String="model") + name::Symbol=:model) Map the function `f` over the model `l`, with the parameters `ps` and states `st`. This is different from `Functors.fmap` since it zips the layers, parameters, and states and invokes @@ -55,7 +55,7 @@ the function on all of them together. ## Call Signature for `f` - Must take 4 inputs -- `AbstractExplicitLayer`, Corresponding Parameters, Corresponding - States, and the name of the layer. + States, and the `Functors.KeyPath` to the layer. - Must return a tuple of 3 elements -- `AbstractExplicitLayer`, new parameters and the new states. @@ -64,12 +64,6 @@ the function on all of them together. We recommend using the macro `Lux.Experimental.@layer_map` instead of this function. It automatically sets the `name` of the layer to be the variable name. -!!! danger "Deprecation Notice" - - Starting `v1`, instead of the name of the layer, we will provide the [KeyPath to the - layer](https://fluxml.ai/Functors.jl/stable/api/#KeyPath). The current version of - providing a String has been deprecated. - # Extended Help ## Example @@ -77,7 +71,6 @@ the function on all of them together. ```jldoctest julia> using Lux, Random - julia> c = Parallel( +; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)), dense_3=Dense(5 => 1)); @@ -96,9 +89,9 @@ julia> # Makes parameters of Dense Layers inside Chain zero end; julia> _, ps_new, _ = Lux.Experimental.layer_map(zero_dense_params, c, ps, st); -zeroing params of model.layers.chain.layers.dense_1 -zeroing params of model.layers.chain.layers.dense_2 -zeroing params of model.layers.dense_3 +zeroing params of KeyPath(:model, :layers, :chain, :layers, :dense_1) +zeroing params of KeyPath(:model, :layers, :chain, :layers, :dense_2) +zeroing params of KeyPath(:model, :layers, :dense_3) julia> all(iszero, (ps_new.chain.dense_1.weight, ps_new.chain.dense_1.bias, ps_new.chain.dense_2.weight, ps_new.chain.dense_2.bias, @@ -106,16 +99,11 @@ julia> all(iszero, (ps_new.chain.dense_1.weight, ps_new.chain.dense_1.bias, true ``` """ -function layer_map(f::F, l, ps, st, name::String="model") where {F <: Function} - # TODO: In v1 deprecate passing the string - f_wrapper = @closure (kp, layer, ps_, st_) -> f( - layer, ps_, st_, __keypath_to_string(name, kp)) +function layer_map(f::F, l, ps, st, name::Symbol=:model) where {F <: Function} + f_wrapper = @closure (kp, layer, ps_, st_) -> f(layer, ps_, st_, KeyPath(name, kp)) return fmap_with_path(f_wrapper, l, ps, st; walk=LayerWalkWithPath()) end -__keypath_to_string(kp::KeyPath) = join(kp.keys, ".") -__keypath_to_string(str::String, kp::KeyPath) = "$(str).$(__keypath_to_string(kp))" - struct LayerWalkWithPath <: Functors.AbstractWalk end function (::LayerWalkWithPath)(recurse, kp::KeyPath, layer, ps, st) diff --git a/src/preferences.jl b/src/preferences.jl index fdb64717dc..a1ec0eef3b 100644 --- a/src/preferences.jl +++ b/src/preferences.jl @@ -1,3 +1,10 @@ +module LuxPreferences + +using ArgCheck: @argcheck +using Preferences: load_preference, has_preference, set_preferences! + +using ..Lux: Lux + macro load_preference_with_choices(pref, default, choices) msg1 = "Invalid value for `$(pref)` preference: " msg2 = ". Valid choices are: $(choices)" From 7d21d8c04ebd659e18ecf74302837fdcb9543967 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 14:24:34 -0700 Subject: [PATCH 06/95] chore!: remove flattening of chains --- docs/src/manual/debugging.md | 12 ++- docs/src/manual/freezing_model_parameters.md | 3 +- ext/LuxFluxExt.jl | 7 +- ext/LuxSimpleChainsExt.jl | 3 +- src/layers/basic.jl | 14 +--- src/layers/containers.jl | 78 ++++++-------------- test/contrib/debug_tests.jl | 14 ++-- test/layers/containers_tests.jl | 2 +- test/transform/simple_chains_tests.jl | 2 +- test/utils_tests.jl | 6 +- 10 files changed, 45 insertions(+), 96 deletions(-) diff --git a/docs/src/manual/debugging.md b/docs/src/manual/debugging.md index 642c1fb925..3c3395a7d5 100644 --- a/docs/src/manual/debugging.md +++ b/docs/src/manual/debugging.md @@ -21,8 +21,7 @@ will see how easy it is to pin-point the problematic layer. ```@example manual_debugging using Lux, Random -model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 3), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) +model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 3), Dense(1 => 1)), BatchNorm(1)) model_debug = Lux.Experimental.@debug_mode model ``` @@ -63,12 +62,12 @@ model = Chain(Dense(1 => 16, relu), Dense(16 => 3), # [!code --] Dense(16 => 1), # [!code ++] Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) + BatchNorm(1)) ``` ```@example manual_debugging model_fixed = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) + BatchNorm(1)) ps, st = Lux.setup(rng, model_fixed) @@ -88,7 +87,7 @@ debug model. (or even disable it by setting it to `:none`) ```@example manual_debugging model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) + BatchNorm(1)) ps, st = Lux.setup(rng, model) @@ -131,8 +130,7 @@ offending_layer(x) = 2 .* x ``` ```@example manual_debugging -model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), offending_layer), - BatchNorm(1); disable_optimizations=true) +model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), offending_layer), BatchNorm(1)) ps, st = Lux.setup(rng, model) diff --git a/docs/src/manual/freezing_model_parameters.md b/docs/src/manual/freezing_model_parameters.md index 588972b67c..027de13c64 100644 --- a/docs/src/manual/freezing_model_parameters.md +++ b/docs/src/manual/freezing_model_parameters.md @@ -17,8 +17,7 @@ using Lux, Functors, Random rng = Xoshiro(0) -model = Chain(Dense(3, 4), Chain(Dense(4, 4), Dropout(0.5f0), BatchNorm(4)), - Dense(4, 1); disable_optimizations=true) +model = Chain(Dense(3, 4), Chain(Dense(4, 4), Dropout(0.5f0), BatchNorm(4)), Dense(4, 1)) ps, st = Lux.setup(rng, model) diff --git a/ext/LuxFluxExt.jl b/ext/LuxFluxExt.jl index 54c9675854..6766b925d9 100644 --- a/ext/LuxFluxExt.jl +++ b/ext/LuxFluxExt.jl @@ -22,11 +22,8 @@ Lux.convert_flux_model(l::Function; kwargs...) = Lux.WrappedFunction{:direct_cal function Lux.convert_flux_model(l::Flux.Chain; kwargs...) fn = x -> Lux.convert_flux_model(x; kwargs...) layers = map(fn, l.layers) - if layers isa NamedTuple - return Lux.Chain(layers; disable_optimizations=true) - else - return Lux.Chain(layers...; disable_optimizations=true) - end + layers isa NamedTuple && return Lux.Chain(layers) + return Lux.Chain(layers...) end function Lux.convert_flux_model(l::Flux.Dense; preserve_ps_st::Bool=false, kwargs...) diff --git a/ext/LuxSimpleChainsExt.jl b/ext/LuxSimpleChainsExt.jl index c7a607b250..1d10fc106a 100644 --- a/ext/LuxSimpleChainsExt.jl +++ b/ext/LuxSimpleChainsExt.jl @@ -17,8 +17,7 @@ end function Lux.fix_simplechain_input_dims(layers, input_dims) @warn "The model provided is not a `Chain`. Trying to wrap it into a `Chain` but this \ - might fail. Please consider using `Chain` directly (potentially with \ - `disable_optimizations = true`)." + might fail. Please consider using `Chain` directly." return fix_simplechain_input_dims([layers], input_dims) end diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 77cd98f482..bc2501899d 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -218,7 +218,7 @@ struct NoOpLayer <: AbstractExplicitLayer end """ WrappedFunction{DC}(f) - WrappedFunction(f) -> WrappedFunction{:direct_call}(f) + WrappedFunction(f) -> WrappedFunction{:runtime_check}(f) Wraps a stateless and parameter less function. Might be used when a function is added to `Chain`. For example, `Chain(x -> relu.(x))` would not work and the right thing to do would @@ -229,8 +229,7 @@ be `Chain((x, ps, st) -> (relu.(x), st))`. An easier thing to do would be - `DC`: If `:runtime_check`, then we check if the function can be called with the input `x`, `ps`, and `st` using `hasmethod`. If `:direct_call`, we call `f(x)` directly. - For all other values, we call `f(x, ps, st)` which must return a tuple. **(In future - versions, we will default to `:runtime_check`)** + For all other values, we call `f(x, ps, st)` which must return a tuple. - `f`: Some function. ## Inputs @@ -252,14 +251,7 @@ function WrappedFunction{call_mode}(f::F) where {call_mode, F} return WrappedFunction(static(call_mode), f) end -function WrappedFunction(f::F) where {F} - # Not a depwarn but helpful to call this - Base.depwarn("The current default of `:direct_call` will be replaced with \ - `:runtime_check` from v1). Please make sure that the assumptions of \ - this function are correct or specify `WrappedFunction{:direct_call}(f)`", - :WrappedFunction) - return WrappedFunction{:direct_call}(f) -end +WrappedFunction(f::F) where {F} = WrappedFunction{:runtime_check}(f) function (wf::WrappedFunction{:direct_call})(x, ps, st::NamedTuple) return wrapped_function_call(wf.func, x, ps, st, True()) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index 21c1506a8c..20b605d36c 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -194,8 +194,6 @@ end return Expr(:block, calls...) end -Base.keys(m::Parallel) = Base.keys(getfield(m, :layers)) - """ BranchLayer(layers...) BranchLayer(; name=nothing, layers...) @@ -283,8 +281,6 @@ BranchLayer(; name::NAME_TYPE=nothing, kwargs...) = BranchLayer((; kwargs...), n return Expr(:block, calls...) end -Base.keys(m::BranchLayer) = Base.keys(getfield(m, :layers)) - """ PairwiseFusion(connection, layers...; name=nothing) PairwiseFusion(connection; name=nothing, layers...) @@ -391,11 +387,9 @@ end return Expr(:block, calls...) end -Base.keys(m::PairwiseFusion) = Base.keys(getfield(m, :layers)) - """ - Chain(layers...; name=nothing, disable_optimizations::Bool = false) - Chain(; layers..., name=nothing, disable_optimizations::Bool = false) + Chain(layers...; name=nothing) + Chain(; layers..., name=nothing) Collects multiple layers / functions to be called in sequence on a given input. @@ -433,20 +427,6 @@ of the internal layers. - States of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) -## Optimizations - -Performs a few optimizations to generate reasonable architectures. Can be disabled using -keyword argument `disable_optimizations`. - - - All sublayers are recursively optimized. - - If a function `f` is passed as a layer and it doesn't take 3 inputs, it is converted to - a [`WrappedFunction`](@ref)(`f`) which takes only one input. - - If the layer is a Chain, it is flattened. - - [`NoOpLayer`](@ref)s are removed. - - If there is only 1 layer (left after optimizations), then it is returned without the - `Chain` wrapper. - - If there are no layers (left after optimizations), a [`NoOpLayer`](@ref) is returned. - ## Miscellaneous Properties - Allows indexing and field access syntax. We can access the `i`th layer by `m[i]` or @@ -462,6 +442,14 @@ Chain( layer_3 = Dense(3 => 2), # 8 parameters ) # Total: 23 parameters, # plus 7 states. + +julia> Chain(Dense(2, 3, relu), BatchNorm(3), Dense(3, 2); name="MyFancyChain") +MyFancyChain( + layer_1 = Dense(2 => 3, relu), # 9 parameters + layer_2 = BatchNorm(3, affine=true, track_stats=true), # 6 parameters, plus 7 + layer_3 = Dense(3 => 2), # 8 parameters +) # Total: 23 parameters, + # plus 7 states. ``` """ @concrete struct Chain <: AbstractExplicitContainerLayer{(:layers,)} @@ -469,51 +457,31 @@ Chain( name end -function Chain(xs...; name::NAME_TYPE=nothing, disable_optimizations::Bool=false) - xs = disable_optimizations ? xs : flatten_lux_chain(xs) - length(xs) == 0 && return NoOpLayer() - length(xs) == 1 && return first(xs) - return Chain(Utils.named_tuple_layers(xs...), name) +function Chain(xs...; name::NAME_TYPE=nothing) + return Chain(Utils.named_tuple_layers(wrap_functions_in_chain_call(xs)...), name) end - Chain(xs::AbstractVector; kwargs...) = Chain(xs...; kwargs...) +Chain(nt::NamedTuple; name::NAME_TYPE=nothing) = Chain(nt, name) +Chain(; name::NAME_TYPE=nothing, kwargs...) = Chain((; kwargs...); name) -function Chain(nt::NamedTuple; disable_optimizations::Bool=true, name::NAME_TYPE=nothing) - if !disable_optimizations - throw(ArgumentError("Chain(::NamedTuple) is not compatible with disable_optimizations=true")) - end - return Chain(nt, name) -end - -function Chain(; disable_optimizations::Bool=true, name::NAME_TYPE=nothing, kwargs...) - return Chain((; kwargs...); disable_optimizations, name) -end - -function flatten_lux_chain(layers::Union{AbstractVector, Tuple}) +function wrap_functions_in_chain_call(layers::Union{AbstractVector, Tuple}) new_layers = [] for l in layers - f = flatten_lux_chain(l) + f = wrap_functions_in_chain_call(l) if f isa Tuple || f isa AbstractVector append!(new_layers, f) elseif f isa Function - if !hasmethod(f, (Any, Any, NamedTuple)) - f === identity && continue - push!(new_layers, WrappedFunction{:direct_call}(f)) - else - push!(new_layers, WrappedFunction{:layer}(f)) - end - elseif f isa Chain - append!(new_layers, f.layers) - elseif f isa NoOpLayer - continue - else + push!(new_layers, WrappedFunction(f)) + elseif f isa AbstractExplicitLayer push!(new_layers, f) + else + throw("Encountered a non-AbstractExplicitLayer in Chain.") end end return layers isa AbstractVector ? new_layers : Tuple(new_layers) end -flatten_lux_chain(x) = x +wrap_functions_in_chain_call(x) = x (c::Chain)(x, ps, st::NamedTuple) = applychain(c.layers, x, ps, st) @@ -530,8 +498,6 @@ flatten_lux_chain(x) = x return Expr(:block, calls...) end -Base.keys(c::Chain) = Base.keys(getfield(c, :layers)) - Base.getindex(c::Chain, i::Int) = c.layers[i] Base.getindex(c::Chain, i::AbstractArray) = Chain(Utils.index_namedtuple(c.layers, i)) @@ -620,8 +586,6 @@ Maxout(f::Function, n_alts::Int) = Maxout(ntuple(Returns(f()), n_alts)...) return Expr(:block, calls...) end -Base.keys(m::Maxout) = Base.keys(getfield(m, :layers)) - """ RepeatedLayer(model; repeats::Val = Val(10), input_injection::Val = Val(false)) diff --git a/test/contrib/debug_tests.jl b/test/contrib/debug_tests.jl index 469f7787a6..b06038af83 100644 --- a/test/contrib/debug_tests.jl +++ b/test/contrib/debug_tests.jl @@ -3,9 +3,9 @@ rng = StableRNG(12345) - @testset "$mode" for (mode, aType, dev, ongpu) in MODES - model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 3), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) + @testset "$mode" for (mode, aType, device, ongpu) in MODES + model = Chain( + Dense(1 => 16, relu), Chain(Dense(16 => 3), Dense(1 => 1)), BatchNorm(1)) ps, st = Lux.setup(rng, model) |> dev x = randn(rng, Float32, 1, 5) |> aType @@ -29,8 +29,8 @@ catch end - model_fixed = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) + model_fixed = Chain( + Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), BatchNorm(1)) ps, st = Lux.setup(rng, model_fixed) |> dev @@ -85,8 +85,8 @@ end model_debug4 = Lux.Experimental.@debug_mode model nan_check=:none @test any(isnan, first(model_debug4(x, ps, st)) |> Array) - model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), offending_layer), - BatchNorm(1); disable_optimizations=true) + model = Chain( + Dense(1 => 16, relu), Chain(Dense(16 => 1), offending_layer), BatchNorm(1)) ps, st = Lux.setup(rng, model) |> dev diff --git a/test/layers/containers_tests.jl b/test/layers/containers_tests.jl index 6d1eaaed7f..aeb0fd8e88 100644 --- a/test/layers/containers_tests.jl +++ b/test/layers/containers_tests.jl @@ -332,7 +332,7 @@ end @test_throws ArgumentError Chain(; l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), - d21=Dense(2 => 1), d2=Dense(2 => 1), disable_optimizations=false) + d21=Dense(2 => 1), d2=Dense(2 => 1)) @testset "indexing and field access" begin encoder = Chain(Dense(10 => 5, sigmoid), Dense(5 => 2, tanh)) diff --git a/test/transform/simple_chains_tests.jl b/test/transform/simple_chains_tests.jl index 5f22bea523..76938a6ae3 100644 --- a/test/transform/simple_chains_tests.jl +++ b/test/transform/simple_chains_tests.jl @@ -87,7 +87,7 @@ for dims in (static(10), (static(10),)) adaptor = ToSimpleChainsAdaptor(dims) - simple_chains_model = @test_warn "The model provided is not a `Chain`. Trying to wrap it into a `Chain` but this might fail. Please consider using `Chain` directly (potentially with `disable_optimizations = true`)." adaptor(lux_model) + simple_chains_model = @test_warn "The model provided is not a `Chain`. Trying to wrap it into a `Chain` but this might fail. Please consider using `Chain` directly." adaptor(lux_model) ps, st = Lux.setup(Random.default_rng(), simple_chains_model) diff --git a/test/utils_tests.jl b/test/utils_tests.jl index 8db7b95eda..5086c5109c 100644 --- a/test/utils_tests.jl +++ b/test/utils_tests.jl @@ -127,9 +127,9 @@ end @testitem "FP Conversions" setup=[SharedTestSetup] tags=[:others] begin rng = StableRNG(12345) - @testset "$mode" for (mode, aType, dev, ongpu) in MODES - model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) + @testset "$mode" for (mode, aType, device, ongpu) in MODES + model = Chain( + Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), BatchNorm(1)) for (f, ftype) in zip((f16, f32, f64), (Float16, Float32, Float64)) ps, st = Lux.setup(rng, model) |> dev |> f From ab46dd8998b4fabc89a0324df217e49a146c9a19 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 18:03:03 -0700 Subject: [PATCH 07/95] fix: remove old usage of TrainState --- examples/ConvMixer/main.jl | 3 ++- examples/DDIM/main.jl | 4 ++-- examples/HyperNet/main.jl | 4 ++-- examples/PolynomialFitting/main.jl | 10 ++++++---- examples/SimpleChains/main.jl | 4 ++-- examples/SimpleRNN/main.jl | 6 ++++-- test/helpers/training_tests.jl | 14 +++++++------- 7 files changed, 25 insertions(+), 20 deletions(-) diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index 372259d817..21ece57fc1 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -77,12 +77,13 @@ end trainloader, testloader = get_dataloaders(batchsize) model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) + ps, st = Lux.setup(rng, model) |> gdev opt = AdamW(; eta=lr_max, lambda=weight_decay) clip_norm && (opt = OptimiserChain(ClipNorm(), opt)) train_state = Training.TrainState( - rng, model, AdamW(; eta=lr_max, lambda=weight_decay); transform_variables=gdev) + model, ps, st, AdamW(; eta=lr_max, lambda=weight_decay)) lr_schedule = linear_interpolation( [0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]) diff --git a/examples/DDIM/main.jl b/examples/DDIM/main.jl index ec31991b5f..e97d72939f 100644 --- a/examples/DDIM/main.jl +++ b/examples/DDIM/main.jl @@ -330,6 +330,7 @@ end @info "Building model" model = ddim(rng, (image_size, image_size); channels, block_depth, min_freq, max_freq, embedding_dims, min_signal_rate, max_signal_rate) + ps, st = Lux.setup(rng, model) |> gdev if inference_mode @argcheck saved_model_path!==nothing "`saved_model_path` must be specified for inference" @@ -354,8 +355,7 @@ end tb_logger = TBLogger(tb_dir) tstate = Training.TrainState( - rng, model, AdamW(; eta=learning_rate_start, lambda=weight_decay); - transform_variables=gdev) + model, ps, st, AdamW(; eta=learning_rate_start, lambda=weight_decay)) @info "Preparing dataset" ds = FlowersDataset(x -> preprocess_image(x, image_size), true) diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index 9045091c77..fef29796fd 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -80,10 +80,10 @@ function train() dataloaders = load_datasets() dev = gpu_device() - rng = Xoshiro(0) + ps, st = Lux.setup(rng, model) |> dev - train_state = Training.TrainState(rng, model, Adam(3.0f-4); transform_variables=dev) + train_state = Training.TrainState(model, ps, st, Adam(3.0f-4)) ### Lets train the model nepochs = 10 diff --git a/examples/PolynomialFitting/main.jl b/examples/PolynomialFitting/main.jl index 2db9b65916..217f1f3e1b 100644 --- a/examples/PolynomialFitting/main.jl +++ b/examples/PolynomialFitting/main.jl @@ -55,12 +55,17 @@ opt = Adam(0.03f0) # functions provided by Lux. const loss_function = MSELoss() +const dev_cpu = cpu_device() +const dev_gpu = gpu_device() + +ps, st = Lux.setup(rng, model) |> dev_gpu + # ## Training # First we will create a [`Training.TrainState`](@ref) which is essentially a # convenience wrapper over parameters, states and optimizer states. -tstate = Training.TrainState(rng, model, opt) +tstate = Training.TrainState(model, ps, st, opt) # Now we will use Zygote for our AD requirements. @@ -79,9 +84,6 @@ function main(tstate::Training.TrainState, vjp, data, epochs) return tstate end -dev_cpu = cpu_device() -dev_gpu = gpu_device() - tstate = main(tstate, vjp_rule, (x, y), 250) y_pred = dev_cpu(Lux.apply(tstate.model, dev_gpu(x), tstate.parameters, tstate.states)[1]) nothing #hide diff --git a/examples/SimpleChains/main.jl b/examples/SimpleChains/main.jl index 480865cd2f..0bbd62944b 100644 --- a/examples/SimpleChains/main.jl +++ b/examples/SimpleChains/main.jl @@ -61,9 +61,9 @@ end # ## Define the Training Loop function train(model; rng=Xoshiro(0), kwargs...) train_dataloader, test_dataloader = loadmnist(128, 0.9) + ps, st = Lux.setup(rng, model) - train_state = Training.TrainState( - rng, model, Adam(3.0f-4); transform_variables=identity) + train_state = Training.TrainState(model, ps, st, Adam(3.0f-4)) ### Warmup the model x_proto = randn(rng, Float32, 28, 28, 1, 1) diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index ef120c5a08..8e7dd45a18 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -133,12 +133,14 @@ function main(model_type) ## Get the dataloaders (train_loader, val_loader) = get_dataloaders() + dev = gpu_device() + ## Create the model model = model_type(2, 8, 1) rng = Xoshiro(0) + ps, st = Lux.setup(rng, model) |> dev - dev = gpu_device() - train_state = Training.TrainState(rng, model, Adam(0.01f0); transform_variables=dev) + train_state = Training.TrainState(model, ps, st, Adam(0.01f0)) for epoch in 1:25 ## Train the model diff --git a/test/helpers/training_tests.jl b/test/helpers/training_tests.jl index 38c934b2ad..21f3dc3fc3 100644 --- a/test/helpers/training_tests.jl +++ b/test/helpers/training_tests.jl @@ -6,8 +6,9 @@ @testset "$mode" for (mode, aType, dev, ongpu) in MODES model = Dense(3, 2) opt = Adam(0.01f0) + ps, st = Lux.setup(rng, model) |> dev - tstate = Lux.Training.TrainState(Lux.replicate(rng), model, opt) + tstate = Lux.Training.TrainState(model, ps, st, opt) x = randn(Lux.replicate(rng), Float32, (3, 1)) @@ -35,9 +36,9 @@ end @testset "$mode" for (mode, aType, dev, ongpu) in MODES model = Dense(3, 2) opt = Adam(0.01f0) + ps, st = Lux.setup(rng, model) |> dev - tstate = Lux.Training.TrainState( - Lux.replicate(rng), model, opt; transform_variables=dev) + tstate = Lux.Training.TrainState(model, ps, st, opt) x = randn(Lux.replicate(rng), Float32, (3, 1)) |> aType @@ -71,14 +72,14 @@ end Dense(32, 32, tanh), BatchNorm(32), Dense(32, 4)) dataset_ = [dev((x, y)) for (x, y) in dataset] opt = Adam(0.001f0) + ps, st = Lux.setup(rng, model) |> dev @testset "$(ad)" for ad in ( AutoZygote(), AutoTracker(), AutoReverseDiff(), AutoEnzyme()) ongpu && (ad isa AutoReverseDiff || ad isa AutoEnzyme) && continue !LuxTestUtils.ENZYME_TESTING_ENABLED && ad isa AutoEnzyme && continue - tstate = Lux.Training.TrainState( - Lux.replicate(rng), model, opt; transform_variables=dev) + tstate = Lux.Training.TrainState(model, ps, st, opt) initial_loss = first(mse(model, tstate.parameters, tstate.states, dataset_[1])) @@ -121,8 +122,7 @@ end struct AutoCustomAD <: ADTypes.AbstractADType end - tstate = Lux.Training.TrainState( - Lux.replicate(rng), model, opt; transform_variables=dev) + tstate = Lux.Training.TrainState(model, ps, st, opt) @test_throws ArgumentError Lux.Training.compute_gradients( AutoCustomAD(), mse, dataset_[1], tstate) From 9ea6e607cfd5ef69bbacb10e5d62db493808632d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 18:04:06 -0700 Subject: [PATCH 08/95] chore: remove old `transform` export --- src/Lux.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Lux.jl b/src/Lux.jl index 3db606c7bf..c6f562ead2 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -113,7 +113,6 @@ export GenericLossFunction export f16, f32, f64 export match_eltype -export transform export FromFluxAdaptor, FluxLayer export ToSimpleChainsAdaptor, SimpleChainsLayer export DynamicExpressionsLayer From ae66cc8f1b1fa75ba4a39228ea963736984bfd77 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 2 Jul 2024 18:33:35 -0700 Subject: [PATCH 09/95] ci: remove unncessary vars --- .github/workflows/CI.yml | 2 -- test/contrib/map_tests.jl | 11 ++++++++--- test/layers/containers_tests.jl | 7 +------ 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f008abdef8..9790c8d09d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -154,8 +154,6 @@ jobs: with: version: ${{ matrix.version }} - uses: julia-actions/julia-downgrade-compat@v1 - with: - skip: 'AMDGPU' - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: diff --git a/test/contrib/map_tests.jl b/test/contrib/map_tests.jl index cd3cdfb323..5932d5d393 100644 --- a/test/contrib/map_tests.jl +++ b/test/contrib/map_tests.jl @@ -1,8 +1,13 @@ @testitem "Layer Map" setup=[SharedTestSetup] tags=[:contrib] begin - using Setfield + using Setfield, Functors + + function __occurs_in(kp::KeyPath, x::KeyPath) + length(kp) ≤ length(x) && return all(==(x[i], kp[i]) for i in 1:length(x)) + return false + end function zero_dense_params_1(l, ps, st, name) - if l isa Dense && occursin("model.layers.chain", name) + if l isa Dense && __occurs_in(KeyPath(:model, :layers, :chain), name) @set! ps.weight = zero.(ps.weight) @set! ps.bias = zero.(ps.bias) end @@ -10,7 +15,7 @@ end function zero_dense_params_2(l, ps, st, name) - if l isa Dense && occursin("c.layers.chain", name) + if l isa Dense && __occurs_in(KeyPath(:c, :layers, :chain), name) @set! ps.weight = zero.(ps.weight) @set! ps.bias = zero.(ps.bias) end diff --git a/test/layers/containers_tests.jl b/test/layers/containers_tests.jl index aeb0fd8e88..d7fd54537f 100644 --- a/test/layers/containers_tests.jl +++ b/test/layers/containers_tests.jl @@ -44,7 +44,6 @@ end layer = Parallel( +, WrappedFunction{:direct_call}(Broadcast.BroadcastFunction(zero)), NoOpLayer()) - @test :layer_1 in keys(layer) && :layer_2 in keys(layer) display(layer) ps, st = Lux.setup(rng, layer) |> dev x = randn(rng, 10, 10, 10, 10) |> aType @@ -342,8 +341,6 @@ end @test encoder[2] == encoder.layer_2 @test autoencoder[1] == autoencoder.encoder @test autoencoder[2] == autoencoder.decoder - @test keys(encoder) == (:layer_1, :layer_2) - @test keys(autoencoder) == (:encoder, :decoder) @test autoencoder.layers isa NamedTuple @test autoencoder.encoder isa Chain @test_throws ArgumentError autoencoder.layer_1 @@ -351,13 +348,11 @@ end end @testset "constructors" begin - @test Chain([Dense(10 => 5, sigmoid)]) == Dense(10 => 5, sigmoid) - f1(x, ps, st::NamedTuple) = (x .+ 1, st) f2(x) = x .+ 2 model = Chain((Dense(2 => 3), Dense(3 => 2)), f1, f2, NoOpLayer()) - @test length(model) == 4 + @test length(model) == 5 x = rand(Float32, 2, 5) ps, st = Lux.setup(rng, model) From 6ea462220856a9e8318ea534a3ea206000639143 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 3 Jul 2024 00:07:49 -0700 Subject: [PATCH 10/95] test: fix some tests --- src/helpers/training.jl | 19 +------------------ test/contrib/map_tests.jl | 2 +- test/helpers/training_tests.jl | 33 +++++++++++++++------------------ test/layers/containers_tests.jl | 4 ---- 4 files changed, 17 insertions(+), 41 deletions(-) diff --git a/src/helpers/training.jl b/src/helpers/training.jl index fa52f357a3..3819b39e72 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -45,11 +45,7 @@ Internal fields: end """ - TrainState(rng::Random.AbstractRNG, model::LuxCore.AbstractExplicitLayer, - optimizer::Optimisers.AbstractRule; - transform_variables::Union{Function, AbstractDevice}=gpu_device()) - TrainState(model::LuxCore.AbstractExplicitLayer, ps, st, - optimizer::Optimisers.AbstractRule) + TrainState(model::Lux.AbstractExplicitLayer, ps, st, optimizer::Optimisers.AbstractRule) Constructor for [`TrainState`](@ref). @@ -67,19 +63,6 @@ Constructor for [`TrainState`](@ref). [`TrainState`](@ref) object. """ -function TrainState(rng::AbstractRNG, model::AbstractExplicitLayer, optimizer::AbstractRule; - transform_variables=MLDataDevices.gpu_device()) - Base.depwarn( - "`TrainState(rng::AbstractRNG, model::AbstractExplicitLayer, \ - optimizer::Optimisers.AbstractRule; transform_variables::Union{Function, \ - AbstractLuxDevice}=gpu_device())` has been deprecated in favor of \ - `TrainState(model::AbstractExplicitLayer, ps, st, \ - optimizer::Optimisers.AbstractRule)`", - :TrainState) - ps, st = LuxCore.setup(rng, model) .|> transform_variables - return TrainState(model, ps, st, optimizer) -end - function TrainState( model::AbstractExplicitLayer, ps, st, optimizer::Optimisers.AbstractRule) st_opt = Optimisers.setup(optimizer, ps) diff --git a/test/contrib/map_tests.jl b/test/contrib/map_tests.jl index 5932d5d393..842cd18581 100644 --- a/test/contrib/map_tests.jl +++ b/test/contrib/map_tests.jl @@ -2,7 +2,7 @@ using Setfield, Functors function __occurs_in(kp::KeyPath, x::KeyPath) - length(kp) ≤ length(x) && return all(==(x[i], kp[i]) for i in 1:length(x)) + length(kp) ≤ length(x) && return all(==(x[i], kp[i]) for i in 1:length(kp)) return false end diff --git a/test/helpers/training_tests.jl b/test/helpers/training_tests.jl index 21f3dc3fc3..86bbd649d8 100644 --- a/test/helpers/training_tests.jl +++ b/test/helpers/training_tests.jl @@ -6,18 +6,14 @@ @testset "$mode" for (mode, aType, dev, ongpu) in MODES model = Dense(3, 2) opt = Adam(0.01f0) - ps, st = Lux.setup(rng, model) |> dev + ps, st = Lux.setup(Lux.replicate(rng), model) |> dev tstate = Lux.Training.TrainState(model, ps, st, opt) x = randn(Lux.replicate(rng), Float32, (3, 1)) - - ps, st = Lux.setup(Lux.replicate(rng), model) opt_st = Optimisers.setup(opt, tstate.parameters) @test check_approx(tstate.model, model) - @test check_approx(tstate.parameters, ps) - @test check_approx(tstate.states, st) @test check_approx(tstate.optimizer_state, opt_st) @test tstate.step == 0 end @@ -72,36 +68,40 @@ end Dense(32, 32, tanh), BatchNorm(32), Dense(32, 4)) dataset_ = [dev((x, y)) for (x, y) in dataset] opt = Adam(0.001f0) - ps, st = Lux.setup(rng, model) |> dev @testset "$(ad)" for ad in ( AutoZygote(), AutoTracker(), AutoReverseDiff(), AutoEnzyme()) ongpu && (ad isa AutoReverseDiff || ad isa AutoEnzyme) && continue !LuxTestUtils.ENZYME_TESTING_ENABLED && ad isa AutoEnzyme && continue - tstate = Lux.Training.TrainState(model, ps, st, opt) + ps, st = Lux.setup(rng, model) |> dev + tstate = Training.TrainState(model, ps, st, opt) initial_loss = first(mse(model, tstate.parameters, tstate.states, dataset_[1])) - for epoch in 1:100, (x, y) in dataset_ + for epoch in 1:1000, (x, y) in dataset_ grads, loss, _, tstate = allow_unstable() do - Lux.Training.compute_gradients(ad, mse, (x, y), tstate) + Lux.Experimental.compute_gradients(ad, mse, (x, y), tstate) end tstate = Lux.Training.apply_gradients!(tstate, grads) end - for epoch in 1:100, (x, y) in dataset_ + for epoch in 1:1000, (x, y) in dataset_ grads, loss, _, tstate = allow_unstable() do - Lux.Training.single_train_step!(ad, mse, (x, y), tstate) + Training.single_train_step!(ad, mse, (x, y), tstate) end end - for epoch in 1:100, (x, y) in dataset_ + for epoch in 1:1000, (x, y) in dataset_ grads, loss, _, tstate = allow_unstable() do - Lux.Training.single_train_step(ad, mse, (x, y), tstate) + Training.single_train_step(ad, mse, (x, y), tstate) end end + final_loss = first(mse(model, tstate.parameters, tstate.states, dataset_[1])) + + @test final_loss * 100 < initial_loss + # Test the adjust API tstate = Optimisers.adjust(tstate, 0.1f0) @test tstate.optimizer_state.layer_1.weight.rule.eta ≈ 0.1f0 @@ -114,15 +114,12 @@ end Optimisers.adjust!(tstate; eta=0.11f0) @test tstate.optimizer_state.layer_1.weight.rule.eta ≈ 0.11f0 - - final_loss = first(mse(model, tstate.parameters, tstate.states, dataset_[1])) - - @test final_loss * 100 < initial_loss end struct AutoCustomAD <: ADTypes.AbstractADType end - tstate = Lux.Training.TrainState(model, ps, st, opt) + ps, st = Lux.setup(rng, model) |> dev + tstate = Training.TrainState(model, ps, st, opt) @test_throws ArgumentError Lux.Training.compute_gradients( AutoCustomAD(), mse, dataset_[1], tstate) diff --git a/test/layers/containers_tests.jl b/test/layers/containers_tests.jl index d7fd54537f..b5134caad6 100644 --- a/test/layers/containers_tests.jl +++ b/test/layers/containers_tests.jl @@ -329,10 +329,6 @@ end __f = (x, ps) -> sum(first(layer(x, ps, st))) test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) - @test_throws ArgumentError Chain(; - l1=Dense(10 => 5, sigmoid), d52=Dense(5 => 2, tanh), - d21=Dense(2 => 1), d2=Dense(2 => 1)) - @testset "indexing and field access" begin encoder = Chain(Dense(10 => 5, sigmoid), Dense(5 => 2, tanh)) decoder = Chain(Dense(2 => 5, tanh), Dense(5 => 10, sigmoid)) From 6db3f6fd2dccc59af93214ba2344e3d167309d73 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 3 Jul 2024 19:22:01 -0700 Subject: [PATCH 11/95] chore!: remove annotation of WrappedFunction --- ext/LuxDynamicExpressionsExt.jl | 2 +- ext/LuxFluxExt.jl | 4 +-- src/layers/basic.jl | 41 ++++---------------------- src/layers/conv.jl | 2 +- test/contrib/share_parameters_tests.jl | 16 +++++----- test/helpers/compact_tests.jl | 6 ++-- test/layers/basic_tests.jl | 24 ++------------- test/layers/containers_tests.jl | 8 ++--- test/transform/flux_tests.jl | 2 +- 9 files changed, 27 insertions(+), 78 deletions(-) diff --git a/ext/LuxDynamicExpressionsExt.jl b/ext/LuxDynamicExpressionsExt.jl index 2b28885c4e..c552b0c83c 100644 --- a/ext/LuxDynamicExpressionsExt.jl +++ b/ext/LuxDynamicExpressionsExt.jl @@ -34,7 +34,7 @@ function Lux.DynamicExpressionsLayer(operator_enum::OperatorEnum, expressions::N Parallel(nothing, ntuple(i -> DynamicExpressionsLayer(operator_enum, expressions[i], name_fn(i), eval_options), length(expressions))...), - WrappedFunction{:direct_call}(Lux.Utils.stack1); + WrappedFunction(Lux.Utils.stack1); name="DynamicExpressionsLayer") #! format: on end diff --git a/ext/LuxFluxExt.jl b/ext/LuxFluxExt.jl index 6766b925d9..d0201dad2c 100644 --- a/ext/LuxFluxExt.jl +++ b/ext/LuxFluxExt.jl @@ -17,7 +17,7 @@ function Lux.convert_flux_model(l::T; preserve_ps_st::Bool=false, kwargs...) whe return Lux.FluxLayer(l) end -Lux.convert_flux_model(l::Function; kwargs...) = Lux.WrappedFunction{:direct_call}(l) +Lux.convert_flux_model(l::Function; kwargs...) = Lux.WrappedFunction(l) function Lux.convert_flux_model(l::Flux.Chain; kwargs...) fn = x -> Lux.convert_flux_model(x; kwargs...) @@ -29,7 +29,7 @@ end function Lux.convert_flux_model(l::Flux.Dense; preserve_ps_st::Bool=false, kwargs...) out_dims, in_dims = size(l.weight) if preserve_ps_st - bias = l.bias isa Bool ? nothing : reshape(copy(l.bias), out_dims, 1) + bias = l.bias isa Bool ? nothing : copy(l.bias) return Lux.Dense(in_dims => out_dims, l.σ; init_weight=Returns(copy(l.weight)), init_bias=Returns(bias), use_bias=!(l.bias isa Bool)) else diff --git a/src/layers/basic.jl b/src/layers/basic.jl index bc2501899d..801b948094 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -217,8 +217,7 @@ struct NoOpLayer <: AbstractExplicitLayer end (noop::NoOpLayer)(x, _, st::NamedTuple) = x, st """ - WrappedFunction{DC}(f) - WrappedFunction(f) -> WrappedFunction{:runtime_check}(f) + WrappedFunction(f) Wraps a stateless and parameter less function. Might be used when a function is added to `Chain`. For example, `Chain(x -> relu.(x))` would not work and the right thing to do would @@ -227,9 +226,6 @@ be `Chain((x, ps, st) -> (relu.(x), st))`. An easier thing to do would be ## Arguments - - `DC`: If `:runtime_check`, then we check if the function can be called with the input - `x`, `ps`, and `st` using `hasmethod`. If `:direct_call`, we call `f(x)` directly. - For all other values, we call `f(x, ps, st)` which must return a tuple. - `f`: Some function. ## Inputs @@ -242,38 +238,13 @@ be `Chain((x, ps, st) -> (relu.(x), st))`. An easier thing to do would be - Output of `f(x)` - Empty `NamedTuple()` """ -struct WrappedFunction{DC, F} <: AbstractExplicitLayer - call_mode::StaticSymbol{DC} - func::F +@concrete struct WrappedFunction <: AbstractExplicitLayer + func <: Function end -function WrappedFunction{call_mode}(f::F) where {call_mode, F} - return WrappedFunction(static(call_mode), f) -end - -WrappedFunction(f::F) where {F} = WrappedFunction{:runtime_check}(f) - -function (wf::WrappedFunction{:direct_call})(x, ps, st::NamedTuple) - return wrapped_function_call(wf.func, x, ps, st, True()) -end - -function (wf::WrappedFunction)(x, ps, st::NamedTuple) - return wrapped_function_call(wf.func, x, ps, st, False()) -end - -function (wf::WrappedFunction{:runtime_check})(x, ps, st::NamedTuple) - return wrapped_function_call(wf.func, x, ps, st, - static(!hasmethod(wf.func, (typeof(x), typeof(ps), typeof(st))))) -end +(wf::WrappedFunction)(x, ps, st::NamedTuple{}) = wf.func(x), st -wrapped_function_call(f, x, ps, st, ::False) = f(x, ps, st) -wrapped_function_call(f, x, _, st, ::True) = f(x), st - -function Base.show(io::IO, w::WrappedFunction{T}) where {T} - print(io, "WrappedFunction(", static(w.call_mode), ", ") - show(io, w.func) - print(io, ")") -end +Base.show(io::IO, w::WrappedFunction) = print(io, "WrappedFunction(", w.func, ")") """ Dense(in_dims => out_dims, activation=identity; init_weight=glorot_uniform, @@ -342,7 +313,7 @@ end function initialparameters(rng::AbstractRNG, d::Dense) if has_bias(d) return (weight=d.init_weight(rng, d.out_dims, d.in_dims), - bias=d.init_bias(rng, d.out_dims, 1)) #TODO: In v1 make it a vector + bias=d.init_bias(rng, d.out_dims)) else return (weight=d.init_weight(rng, d.out_dims, d.in_dims),) end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index ad40385a04..4e727b89c7 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -571,7 +571,7 @@ function set to `Base.Fix2(pixel_shuffle, r)` - Output of size `(r x W, r x H, C, N)` for 4D-arrays, and `(r x W, r x H, ..., C, N)` for D-dimensional data, where `D = ndims(x) - 2` """ -PixelShuffle(r::IntegerType) = WrappedFunction{:direct_call}(Base.Fix2(pixel_shuffle, r)) +PixelShuffle(r::IntegerType) = WrappedFunction(Base.Fix2(pixel_shuffle, r)) @doc doc""" CrossCor(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, diff --git a/test/contrib/share_parameters_tests.jl b/test/contrib/share_parameters_tests.jl index 92a8c19d67..d87a477cd7 100644 --- a/test/contrib/share_parameters_tests.jl +++ b/test/contrib/share_parameters_tests.jl @@ -16,10 +16,10 @@ @test ps_1.d3.weight == ps_1.d2.l1.weight @test ps_1.d3.bias == ps_1.d2.l1.bias - ps_new_1 = (; weight=randn(rng, Float32, 4, 2), bias=randn(rng, Float32, 4, 1)) |> - dev - ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2, 1)) |> - dev + ps_new_1 = (; weight=randn(rng, Float32, 4, 2), bias=randn(rng, Float32, 4)) |> + device + ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2)) |> + device ps_2 = Lux.Experimental.share_parameters(ps, sharing, (ps_new_1, ps_new_2)) @@ -46,10 +46,10 @@ ps, sharing, (ps_new_1,)) # Parameter Structure Mismatch - ps_new_1 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 4, 1)) |> - dev - ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2, 1)) |> - dev + ps_new_1 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 4)) |> + device + ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2)) |> + device @test_throws ArgumentError Lux.Experimental.share_parameters( ps, sharing, (ps_new_1, ps_new_2)) diff --git a/test/helpers/compact_tests.jl b/test/helpers/compact_tests.jl index 48c28fffb5..1844e937b2 100644 --- a/test/helpers/compact_tests.jl +++ b/test/helpers/compact_tests.jl @@ -104,14 +104,14 @@ ps, st = Lux.setup(rng, model) |> dev @test size(ps.w1.weight) == (128, 1) - @test size(ps.w1.bias) == (128, 1) + @test size(ps.w1.bias) == (128,) @test length(ps.w2) == nlayers for i in 1:nlayers @test size(ps.w2[i].weight) == (128, 128) - @test size(ps.w2[i].bias) == (128, 1) + @test size(ps.w2[i].bias) == (128,) end @test size(ps.w3.weight) == (1, 128) - @test size(ps.w3.bias) == (1, 1) + @test size(ps.w3.bias) == (1,) x = randn(n_in, 32) |> aType diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 7f1acd4880..15982a27c4 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -102,27 +102,7 @@ @jet layer(x, ps, st) __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - f11(x) = x .* x - - layer = WrappedFunction{:runtime_check}(f11) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - x = randn(rng, 2, 3) |> aType - - @test layer(x, ps, st)[1] ≈ x .* x - @test @inferred(layer(x, ps, st)) isa Any - - f12(x, ps, st) = x .+ 1, st - - layer = WrappedFunction{:runtime_check}(f12) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - x = randn(rng, 2, 3) |> aType - - @test layer(x, ps, st)[1] ≈ x .+ 1 - @test @inferred(layer(x, ps, st)) isa Any + @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 end @testset "PeriodicEmbedding" begin @@ -155,7 +135,7 @@ end ps, st = Lux.setup(rng, layer) |> dev @test size(ps.weight) == (100, 10) - @test size(ps.bias) == (100, 1) + @test size(ps.bias) == (100,) @test layer.activation == identity layer = Dense(10, 100, relu; use_bias=false) diff --git a/test/layers/containers_tests.jl b/test/layers/containers_tests.jl index b5134caad6..cbd425f7f3 100644 --- a/test/layers/containers_tests.jl +++ b/test/layers/containers_tests.jl @@ -4,8 +4,7 @@ @testset "$mode" for (mode, aType, dev, ongpu) in MODES @testset "zero sum" begin - layer = SkipConnection( - WrappedFunction{:direct_call}(Broadcast.BroadcastFunction(zero)), .+) + layer = SkipConnection(WrappedFunction(Broadcast.BroadcastFunction(zero)), .+) display(layer) ps, st = Lux.setup(rng, layer) |> dev x = randn(rng, Float32, 10, 10, 10, 10) |> aType @@ -42,8 +41,7 @@ end @testset "$mode" for (mode, aType, dev, ongpu) in MODES @testset "zero sum" begin layer = Parallel( - +, WrappedFunction{:direct_call}(Broadcast.BroadcastFunction(zero)), - NoOpLayer()) + +, WrappedFunction(Broadcast.BroadcastFunction(zero)), NoOpLayer()) display(layer) ps, st = Lux.setup(rng, layer) |> dev x = randn(rng, 10, 10, 10, 10) |> aType @@ -344,7 +342,7 @@ end end @testset "constructors" begin - f1(x, ps, st::NamedTuple) = (x .+ 1, st) + f1(x) = x .+ 1 f2(x) = x .+ 2 model = Chain((Dense(2 => 3), Dense(3 => 2)), f1, f2, NoOpLayer()) diff --git a/test/transform/flux_tests.jl b/test/transform/flux_tests.jl index 7c6cc0ed09..b063bd67fb 100644 --- a/test/transform/flux_tests.jl +++ b/test/transform/flux_tests.jl @@ -503,7 +503,7 @@ @testset "Functions" begin @test tolux(Flux.flatten) isa Lux.FlattenLayer @test tolux(identity) isa Lux.NoOpLayer - @test tolux(+) isa Lux.WrappedFunction{:direct_call} + @test tolux(+) isa Lux.WrappedFunction end @testset "Unsupported Layers" begin From 8fb9ad454586cdcde357a4363f5e29ca97944df1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 10 Jul 2024 22:36:31 -0700 Subject: [PATCH 12/95] fix!: remove potentially incorrect Tracker gradients for SimpleChains --- ext/LuxTrackerExt/rules.jl | 19 ------------------- src/layers/extension.jl | 11 +---------- test/transform/simple_chains_tests.jl | 6 ++++-- 3 files changed, 5 insertions(+), 31 deletions(-) diff --git a/ext/LuxTrackerExt/rules.jl b/ext/LuxTrackerExt/rules.jl index 70883d3a55..39f9a879c8 100644 --- a/ext/LuxTrackerExt/rules.jl +++ b/ext/LuxTrackerExt/rules.jl @@ -7,25 +7,6 @@ for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray) end end -Tracker.@grad function Lux.apply_simple_chain(layer, x, ps, ::CPUDevice) - Base.depwarn("`Tracker.jl` often produces incorrect gradients for `SimpleChains.jl` \ - models. In future versions of Lux.jl you will need to load `Zygote.jl` \ - to use `Tracker.jl` for your model.", - :apply_simple_chain) - @warn "`Tracker.jl` often produces incorrect gradients for `SimpleChains.jl` models. \ - As such please test your model with `FiniteDiff.jl` or `Zygote.jl` before using \ - `Tracker.jl` for your model." maxlog=1 - y, pb_f = CRC.rrule(layer, Tracker.data(x), Tracker.data(ps)) - ∇apply_simple_chain = let pb_f = pb_f - Δ -> begin - _, ∂x, ∂ps = pb_f(convert(Array, Tracker.data(Δ))) - return Tracker.nobacksies(:apply_simple_chain, (nothing, ∂x, ∂ps, nothing)) - end - end - # Tracker is not great at handling arbitrary types, so we convert to Array - return Array(y), ∇apply_simple_chain -end - # DynamicExpressions.jl for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray) T1 === :AbstractArray && T2 === :AbstractArray && continue diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 6997f4538c..9e1d1fd7de 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -179,16 +179,7 @@ regular `Array` or not. Default is `false`. ## Arguments - `layer`: SimpleChains layer - -!!! note - - If using `Tracker.jl`, the output will always be a regular `Array`. - -!!! danger - - `Tracker.jl` sometimes produces incorrect gradients for `SimpleChains.jl` models. As - such please test your model with `FiniteDiff.jl` or `Zygote.jl` before using - `Tracker.jl` for your model. + - `lux_layer`: Potentially equivalent Lux layer that is used for printing """ struct SimpleChainsLayer{ToArray, SL, LL <: Union{Nothing, AbstractExplicitLayer}} <: AbstractExplicitLayer diff --git a/test/transform/simple_chains_tests.jl b/test/transform/simple_chains_tests.jl index 76938a6ae3..3c20563195 100644 --- a/test/transform/simple_chains_tests.jl +++ b/test/transform/simple_chains_tests.jl @@ -29,7 +29,8 @@ @test length(gs[2].params) == length(ps.params) # See https://github.com/LuxDL/Lux.jl/issues/644 - test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoEnzyme()]) + test_gradients( + __f, x, ps; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoEnzyme(), AutoTracker()]) x = randn(Float32, 28, 28, 1, 15) @test size(first(simple_chains_model(x, ps, st))) == (10, 15) @@ -41,7 +42,8 @@ @test length(gs[2].params) == length(ps.params) # See https://github.com/LuxDL/Lux.jl/issues/644 - test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoEnzyme()]) + test_gradients( + __f, x, ps; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoEnzyme(), AutoTracker()]) @testset "Array Output" begin adaptor = ToSimpleChainsAdaptor((static(28), static(28), static(1)), true) From e2b0e16e22bf6456340488f705b33653666a7a18 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 22 Jul 2024 21:21:07 -0700 Subject: [PATCH 13/95] fix: store the bias as a vector --- src/layers/basic.jl | 7 +++---- src/layers/conv.jl | 15 ++++++--------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 801b948094..e952969cad 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -326,7 +326,7 @@ outputsize(d::Dense) = (d.out_dims,) function (d::Dense)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(d, ps, st, x) - bias = safe_vec(safe_getproperty(ps, Val(:bias))) + bias = safe_getproperty(ps, Val(:bias)) z = matrix_to_array( fused_dense_bias_activation(d.activation, ps.weight, make_abstract_matrix(y), bias), y) @@ -505,7 +505,7 @@ end function initialparameters(rng::AbstractRNG, b::Bilinear) if has_bias(b) return (weight=b.init_weight(rng, b.out_dims, b.in1_dims, b.in2_dims), - bias=b.init_bias(rng, b.out_dims, 1)) # TODO: In v1.0 make it a vector + bias=b.init_bias(rng, b.out_dims)) else return (weight=b.init_weight(rng, b.out_dims, b.in1_dims, b.in2_dims),) end @@ -527,8 +527,7 @@ function (b::Bilinear)( Wy = reshape(reshape(ps.weight, (:, s₃)) * y, (s₁, s₂, :)) Wyx = reshape(batched_matmul(Wy, reshape(x, (s₂, 1, :))), (s₁, :)) - bias = safe_vec(safe_getproperty(ps, Val(:bias))) - return (bias_activation!!(b.activation, Wyx, bias), st) + return bias_activation!!(b.activation, Wyx, safe_getproperty(ps, Val(:bias))), st end function (b::Bilinear)((x, y)::Tuple{<:AbstractArray, <:AbstractArray}, ps, st::NamedTuple) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 4e727b89c7..44ce03fa67 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -171,8 +171,7 @@ function initialparameters(rng::AbstractRNG, c::Conv) weight = init_conv_filter( rng, c.kernel_size, c.in_chs => c.out_chs; init=c.init_weight, c.groups) has_bias(c) || return (; weight) - return (; weight, - bias=c.init_bias(rng, ntuple(_ -> 1, length(c.kernel_size))..., c.out_chs, 1)) # TODO: flatten in v1 + return (; weight, bias=c.init_bias(rng, c.out_chs)) end function parameterlength(c::Conv) @@ -182,7 +181,7 @@ end function (c::Conv)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(c, ps, st, x) cdims = DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups) - bias = safe_vec(safe_getproperty(ps, Val(:bias))) + bias = safe_getproperty(ps, Val(:bias)) return fused_conv_bias_activation(c.activation, ps.weight, y, bias, cdims), st end @@ -676,8 +675,7 @@ function initialparameters(rng::AbstractRNG, c::CrossCor) weight = init_conv_filter( rng, c.kernel_size, c.in_chs => c.out_chs; init=c.init_weight, c.groups) has_bias(c) || return (; weight) - return (; weight, - bias=c.init_bias(rng, ntuple(_ -> 1, length(c.kernel_size))..., c.out_chs, 1)) # TODO: flatten in v1 + return (; weight, bias=c.init_bias(rng, c.out_chs)) end function parameterlength(c::CrossCor) @@ -688,7 +686,7 @@ function (c::CrossCor)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(c, ps, st, x) cdims = DenseConvDims( DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups); F=true) - bias = safe_vec(safe_getproperty(ps, Val(:bias))) + bias = safe_getproperty(ps, Val(:bias)) return fused_conv_bias_activation(c.activation, ps.weight, y, bias, cdims), st end @@ -801,8 +799,7 @@ function initialparameters(rng::AbstractRNG, c::ConvTranspose) weight = init_conv_filter( rng, c.kernel_size, c.out_chs => c.in_chs; init=c.init_weight, c.groups) has_bias(c) || return (; weight) - return (; weight, - bias=c.init_bias(rng, ntuple(_ -> 1, length(c.kernel_size))..., c.out_chs, 1)) # TODO: flatten in v1 + return (; weight, bias=c.init_bias(rng, c.out_chs)) end function parameterlength(c::ConvTranspose) @@ -812,7 +809,7 @@ end function (c::ConvTranspose)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(c, ps, st, x) cdims = conv_transpose_dims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups) - bias = safe_vec(safe_getproperty(ps, Val(:bias))) + bias = safe_getproperty(ps, Val(:bias)) return bias_activation!!(c.activation, conv_transpose(y, ps.weight, cdims), bias), st end From 287932d53ef5b07f4978bc7eda0842a89e95521f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 22 Jul 2024 21:58:48 -0700 Subject: [PATCH 14/95] fix: test updates from new changes --- ext/LuxFluxExt.jl | 9 +++------ test/layers/conv_tests.jl | 4 ++-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/ext/LuxFluxExt.jl b/ext/LuxFluxExt.jl index d0201dad2c..1f73de3d78 100644 --- a/ext/LuxFluxExt.jl +++ b/ext/LuxFluxExt.jl @@ -97,8 +97,7 @@ function Lux.convert_flux_model(l::Flux.Conv; preserve_ps_st::Bool=false, kwargs groups = l.groups pad = l.pad isa Flux.SamePad ? SamePad() : l.pad if preserve_ps_st - _bias = l.bias isa Bool ? nothing : - reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1) + _bias = l.bias isa Bool ? nothing : vec(copy(l.bias)) return Lux.Conv(k, in_chs * groups => out_chs, l.σ; l.stride, pad, l.dilation, groups, init_weight=Returns(Lux.maybe_flip_conv_weight(l.weight)), init_bias=Returns(_bias), use_bias=!(l.bias isa Bool)) @@ -115,8 +114,7 @@ function Lux.convert_flux_model( groups = l.groups pad = l.pad isa Flux.SamePad ? SamePad() : l.pad if preserve_ps_st - _bias = l.bias isa Bool ? nothing : - reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1) + _bias = l.bias isa Bool ? nothing : vec(copy(l.bias)) return Lux.ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride, pad, l.dilation, groups, use_bias=!(l.bias isa Bool), init_weight=Returns(Lux.maybe_flip_conv_weight(l.weight)), @@ -132,8 +130,7 @@ function Lux.convert_flux_model(l::Flux.CrossCor; preserve_ps_st::Bool=false, kw in_chs, out_chs = size(l.weight)[(end - 1):end] pad = l.pad isa Flux.SamePad ? SamePad() : l.pad if preserve_ps_st - _bias = l.bias isa Bool ? nothing : - reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1) + _bias = l.bias isa Bool ? nothing : vec(copy(l.bias)) return Lux.CrossCor(k, in_chs => out_chs, l.σ; l.stride, pad, l.dilation, init_weight=Returns(copy(l.weight)), init_bias=Returns(_bias), use_bias=!(l.bias isa Bool)) diff --git a/test/layers/conv_tests.jl b/test/layers/conv_tests.jl index eb6b0974cf..7a0e813a0f 100644 --- a/test/layers/conv_tests.jl +++ b/test/layers/conv_tests.jl @@ -189,7 +189,7 @@ end display(layer) ps, st = Lux.setup(rng, layer) @test ps.weight isa aType{Float64, 4} - @test ps.bias isa aType{Float16, 4} + @test ps.bias isa aType{Float16, 1} end @testset "Depthwise Conv" begin @@ -449,7 +449,7 @@ end display(layer) ps, st = Lux.setup(rng, layer) @test ps.weight isa aType{Float64, 4} - @test ps.bias isa aType{Float16, 4} + @test ps.bias isa aType{Float16, 1} end @testset "CrossCor SamePad kernelsize $k" for k in ( From d8183283fb378dd5d6b0dc094346d38936de819d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 14:53:43 -0700 Subject: [PATCH 15/95] chore: drop pre-1.0 weight initializers --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index bd381999a9..de9e53da86 100644 --- a/Project.toml +++ b/Project.toml @@ -113,6 +113,6 @@ Statistics = "1.10" Tracker = "0.2.34" UnrolledUtilities = "0.1.2" VectorizationBase = "0.21.70" -WeightInitializers = "0.1.5, 1" +WeightInitializers = "1" Zygote = "0.6.70" julia = "1.10" From e4ae2507b55711e59bf0c32fe491d0b9b1d220d4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 15:03:58 -0700 Subject: [PATCH 16/95] refactor: migrate to `MLDataDevices` --- Project.toml | 2 - docs/Project.toml | 2 - docs/make.jl | 7 +- docs/run_single_tutorial.jl | 6 +- docs/src/.vitepress/config.mts | 2 - .../api/Accelerator_Support/LuxDeviceUtils.md | 50 ----- .../api/Accelerator_Support/MLDataDevices.md | 2 +- docs/src/manual/distributed_utils.md | 8 +- docs/src/manual/gpu_management.md | 21 +- docs/src/manual/preferences.md | 4 +- examples/ImageNet/utils.jl | 10 +- ext/LuxMPIExt.jl | 4 + ext/LuxReverseDiffExt/LuxReverseDiffExt.jl | 4 + ext/LuxTrackerExt.jl | 82 +++++++ ext/LuxZygoteExt/LuxZygoteExt.jl | 2 + src/Lux.jl | 6 +- src/helpers/nested_ad.jl | 203 ++++++++++++++++++ test/contrib/share_parameters_tests.jl | 4 +- test/distributed/common_distributedtest.jl | 4 +- test/distributed/data_distributedtest.jl | 4 +- test/distributed/optimizer_distributedtest.jl | 4 +- .../synchronize_distributedtest.jl | 4 +- test/layers/normalize_tests.jl | 4 +- test/qa_tests.jl | 2 +- test/setup_modes.jl | 6 +- test/transform/flux_tests.jl | 6 +- 26 files changed, 342 insertions(+), 111 deletions(-) delete mode 100644 docs/src/api/Accelerator_Support/LuxDeviceUtils.md create mode 100644 ext/LuxTrackerExt.jl create mode 100644 src/helpers/nested_ad.jl diff --git a/Project.toml b/Project.toml index de9e53da86..4e392cea08 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,6 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -90,7 +89,6 @@ GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" LossFunctions = "0.11.1" LuxCore = "0.1.24" -LuxDeviceUtils = "0.1.26" LuxLib = "0.3.42" MLDataDevices = "1.1" MLUtils = "0.4.4" diff --git a/docs/Project.toml b/docs/Project.toml index 2dab4173db..93e13e81ad 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -14,7 +14,6 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" @@ -42,7 +41,6 @@ Literate = "2.18.0" Lux = "0.5.62" LuxCUDA = "0.3.2" LuxCore = "0.1.15" -LuxDeviceUtils = "0.1.21" LuxLib = "0.3.42" LuxTestUtils = "1.1" MLDataDevices = "1.1" diff --git a/docs/make.jl b/docs/make.jl index 56c51ae44e..564d46bad2 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,7 +1,6 @@ using Documenter, DocumenterVitepress, Pkg using Lux, LuxCore, LuxLib, WeightInitializers -using LuxTestUtils, LuxDeviceUtils -using MLDataDevices +using LuxTestUtils, MLDataDevices using LuxCUDA using Optimisers # for some docstrings @@ -56,7 +55,6 @@ pages = [ "api/Lux/distributed_utils.md", ], "Accelerator Support" => [ - "api/Accelerator_Support/LuxDeviceUtils.md", "api/Accelerator_Support/MLDataDevices.md" ], "Building Blocks" => [ @@ -80,8 +78,7 @@ makedocs(; sitename="Lux.jl Docs", authors="Avik Pal et al.", clean=true, doctest=false, # We test it in the CI, no need to run it here - modules=[Lux, LuxCore, LuxLib, WeightInitializers, - LuxTestUtils, LuxDeviceUtils, MLDataDevices], + modules=[Lux, LuxCore, LuxLib, WeightInitializers, LuxTestUtils, MLDataDevices], linkcheck=true, repo="https://github.com/LuxDL/Lux.jl/blob/{commit}{path}#{line}", format=DocumenterVitepress.MarkdownVitepress(; diff --git a/docs/run_single_tutorial.jl b/docs/run_single_tutorial.jl index 965f99b942..b163ee244d 100644 --- a/docs/run_single_tutorial.jl +++ b/docs/run_single_tutorial.jl @@ -24,13 +24,13 @@ function preprocess(path, str) using InteractiveUtils InteractiveUtils.versioninfo() - if @isdefined(LuxDeviceUtils) - if @isdefined(CUDA) && LuxDeviceUtils.functional(LuxCUDADevice) + if @isdefined(MLDataDevices) + if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice) println() CUDA.versioninfo() end - if @isdefined(AMDGPU) && LuxDeviceUtils.functional(LuxAMDGPUDevice) + if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice) println() AMDGPU.versioninfo() end diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index b111a3bf37..10948ced66 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -79,7 +79,6 @@ export default defineConfig({ }, { text: 'Accelerator Support', items: [ - { text: 'LuxDeviceUtils', link: '/api/Accelerator_Support/LuxDeviceUtils' }, { text: 'MLDataDevices', link: '/api/Accelerator_Support/MLDataDevices' } ] }, @@ -216,7 +215,6 @@ export default defineConfig({ }, { text: 'Accelerator Support', collapsed: false, items: [ - { text: 'LuxDeviceUtils', link: '/api/Accelerator_Support/LuxDeviceUtils' }, { text: 'MLDataDevices', link: '/api/Accelerator_Support/MLDataDevices' }] }, { diff --git a/docs/src/api/Accelerator_Support/LuxDeviceUtils.md b/docs/src/api/Accelerator_Support/LuxDeviceUtils.md deleted file mode 100644 index e4fac9ba01..0000000000 --- a/docs/src/api/Accelerator_Support/LuxDeviceUtils.md +++ /dev/null @@ -1,50 +0,0 @@ -# [LuxDeviceUtils](@id LuxDeviceUtils-API) - -`LuxDeviceUtils.jl` is a lightweight package defining rules for transferring data across -devices. Most users should directly use Lux.jl instead. - -!!! note "Transition to `MLDataDevices.jl`" - - Currently this package is in maintenance mode and won't receive any new features, - however, we will backport bug fixes till Lux `v1.0` is released. Post that this package - should be considered deprecated and users should switch to `MLDataDevices.jl`. - - For more information on `MLDataDevices.jl` checkout the - [MLDataDevices.jl Documentation](@ref MLDataDevices-API). - -## Index - -```@index -Pages = ["LuxDeviceUtils.md"] -``` - -## Preferences - -```@docs -LuxDeviceUtils.gpu_backend! -``` - -## Data Transfer - -```@docs -LuxDeviceUtils.cpu_device -LuxDeviceUtils.gpu_device -``` - -## Miscellaneous - -```@docs -LuxDeviceUtils.reset_gpu_device! -LuxDeviceUtils.supported_gpu_backends -LuxDeviceUtils.default_device_rng -LuxDeviceUtils.get_device -LuxDeviceUtils.get_device_type -LuxDeviceUtils.loaded -LuxDeviceUtils.functional -``` - -## Multi-GPU Support - -```@docs -LuxDeviceUtils.set_device! -``` diff --git a/docs/src/api/Accelerator_Support/MLDataDevices.md b/docs/src/api/Accelerator_Support/MLDataDevices.md index df15d913f1..c1c031e82c 100644 --- a/docs/src/api/Accelerator_Support/MLDataDevices.md +++ b/docs/src/api/Accelerator_Support/MLDataDevices.md @@ -3,7 +3,7 @@ `MLDataDevices.jl` is a lightweight package defining rules for transferring data across devices. Most users should directly use Lux.jl instead. -!!! note "Comparison to LuxDeviceUtils.jl" +!!! note "Transitioning from `LuxDeviceUtils.jl`" `LuxDeviceUtils.jl` was renamed to `MLDataDevices.jl` in v1.0 as a part of allowing these packages to have broader adoption outsize the Lux community. However, Lux diff --git a/docs/src/manual/distributed_utils.md b/docs/src/manual/distributed_utils.md index 2f9c4be36f..dbee8ab110 100644 --- a/docs/src/manual/distributed_utils.md +++ b/docs/src/manual/distributed_utils.md @@ -88,10 +88,10 @@ And that's pretty much it! 3. We don't automatically determine if the MPI Implementation is CUDA or ROCM aware. See [GPU-aware MPI](@ref gpu-aware-mpi-preferences) for more information. 4. Older [`Lux.gpu`](@ref) implementations used to "just work" with `FluxMPI.jl`. We expect - [`LuxDeviceUtils.gpu_device`](@ref) to continue working as expected, however, we - recommend using [`LuxDeviceUtils.gpu_device`](@ref) after calling - [`DistributedUtils.initialize`](@ref) to avoid any mismatch between the device set - via `DistributedUtils` and the device stores in `LuxCUDADevice` or `LuxAMDGPUDevice` + [`gpu_device`](@ref) to continue working as expected, however, we recommend using + [`gpu_device`](@ref) after calling [`DistributedUtils.initialize`](@ref) to avoid any + mismatch between the device set via `DistributedUtils` and the device stores in + `CUDADevice` or `AMDGPUDevice`. ## Known Shortcomings diff --git a/docs/src/manual/gpu_management.md b/docs/src/manual/gpu_management.md index b879b12387..b6f578d259 100644 --- a/docs/src/manual/gpu_management.md +++ b/docs/src/manual/gpu_management.md @@ -24,8 +24,7 @@ supported_gpu_backends() Automatic Backend Management is done by two simple functions: `cpu_device` and `gpu_device`. -* [`LuxDeviceUtils.cpu_device`](@ref): This is a simple function and just returns a - `LuxCPUDevice` object. +* [`cpu_device`](@ref): This is a simple function and just returns a `CPUDevice` object. ```@example gpu_management cdev = cpu_device() @@ -35,9 +34,9 @@ cdev = cpu_device() x_cpu = randn(Float32, 3, 2) ``` -* [`LuxDeviceUtils.gpu_device`](@ref): This function performs automatic GPU device selection - and returns an object. - 1. If no GPU is available, it returns a `LuxCPUDevice` object. +* [`gpu_device`](@ref): This function performs automatic GPU device selection and returns + an object. + 1. If no GPU is available, it returns a `CPUDevice` object. 2. If a LocalPreferences file is present, then the backend specified in the file is used. To set a backend, use `Lux.gpu_backend!()`. (a) If the trigger package corresponding to the device is not loaded, then a warning is displayed. (b) If no @@ -57,19 +56,19 @@ x_gpu = x_cpu |> gdev ## Manual Backend Management -Automatic Device Selection can be circumvented by directly using `LuxCPUDevice` and -`AbstractLuxGPUDevice` objects. +Automatic Device Selection can be circumvented by directly using `CPUDevice` and +`AbstractGPUDevice` objects. ```@example gpu_management cdev = cpu_device() x_cpu = randn(Float32, 3, 2) -if LuxDeviceUtils.functional(LuxCUDADevice) - gdev = LuxCUDADevice() +if MLDataDevices.functional(CUDADevice) + gdev = CUDADevice() x_gpu = x_cpu |> gdev -elseif LuxDeviceUtils.functional(LuxAMDGPUDevice) - gdev = LuxAMDGPUDevice() +elseif MLDataDevices.functional(AMDGPUDevice) + gdev = AMDGPUDevice() x_gpu = x_cpu |> gdev else @info "No GPU is available. Using CPU." diff --git a/docs/src/manual/preferences.md b/docs/src/manual/preferences.md index eaea213ee6..357c77acb5 100644 --- a/docs/src/manual/preferences.md +++ b/docs/src/manual/preferences.md @@ -38,8 +38,8 @@ By default, both of these preferences are set to `false`. 1. `gpu_backend` - Set this to bypass the automatic backend selection and use a specific gpu backend. Valid options are "cuda", "rocm", "metal", and "oneapi". This preference - needs to be set for `LuxDeviceUtils` package. It is recommended to use - [`LuxDeviceUtils.gpu_backend!`](@ref) to set this preference. + needs to be set for `MLDataDevices` package. It is recommended to use + [`MLDataDevices.gpu_backend!`](@ref) to set this preference. ## [Automatic Eltype Conversion](@id automatic-eltypes-preference) diff --git a/examples/ImageNet/utils.jl b/examples/ImageNet/utils.jl index ee5dcf74ff..44b1ef8ef3 100644 --- a/examples/ImageNet/utils.jl +++ b/examples/ImageNet/utils.jl @@ -2,13 +2,13 @@ CUDA.allowscalar(false) function unsafe_free! end -if LuxDeviceUtils.functional(LuxCUDADevice) +if MLDataDevices.functional(CUDADevice) function unsafe_free!(x) return hasmethod(CUDA.unsafe_free!, Tuple{typeof(x)}) ? CUDA.unsafe_free!(x) : nothing end unsafe_free!(x::OneHotArray) = CUDA.unsafe_free!(x.indices) -elseif LuxDeviceUtils.functional(LuxAMDGPUDevice) +elseif MLDataDevices.functional(AMDGPUDevice) function unsafe_free!(x) return hasmethod(AMDGPU.unsafe_free!, Tuple{typeof(x)}) ? AMDGPU.unsafe_free!(x) : nothing @@ -18,8 +18,8 @@ end function reclaim_all() GC.gc(true) - LuxDeviceUtils.functional(LuxCUDADevice) && CUDA.reclaim() - LuxDeviceUtils.functional(LuxAMDGPUDevice) && AMDGPU.reclaim() + MLDataDevices.functional(CUDADevice) && CUDA.reclaim() + MLDataDevices.functional(AMDGPUDevice) && AMDGPU.reclaim() return end @@ -147,7 +147,7 @@ end get_loggable_values(meter::ProgressMeter) = getproperty.(meter.meters, :average) # Optimisers State -function (dev::LuxDeviceUtils.AbstractLuxDevice)(l::Optimisers.Leaf) +function (dev::MLDataDevices.AbstractDevice)(l::Optimisers.Leaf) @set! l.state = dev(l.state) return l end diff --git a/ext/LuxMPIExt.jl b/ext/LuxMPIExt.jl index 072663cc99..4ffd5fbed2 100644 --- a/ext/LuxMPIExt.jl +++ b/ext/LuxMPIExt.jl @@ -1,5 +1,9 @@ module LuxMPIExt +using Lux: MPIBackend, NCCLBackend, DistributedUtils, __unwrap_val, MPI_CUDA_AWARE, + MPI_ROCM_AWARE +using MLDataDevices: AbstractDevice, CUDADevice, AMDGPUDevice, cpu_device, set_device!, + functional using MPI: MPI using Lux: Lux, MPIBackend, NCCLBackend, DistributedUtils, Utils diff --git a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl index 0295b4a286..8660967e65 100644 --- a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl +++ b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl @@ -3,6 +3,10 @@ module LuxReverseDiffExt using ADTypes: ADTypes, AbstractADType, AutoReverseDiff using ArrayInterface: ArrayInterface using FunctionWrappers: FunctionWrapper +using Lux: Lux +using Lux.Training: TrainingBackendCache, TrainState +using LuxCore: LuxCore, AbstractExplicitLayer +using MLDataDevices: CPUDevice using ReverseDiff: ReverseDiff, ForwardExecutor, ReverseExecutor, TrackedArray, TrackedReal, @grad_from_chainrules diff --git a/ext/LuxTrackerExt.jl b/ext/LuxTrackerExt.jl new file mode 100644 index 0000000000..7d2327d7b0 --- /dev/null +++ b/ext/LuxTrackerExt.jl @@ -0,0 +1,82 @@ +module LuxTrackerExt + +using ADTypes: AutoTracker +using ArrayInterface: ArrayInterface +using ChainRulesCore: ChainRulesCore +using Lux: Lux, CPUDevice +using Lux.Training: TrainingBackendCache, TrainState +using LuxCore: LuxCore, AbstractExplicitLayer +using Tracker: Tracker, TrackedArray, TrackedReal, @grad_from_chainrules + +const CRC = ChainRulesCore + +# Weight Norm Patch +@inline Lux._norm(x::TrackedArray; dims=Colon()) = sqrt.(sum(abs2.(x); dims)) + +# multigate chain rules +@inline Lux._gate(x::Tracker.TrackedVector, h::Int, n::Int) = x[Lux._gate(h, n)] +@inline Lux._gate(x::Tracker.TrackedMatrix, h::Int, n::Int) = x[Lux._gate(h, n), :] + +function __construct_tracked_params(ps, dps) + map_fn = (p, dp) -> Tracker.TrackedArray(Tracker.Call(), p, dp) + return Lux.recursive_map(map_fn, ps, dps) +end + +# Lux.Training +function Lux.Training.compute_gradients(::AutoTracker, obj_fn::F, data, + ts::TrainState{<:TrainingBackendCache{:Tracker, FT}}) where {F, FT} + dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) + ps_tracked = __construct_tracked_params(ts.parameters, dparams) + + loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data) + Tracker.back!(loss) + + ts_new = TrainState( + TrainingBackendCache{:Tracker, false}(ts.cache.dparameters, nothing), obj_fn, + ts.model, ts.parameters, st, ts.optimizer, ts.optimizer_state, ts.step) + + return dparams, loss.data, stats, ts_new +end + +function Lux.Training.compute_gradients( + ::AutoTracker, obj_fn::F, data, ts::TrainState) where {F} + grads = Lux.recursive_make_zero(ts.parameters) + ts_new = TrainState( + TrainingBackendCache{:Tracker, true}(grads, nothing), obj_fn, ts.model, + ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step) + return Lux.Training.compute_gradients(AutoTracker(), obj_fn, data, ts_new) +end + +# AoS to SoA conversion +function LuxCore.apply(m::AbstractExplicitLayer, x::AbstractArray{<:TrackedReal}, ps, st) + @warn "LuxCore.apply(m::AbstractExplicitLayer, \ + x::AbstractArray{<:Tracker.TrackedReal}, ps, st) input was corrected to \ + LuxCore.apply(m::AbstractExplicitLayer, x::Tracker.TrackedArray}, ps, st).\n\n\ + 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ + 2. This might have performance implications. Check which layer was causing this \ + problem using `Lux.Experimental.@debug_mode`." maxlog=1 + return LuxCore.apply(m, ArrayInterface.aos_to_soa(x), ps, st) +end + +## Prevent an infinite loop +LuxCore.apply(m::AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) + +@inline Lux.__eltype(::TrackedArray{T}) where {T} = T +@inline Lux.__eltype(::TrackedReal{T}) where {T} = T +@inline Lux.__eltype(::AbstractArray{<:TrackedReal{T}}) where {T} = T + +@inline Lux.__reverse(x::TrackedArray; dims=:) = ArrayInterface.aos_to_soa(reverse(x; dims)) +@inline function Lux.__reverse(x::AbstractArray{<:TrackedReal}; dims=:) + return ArrayInterface.aos_to_soa(reverse(x; dims)) +end + +# DynamicExpressions.jl +for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray) + T1 === :AbstractArray && T2 === :AbstractArray && continue + + @eval @grad_from_chainrules Lux.__apply_dynamic_expression( + de::Lux.DynamicExpressionsLayer, expr, + operator_enum, x::$(T1), ps::$(T2), dev::CPUDevice) +end + +end diff --git a/ext/LuxZygoteExt/LuxZygoteExt.jl b/ext/LuxZygoteExt/LuxZygoteExt.jl index cefbf32164..84b734aceb 100644 --- a/ext/LuxZygoteExt/LuxZygoteExt.jl +++ b/ext/LuxZygoteExt/LuxZygoteExt.jl @@ -4,6 +4,8 @@ using ArgCheck: @argcheck using ADTypes: AutoZygote using ChainRulesCore: ChainRulesCore using ForwardDiff: ForwardDiff +using Lux: Lux +using MLDataDevices: get_device_type, CPUDevice using Setfield: @set! using Zygote: Zygote diff --git a/src/Lux.jl b/src/Lux.jl index c6f562ead2..c004a4ae37 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -20,11 +20,7 @@ using Reexport: @reexport using Statistics: mean using UnrolledUtilities: unrolled_map, unrolled_mapreduce -# TODO: In v1 we remove the LuxDeviceUtils dependency and replace it with MLDataDevices -@reexport using LuxCore, LuxLib, LuxDeviceUtils, WeightInitializers -using MLDataDevices: MLDataDevices, AMDGPUDevice, CPUDevice -using NNlib: NNlib - +@reexport using LuxCore, LuxLib, MLDataDevices, WeightInitializers import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer, initialparameters, initialstates, parameterlength, statelength, inputsize, outputsize, update_state, trainmode, testmode, setup, apply, replicate diff --git a/src/helpers/nested_ad.jl b/src/helpers/nested_ad.jl new file mode 100644 index 0000000000..4cf036a9d6 --- /dev/null +++ b/src/helpers/nested_ad.jl @@ -0,0 +1,203 @@ +#! format: off +const AD_CONVERTIBLE_FUNCTIONS = [ + # Input Gradient/Jacobian + ComposedFunction{<:Any, <:StatefulLuxLayer}, + ComposedFunction{<:StatefulLuxLayer, <:Any}, + StatefulLuxLayer, + # Parameter Gradient/Jacobian + ComposedFunction{<:Any, <:Base.Fix1{<:StatefulLuxLayer}}, + ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:Any}, + Base.Fix1{<:StatefulLuxLayer} +] +#! format: on + +## Written like this to avoid dynamic dispatch from Zygote +# Input Gradient / Jacobian +@inline __rewrite_ad_call(f::ComposedFunction{F, <:StatefulLuxLayer}) where {F} = ( + f, f.inner.ps) +@inline __rewrite_ad_call(f::ComposedFunction{<:StatefulLuxLayer, F}) where {F} = ( + @closure((x, ps)->f.outer(f.inner(x), ps)), f.outer.ps) +@inline __rewrite_ad_call(f::StatefulLuxLayer) = f, f.ps + +# Parameter Gradient / Jacobian +@inline __rewrite_ad_call(f::ComposedFunction{F, <:Base.Fix1{<:StatefulLuxLayer}}) where {F} = ( + @closure((ps, x)->f.outer(f.inner.f(x, ps))), f.inner.x) +@inline __rewrite_ad_call(f::ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, F}) where {F} = ( + @closure((ps, x)->f.outer.f(x, f.inner(ps))), f.outer.x) +@inline __rewrite_ad_call(f::Base.Fix1{<:StatefulLuxLayer}) = ( + @closure((ps, x)->f.f(x, ps)), f.x) + +## Break ambiguity +for op in [ComposedFunction{<:StatefulLuxLayer, <:StatefulLuxLayer}, + ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:StatefulLuxLayer}, + ComposedFunction{<:StatefulLuxLayer, <:Base.Fix1{<:StatefulLuxLayer}}, + ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:Base.Fix1{<:StatefulLuxLayer}}] + @eval @inline function __rewrite_ad_call(::$op) + error("Cannot rewrite ComposedFunction with StatefulLuxLayer as inner and outer layers") + end +end + +# Nested Gradients +## Essentially computes the gradient of `f(x, y)` wrt x using the function `grad_fn` +## To compute the gradient of `f(x, y)` wrt y, just reorder the arguments with a wrapper +## over `f` +for fname in (:__internal_ad_gradient_call, :__internal_ad_gradient_call_no_custom_rrule) + @eval @inline function $fname(grad_fn::G, f::F, x, y) where {G, F} + return grad_fn(Base.Fix2(f, y), x) + end +end + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__internal_ad_gradient_call), + grad_fn::G, f::F, x, y) where {G, F} + @static if !AUTOMATIC_NESTED_AD_SWITCHING + return CRC.rrule_via_ad( + cfg, __internal_ad_gradient_call_no_custom_rrule, grad_fn, f, x, y) + end + + res = __internal_ad_gradient_call(grad_fn, f, x, y) + ∇internal_gradient_capture = @closure Δ_ -> begin + (Δ_ isa NoTangent || Δ_ isa ZeroTangent) && return ntuple(Returns(NoTangent()), 5) + + Δ = CRC.unthunk(Δ_) + (res isa Tuple || Δ isa Tuple) && (Δ = only(Δ)) # For Zygote and such which return a tuple + ∂x, ∂y = __forwarddiff_jvp(@closure((x, y)->grad_fn(f, x, y)), x, Δ, y) + return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂y + end + + return res, ∇internal_gradient_capture +end + +# Nested Pullbacks +for fname in (:__internal_ad_pullback_call, :__internal_ad_pullback_call_no_custom_rrule) + @eval @inline function $fname(pullback_fn::P, f::F, x, y, u) where {P, F} + return only(last(pullback_fn(Base.Fix2(f, y), x))(u)) + end +end + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__internal_ad_pullback_call), + pullback_fn::P, f::F, x, y, u) where {P, F} + @static if !AUTOMATIC_NESTED_AD_SWITCHING + return CRC.rrule_via_ad( + cfg, __internal_ad_pullback_call_no_custom_rrule, pullback_fn, f, x, y, u) + end + + res = __internal_ad_pullback_call(pullback_fn, f, x, y, u) + ∇nested_ad = let pullback_fn = pullback_fn, f = f, x = x, y = y, u = u, res = res + Δ_ -> begin + (Δ_ isa NoTangent || Δ_ isa ZeroTangent) && + return ntuple(Returns(NoTangent()), 6) + + Δ = CRC.unthunk(Δ_) + (res isa Tuple || Δ isa Tuple) && (Δ = only(Δ)) # For Zygote and such which return a tuple + ∂x, ∂y = __forwarddiff_jvp(x, Δ, y) do x_dual, y_ + return last(pullback_fn(f, x_dual, y_))(u) + end + return (NoTangent(), NoTangent(), NoTangent(), ∂x, ∂y, NoTangent()) + end + end + + return res, ∇nested_ad +end + +# Nested Jacobians +## `grad_fn` is not needed for the forward pass, we need it for the reverse pass HVP +for fname in (:__internal_ad_jacobian_call, :__internal_ad_jacobian_call_no_custom_rrule) + @eval @inline function $fname( + jac_fn::J, grad_fn::G, f::F, x::AbstractArray, y) where {J, G, F} + return jac_fn(Base.Fix2(f, y), x) + end +end + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__internal_ad_jacobian_call), + jac_fn::J, grad_fn::G, f::F, x::AbstractArray, y) where {J, G, F} + @static if !AUTOMATIC_NESTED_AD_SWITCHING + return CRC.rrule_via_ad( + cfg, __internal_ad_jacobian_call_no_custom_rrule, jac_fn, grad_fn, f, x, y) + end + + res = __internal_ad_jacobian_call(jac_fn, grad_fn, f, x, y) + ∇internal_jacobian_capture = let res = res, grad_fn = grad_fn, f = f, x = x, y = y + Δ_ -> begin + (Δ_ isa NoTangent || Δ_ isa ZeroTangent) && + return ntuple(Returns(NoTangent()), 6) + + Δ = CRC.unthunk(Δ_) + (res isa Tuple || Δ isa Tuple) && (Δ = only(Δ)) # For Zygote and such which return a tuple + Δ = __compactify_if_structured_matrix(res isa Tuple ? only(res) : res, Δ) + + __inner_grad_fn = @closure(i->sum ∘ Base.Fix2(getindex, i:i) ∘ vec ∘ f) + map_fn = @closure i -> begin + Δᵢ = __maybe_batched_row(Δ, i) + fn = __inner_grad_fn(i) + __f = let fn = fn + (x, y) -> grad_fn(fn, x, y) + end + return __forwarddiff_jvp(__f, x, Δᵢ, y) + end + + # FIXME: threading on CUDA cause unexpected errors on the first run to CUDNN + # when doing a algorithm lookup + ∂x, ∂y = if get_device_type(x) <: CPUDevice + tasks = map(i -> Threads.@spawn(map_fn(i)), 1:__numrows(Δ)) + mapreduce(fetch, recursive_add!!, tasks) + else + mapreduce(map_fn, recursive_add!!, 1:__numrows(Δ)) + end + + return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), ∂x, ∂y) + end + end + + return res, ∇internal_jacobian_capture +end + +# Convert a structured Matrix to a General Matrix if it doesn't have fast scalar indexing +@inline function __compactify_if_structured_matrix( + J::AbstractArray{T1, N}, Δ::AbstractArray{T2}) where {T1, T2, N} + @argcheck N ∈ (2, 3) "Only 2D and 3D arrays are supported for compactifying." + if !ArrayInterface.fast_scalar_indexing(J) && ArrayInterface.isstructured(Δ) + J_ = similar(J) + copyto!(J_, Δ) + return J_ + end + return reshape(Δ, size(J)) +end + +@inline __numrows(x::AbstractMatrix) = size(x, 1) +@inline __numrows(x::AbstractArray{T, 3}) where {T} = size(x, 1) * size(x, 3) + +@inline __maybe_batched_row(x::AbstractMatrix, i::Integer) = view(x, i, :) +@inline function __maybe_batched_row(x::AbstractArray{T, 3}, i::Integer) where {T} + M, N, K = size(x) + k = (i - 1) ÷ M + 1 + i = mod1(i, M) + y = similar(x, N * K) + data = view(x, i, :, k) + fill!(view(y, 1:(N * (K - 1))), zero(T)) + copyto!(view(y, (N * (k - 1) + 1):(N * k)), data) + fill!(view(y, (N * k + 1):(N * K)), zero(T)) + return y +end + +@inline function __partials(::Type{Tag}, x, i) where {Tag} + x isa ForwardDiff.Dual && return ForwardDiff.partials(Tag, x, i) + if x isa AbstractArray + bfn(xᵢ, iᵢ) = ForwardDiff.partials(Tag, xᵢ, iᵢ) + return bfn.(x, i) + end + map_fn = @closure(xᵢ->__partials(Tag, xᵢ, i)) + (x isa Tuple || x isa NamedTuple) && return map(map_fn, x) + x isa CRC.AbstractTangent && return __partials(Tag, CRC.backing(x), i) + x === nothing && return nothing + return fmap(map_fn, x) +end + +@inline function __dualify(::Type{Tag}, ::Type{T}, x, u) where {Tag, T} + if x isa AbstractArray + bfn(xᵢ, uᵢ) = ForwardDiff.Dual{Tag, T, 1}(xᵢ, ForwardDiff.Partials{1, T}(uᵢ)) + return bfn.(x, tuple.(reshape(u, size(x)))) + end + (x isa Tuple || x isa NamedTuple) && + return map((xᵢ, uᵢ) -> __dualify(Tag, T, xᵢ, uᵢ), x, u) + return fmap((xᵢ, uᵢ) -> __dualify(Tag, T, xᵢ, uᵢ), x, u) +end diff --git a/test/contrib/share_parameters_tests.jl b/test/contrib/share_parameters_tests.jl index d87a477cd7..385e8a4eae 100644 --- a/test/contrib/share_parameters_tests.jl +++ b/test/contrib/share_parameters_tests.jl @@ -29,7 +29,7 @@ @test ps_2.d3.bias == ps_new_2.bias == ps_2.d2.l1.bias # Mix in ComponentArray - ps_new_ca_1 = ComponentArray(ps_new_1 |> LuxCPUDevice()) |> dev + ps_new_ca_1 = ComponentArray(ps_new_1 |> CPUDevice()) |> device ps_3 = Lux.Experimental.share_parameters(ps, sharing, (ps_new_ca_1, ps_new_2)) @@ -54,7 +54,7 @@ @test_throws ArgumentError Lux.Experimental.share_parameters( ps, sharing, (ps_new_1, ps_new_2)) - ps_new_ca_1 = ComponentArray(ps_new_1 |> LuxCPUDevice()) |> dev + ps_new_ca_1 = ComponentArray(ps_new_1 |> CPUDevice()) |> device @test_throws ArgumentError Lux.Experimental.share_parameters( ps, sharing, (ps_new_ca_1, ps_new_2)) diff --git a/test/distributed/common_distributedtest.jl b/test/distributed/common_distributedtest.jl index 4c64927d09..231078b6b6 100644 --- a/test/distributed/common_distributedtest.jl +++ b/test/distributed/common_distributedtest.jl @@ -9,8 +9,8 @@ if input_args[1] == "amdgpu" end const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend -const dev = input_args[1] == "cpu" ? LuxCPUDevice() : - (input_args[1] == "cuda" ? LuxCUDADevice() : LuxAMDGPUDevice()) +const dev = input_args[1] == "cpu" ? CPUDevice() : + (input_args[1] == "cuda" ? CUDADevice() : AMDGPUDevice()) const aType = input_args[1] == "cpu" ? Array : (input_args[1] == "cuda" ? CuArray : ROCArray) diff --git a/test/distributed/data_distributedtest.jl b/test/distributed/data_distributedtest.jl index d2eb08de78..c2f78adf57 100644 --- a/test/distributed/data_distributedtest.jl +++ b/test/distributed/data_distributedtest.jl @@ -14,8 +14,8 @@ if input_args[1] == "amdgpu" end const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend -const dev = input_args[1] == "cpu" ? LuxCPUDevice() : - (input_args[1] == "cuda" ? LuxCUDADevice() : LuxAMDGPUDevice()) +const dev = input_args[1] == "cpu" ? CPUDevice() : + (input_args[1] == "cuda" ? CUDADevice() : AMDGPUDevice()) rng = Xoshiro(1234) diff --git a/test/distributed/optimizer_distributedtest.jl b/test/distributed/optimizer_distributedtest.jl index 122761f194..6a3992a43a 100644 --- a/test/distributed/optimizer_distributedtest.jl +++ b/test/distributed/optimizer_distributedtest.jl @@ -9,8 +9,8 @@ if input_args[1] == "amdgpu" end const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend -const dev = input_args[1] == "cpu" ? LuxCPUDevice() : - (input_args[1] == "cuda" ? LuxCUDADevice() : LuxAMDGPUDevice()) +const dev = input_args[1] == "cpu" ? CPUDevice() : + (input_args[1] == "cuda" ? CUDADevice() : AMDGPUDevice()) DistributedUtils.initialize(backend_type) backend = DistributedUtils.get_distributed_backend(backend_type) diff --git a/test/distributed/synchronize_distributedtest.jl b/test/distributed/synchronize_distributedtest.jl index 403cab1d53..388755881e 100644 --- a/test/distributed/synchronize_distributedtest.jl +++ b/test/distributed/synchronize_distributedtest.jl @@ -9,8 +9,8 @@ if input_args[1] == "amdgpu" end const backend_type = input_args[2] == "nccl" ? NCCLBackend : MPIBackend -const dev = input_args[1] == "cpu" ? LuxCPUDevice() : - (input_args[1] == "cuda" ? LuxCUDADevice() : LuxAMDGPUDevice()) +const dev = input_args[1] == "cpu" ? CPUDevice() : + (input_args[1] == "cuda" ? CUDADevice() : AMDGPUDevice()) function __get_array_based_on_rank(backend, dims; root) DistributedUtils.local_rank(backend) == root && return ones(dims...) diff --git a/test/layers/normalize_tests.jl b/test/layers/normalize_tests.jl index a598ae8972..5f67efb1c2 100644 --- a/test/layers/normalize_tests.jl +++ b/test/layers/normalize_tests.jl @@ -15,7 +15,7 @@ @test ps.scale == [1, 1] |> aType # init_scale(2) y, st_ = pullback(m, x, ps, st)[1] - st_ = st_ |> LuxCPUDevice() + st_ = st_ |> CPUDevice() @test check_approx(Array(y), [-1.22474 0 1.22474; -1.22474 0 1.22474]; atol=1.0e-5) # julia> x # 2×3 Array{Float64,2}: @@ -39,7 +39,7 @@ 0.1 .* var(Array(x); dims=2, corrected=false) .* (3 / 2) .+ 0.9 .* [1.0, 1.0]) st_ = Lux.testmode(st_) |> device - x_ = m(x, ps, st_)[1] |> LuxCPUDevice() + x_ = m(x, ps, st_)[1] |> CPUDevice() @test check_approx(x_[1], (1 .- 0.3) / sqrt(1.3), atol=1.0e-5) @jet m(x, ps, st) diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 06e20f31e4..ac59d4ffcd 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -16,7 +16,7 @@ end # Skip our own packages @test check_no_implicit_imports( - Lux; skip=(Base, Core, LuxCore, LuxDeviceUtils, LuxLib, WeightInitializers)) === + Lux; skip=(Base, Core, LuxCore, MLDataDevices, LuxLib, WeightInitializers)) === nothing @test check_no_stale_explicit_imports( Lux; ignore=(:inputsize, :setup, :testmode, :trainmode, :update_state)) === nothing diff --git a/test/setup_modes.jl b/test/setup_modes.jl index 88b88a247b..1617179a5b 100644 --- a/test/setup_modes.jl +++ b/test/setup_modes.jl @@ -25,9 +25,9 @@ end const MODES = begin # Mode, Array Type, Device Function, GPU? modes = [] - cpu_testing() && push!(modes, ("cpu", Array, LuxCPUDevice(), false)) - cuda_testing() && push!(modes, ("cuda", CuArray, LuxCUDADevice(), true)) - amdgpu_testing() && push!(modes, ("amdgpu", ROCArray, LuxAMDGPUDevice(), true)) + cpu_testing() && push!(modes, ("cpu", Array, CPUDevice(), false)) + cuda_testing() && push!(modes, ("cuda", CuArray, CUDADevice(), true)) + amdgpu_testing() && push!(modes, ("amdgpu", ROCArray, AMDGPUDevice(), true)) modes end diff --git a/test/transform/flux_tests.jl b/test/transform/flux_tests.jl index b063bd67fb..a6530e27bb 100644 --- a/test/transform/flux_tests.jl +++ b/test/transform/flux_tests.jl @@ -1,9 +1,9 @@ @testitem "FromFluxAdaptor" setup=[SharedTestSetup] tags=[:fluxcompat] begin import Flux - from_flux = fdev(::Lux.LuxCPUDevice) = Flux.cpu - fdev(::Lux.LuxCUDADevice) = Base.Fix1(Flux.gpu, Flux.FluxCUDAAdaptor()) - fdev(::Lux.LuxAMDGPUDevice) = Base.Fix1(Flux.gpu, Flux.FluxAMDAdaptor()) + from_flux = fdev(::Lux.CPUDevice) = Flux.cpu + fdev(::Lux.CUDADevice) = Base.Fix1(Flux.gpu, Flux.FluxCUDAAdaptor()) + fdev(::Lux.AMDGPUDevice) = Base.Fix1(Flux.gpu, Flux.FluxAMDAdaptor()) toluxpsst = FromFluxAdaptor(; preserve_ps_st=true) tolux = FromFluxAdaptor() From a35f1fb419a4971ee67306bbbd13dc8ae00c4e1f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 15:04:49 -0700 Subject: [PATCH 17/95] feat: reexport NNlib --- src/Lux.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Lux.jl b/src/Lux.jl index c004a4ae37..2c52111266 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -20,11 +20,12 @@ using Reexport: @reexport using Statistics: mean using UnrolledUtilities: unrolled_map, unrolled_mapreduce -@reexport using LuxCore, LuxLib, MLDataDevices, WeightInitializers import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer, initialparameters, initialstates, parameterlength, statelength, inputsize, outputsize, update_state, trainmode, testmode, setup, apply, replicate +@reexport using LuxCore, LuxLib, MLDataDevices, NNlib, WeightInitializers + const CRC = ChainRulesCore const NAME_TYPE = Union{Nothing, String, Symbol} From dddfd6675ddace9627428511de9aebc760f9643d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 15:16:43 -0700 Subject: [PATCH 18/95] chore: remove old versions --- docs/Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 93e13e81ad..a1bce8d80b 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -38,7 +38,7 @@ GPUArraysCore = "0.1" KernelAbstractions = "0.9" LinearAlgebra = "1.10" Literate = "2.18.0" -Lux = "0.5.62" +Lux = "1" LuxCUDA = "0.3.2" LuxCore = "0.1.15" LuxLib = "0.3.42" @@ -49,6 +49,6 @@ Pkg = "1.10" Printf = "1.10" Random = "1.10" StaticArrays = "1" -WeightInitializers = "0.1.7, 1" +WeightInitializers = "1" Zygote = "0.6.70" julia = "1.10" From fcf27c648f952d161154b1fd4f5b60cc07255db5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 15:41:45 -0700 Subject: [PATCH 19/95] fix: errors in testing --- Project.toml | 2 +- docs/src/api/Building_Blocks/LuxLib.md | 2 +- test/helpers/training_tests.jl | 44 +++++++++++++------------- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/Project.toml b/Project.toml index 4e392cea08..0ea62fc479 100644 --- a/Project.toml +++ b/Project.toml @@ -69,7 +69,7 @@ LuxZygoteExt = "Zygote" [compat] ADTypes = "1.5" Adapt = "4" -ArgCheck = "2.1" +ArgCheck = "2.3" ArrayInterface = "7.9" CUDA = "5.3.2" ChainRulesCore = "1.24" diff --git a/docs/src/api/Building_Blocks/LuxLib.md b/docs/src/api/Building_Blocks/LuxLib.md index 8075d83ce0..72d067eae3 100644 --- a/docs/src/api/Building_Blocks/LuxLib.md +++ b/docs/src/api/Building_Blocks/LuxLib.md @@ -38,7 +38,7 @@ fused_conv_bias_activation ```@docs alpha_dropout -dropout +LuxLib.dropout ``` ## Fully Connected Layers diff --git a/test/helpers/training_tests.jl b/test/helpers/training_tests.jl index 86bbd649d8..b44a7cdd59 100644 --- a/test/helpers/training_tests.jl +++ b/test/helpers/training_tests.jl @@ -8,7 +8,7 @@ opt = Adam(0.01f0) ps, st = Lux.setup(Lux.replicate(rng), model) |> dev - tstate = Lux.Training.TrainState(model, ps, st, opt) + tstate = Training.TrainState(model, ps, st, opt) x = randn(Lux.replicate(rng), Float32, (3, 1)) opt_st = Optimisers.setup(opt, tstate.parameters) @@ -34,7 +34,7 @@ end opt = Adam(0.01f0) ps, st = Lux.setup(rng, model) |> dev - tstate = Lux.Training.TrainState(model, ps, st, opt) + tstate = Training.TrainState(model, ps, st, opt) x = randn(Lux.replicate(rng), Float32, (3, 1)) |> aType @@ -42,8 +42,8 @@ end ongpu && (ad isa AutoReverseDiff || ad isa AutoEnzyme) && continue !LuxTestUtils.ENZYME_TESTING_ENABLED && ad isa AutoEnzyme && continue - grads, _, _, _ = Lux.Training.compute_gradients(ad, _loss_function, x, tstate) - tstate_ = Lux.Training.apply_gradients(tstate, grads) + grads, _, _, _ = Training.compute_gradients(ad, _loss_function, x, tstate) + tstate_ = Training.apply_gradients(tstate, grads) @test tstate_.step == 1 @test tstate != tstate_ end @@ -81,9 +81,9 @@ end for epoch in 1:1000, (x, y) in dataset_ grads, loss, _, tstate = allow_unstable() do - Lux.Experimental.compute_gradients(ad, mse, (x, y), tstate) + Training.compute_gradients(ad, mse, (x, y), tstate) end - tstate = Lux.Training.apply_gradients!(tstate, grads) + tstate = Training.apply_gradients!(tstate, grads) end for epoch in 1:1000, (x, y) in dataset_ @@ -121,7 +121,7 @@ end ps, st = Lux.setup(rng, model) |> dev tstate = Training.TrainState(model, ps, st, opt) - @test_throws ArgumentError Lux.Training.compute_gradients( + @test_throws ArgumentError Training.compute_gradients( AutoCustomAD(), mse, dataset_[1], tstate) end end @@ -179,9 +179,9 @@ end x = randn(rng, Float32, 4, 32) opt = Adam(0.001f0) - tstate = Lux.Training.TrainState(model, ps, st, opt) + tstate = Training.TrainState(model, ps, st, opt) - _, _, _, tstate_new = @inferred Lux.Training.compute_gradients( + _, _, _, tstate_new = @inferred Training.compute_gradients( AutoEnzyme(), mse, (x, x), tstate) @test tstate_new.states !== tstate.states @@ -189,15 +189,15 @@ end model = Chain(Dense(4 => 3), Dense(3 => 4)) ps, st = Lux.setup(rng, model) - tstate = Lux.Training.TrainState(model, ps, st, opt) + tstate = Training.TrainState(model, ps, st, opt) - _, _, _, tstate_new = @inferred Lux.Training.compute_gradients( + _, _, _, tstate_new = @inferred Training.compute_gradients( AutoEnzyme(), mse, (x, x), tstate) - @test @inferred(Lux.Training.compute_gradients( - AutoEnzyme(), mse, (x, x), tstate_new)) isa Any + @test @inferred(Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new)) isa + Any - _, _, _, tstate_new2 = @inferred Lux.Training.compute_gradients( + _, _, _, tstate_new2 = @inferred Training.compute_gradients( AutoEnzyme(), mse2, (x, x), tstate_new) @test hasfield(typeof(tstate_new2.cache.extras), :forward) @test hasfield(typeof(tstate_new2.cache.extras), :reverse) @@ -221,19 +221,19 @@ end Dense(32, 32, tanh), BatchNorm(32), Dense(32, 4)) ps, st = Lux.setup(rng, model) - tstate = Lux.Training.TrainState(model, ps, st, Adam(0.001f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) # Stateful models are not supported - @test_throws ArgumentError Lux.Training.compute_gradients( + @test_throws ArgumentError Training.compute_gradients( AutoReverseDiff(; compile=true), mse1, dataset[1], tstate) model = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 4)) ps, st = Lux.setup(rng, model) - tstate = Lux.Training.TrainState(model, ps, st, Adam(0.001f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) # Loss functions that return non-empty `stats` are not supported - @test_throws ArgumentError Lux.Training.compute_gradients( + @test_throws ArgumentError Training.compute_gradients( AutoReverseDiff(; compile=true), mse2, dataset[1], tstate) struct StrangeModel <: Lux.AbstractExplicitLayer end @@ -245,23 +245,23 @@ end model = StrangeModel() ps, st = Lux.setup(rng, model) - tstate = Lux.Training.TrainState(model, ps, st, Adam(0.001f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) # Stateful models are not supported - @test_throws ArgumentError Lux.Training.compute_gradients( + @test_throws ArgumentError Training.compute_gradients( AutoReverseDiff(; compile=true), mse1, dataset[1], tstate) end model = Chain(Dense(4, 32, tanh), Dense(32, 32, tanh), Dense(32, 4)) ps, st = Lux.setup(rng, model) - tstate = Lux.Training.TrainState(model, ps, st, Adam(0.001f0)) + tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) loss_initial = first(mse1(model, ps, st, dataset[1])) for i in 1:100 for (x, y) in dataset _, _, _, tstate = allow_unstable() do - Lux.Training.single_train_step!( + Training.single_train_step!( AutoReverseDiff(; compile=true), mse1, (x, y), tstate) end end From 6f2f618f06015c462d5479a44bfe5ebff2454764 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 16:54:28 -0700 Subject: [PATCH 20/95] fix: don't reexport NNlib.dropout --- docs/src/api/Building_Blocks/LuxLib.md | 2 +- src/Lux.jl | 5 +++-- src/helpers/training.jl | 4 ++++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/src/api/Building_Blocks/LuxLib.md b/docs/src/api/Building_Blocks/LuxLib.md index 72d067eae3..8075d83ce0 100644 --- a/docs/src/api/Building_Blocks/LuxLib.md +++ b/docs/src/api/Building_Blocks/LuxLib.md @@ -38,7 +38,7 @@ fused_conv_bias_activation ```@docs alpha_dropout -LuxLib.dropout +dropout ``` ## Fully Connected Layers diff --git a/src/Lux.jl b/src/Lux.jl index 2c52111266..db7b0af0b8 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -16,7 +16,7 @@ using Markdown: @doc_str using Optimisers: Optimisers using Random: Random, AbstractRNG using Static: StaticBool, StaticInt, StaticSymbol, True, False, static, known, dynamic -using Reexport: @reexport +using Reexport: Reexport, @reexport using Statistics: mean using UnrolledUtilities: unrolled_map, unrolled_mapreduce @@ -24,7 +24,8 @@ import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer, initialpa initialstates, parameterlength, statelength, inputsize, outputsize, update_state, trainmode, testmode, setup, apply, replicate -@reexport using LuxCore, LuxLib, MLDataDevices, NNlib, WeightInitializers +@reexport using LuxCore, LuxLib, MLDataDevices, WeightInitializers +@eval Expr(:export, filter(x -> x !== :dropout, Reexport.exported_names(NNlib))...) const CRC = ChainRulesCore diff --git a/src/helpers/training.jl b/src/helpers/training.jl index 3819b39e72..238e2d088d 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -9,7 +9,11 @@ using Random: AbstractRNG using ..Lux: Lux using LuxCore: LuxCore, AbstractExplicitLayer +<<<<<<< HEAD using MLDataDevices: MLDataDevices +======= +using Optimisers: Optimisers +>>>>>>> 30f27fd4 (fix: don't reexport NNlib.dropout) """ TrainState From 3ba442f21a87bb0bb17c407e6a99a26d8ddd55db Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 12:37:33 -0700 Subject: [PATCH 21/95] fix: remove explicit imports --- src/contrib/contrib.jl | 2 +- src/helpers/training.jl | 7 +------ src/preferences.jl | 2 +- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/contrib/contrib.jl b/src/contrib/contrib.jl index aee5036c62..902b56a67b 100644 --- a/src/contrib/contrib.jl +++ b/src/contrib/contrib.jl @@ -1,6 +1,6 @@ module Experimental -using ..Lux: Lux, Training, Optional +using ..Lux: Lux, Optional using ..Utils: Utils, BoolType, SymbolType using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, apply diff --git a/src/helpers/training.jl b/src/helpers/training.jl index 238e2d088d..0d4bd7be12 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -4,16 +4,11 @@ using ADTypes: AbstractADType, AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZyg using Compat: @compat using ConcreteStructs: @concrete using FastClosures: @closure -using Optimisers: Optimisers, AbstractRule -using Random: AbstractRNG +using Optimisers: Optimisers using ..Lux: Lux using LuxCore: LuxCore, AbstractExplicitLayer -<<<<<<< HEAD -using MLDataDevices: MLDataDevices -======= using Optimisers: Optimisers ->>>>>>> 30f27fd4 (fix: don't reexport NNlib.dropout) """ TrainState diff --git a/src/preferences.jl b/src/preferences.jl index a1ec0eef3b..a3eaff5445 100644 --- a/src/preferences.jl +++ b/src/preferences.jl @@ -1,7 +1,7 @@ module LuxPreferences using ArgCheck: @argcheck -using Preferences: load_preference, has_preference, set_preferences! +using Preferences: load_preference, has_preference, set_preferences!, @load_preference using ..Lux: Lux From 6553b87df2bb4d31229c8dcd33c8cc22f97b2596 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 12:45:00 -0700 Subject: [PATCH 22/95] fix: bad rebase --- ext/LuxMPIExt.jl | 4 - ext/LuxReverseDiffExt/LuxReverseDiffExt.jl | 5 - ext/LuxTrackerExt.jl | 82 --------- ext/LuxTrackerExt/rules.jl | 9 - ext/LuxZygoteExt/LuxZygoteExt.jl | 2 - src/contrib/map.jl | 2 +- src/helpers/nested_ad.jl | 203 --------------------- test/contrib/share_parameters_tests.jl | 16 +- test/utils_tests.jl | 2 +- 9 files changed, 8 insertions(+), 317 deletions(-) delete mode 100644 ext/LuxTrackerExt.jl delete mode 100644 src/helpers/nested_ad.jl diff --git a/ext/LuxMPIExt.jl b/ext/LuxMPIExt.jl index 4ffd5fbed2..072663cc99 100644 --- a/ext/LuxMPIExt.jl +++ b/ext/LuxMPIExt.jl @@ -1,9 +1,5 @@ module LuxMPIExt -using Lux: MPIBackend, NCCLBackend, DistributedUtils, __unwrap_val, MPI_CUDA_AWARE, - MPI_ROCM_AWARE -using MLDataDevices: AbstractDevice, CUDADevice, AMDGPUDevice, cpu_device, set_device!, - functional using MPI: MPI using Lux: Lux, MPIBackend, NCCLBackend, DistributedUtils, Utils diff --git a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl index 8660967e65..fdd6462086 100644 --- a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl +++ b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl @@ -10,11 +10,6 @@ using MLDataDevices: CPUDevice using ReverseDiff: ReverseDiff, ForwardExecutor, ReverseExecutor, TrackedArray, TrackedReal, @grad_from_chainrules -using Lux: Lux, Utils -using Lux.Training: TrainingBackendCache, TrainState -using LuxCore: LuxCore -using MLDataDevices: CPUDevice - include("utils.jl") include("rules.jl") include("training.jl") diff --git a/ext/LuxTrackerExt.jl b/ext/LuxTrackerExt.jl deleted file mode 100644 index 7d2327d7b0..0000000000 --- a/ext/LuxTrackerExt.jl +++ /dev/null @@ -1,82 +0,0 @@ -module LuxTrackerExt - -using ADTypes: AutoTracker -using ArrayInterface: ArrayInterface -using ChainRulesCore: ChainRulesCore -using Lux: Lux, CPUDevice -using Lux.Training: TrainingBackendCache, TrainState -using LuxCore: LuxCore, AbstractExplicitLayer -using Tracker: Tracker, TrackedArray, TrackedReal, @grad_from_chainrules - -const CRC = ChainRulesCore - -# Weight Norm Patch -@inline Lux._norm(x::TrackedArray; dims=Colon()) = sqrt.(sum(abs2.(x); dims)) - -# multigate chain rules -@inline Lux._gate(x::Tracker.TrackedVector, h::Int, n::Int) = x[Lux._gate(h, n)] -@inline Lux._gate(x::Tracker.TrackedMatrix, h::Int, n::Int) = x[Lux._gate(h, n), :] - -function __construct_tracked_params(ps, dps) - map_fn = (p, dp) -> Tracker.TrackedArray(Tracker.Call(), p, dp) - return Lux.recursive_map(map_fn, ps, dps) -end - -# Lux.Training -function Lux.Training.compute_gradients(::AutoTracker, obj_fn::F, data, - ts::TrainState{<:TrainingBackendCache{:Tracker, FT}}) where {F, FT} - dparams = FT ? ts.cache.dparameters : Lux.recursive_make_zero!!(ts.cache.dparameters) - ps_tracked = __construct_tracked_params(ts.parameters, dparams) - - loss, st, stats = obj_fn(ts.model, ps_tracked, ts.states, data) - Tracker.back!(loss) - - ts_new = TrainState( - TrainingBackendCache{:Tracker, false}(ts.cache.dparameters, nothing), obj_fn, - ts.model, ts.parameters, st, ts.optimizer, ts.optimizer_state, ts.step) - - return dparams, loss.data, stats, ts_new -end - -function Lux.Training.compute_gradients( - ::AutoTracker, obj_fn::F, data, ts::TrainState) where {F} - grads = Lux.recursive_make_zero(ts.parameters) - ts_new = TrainState( - TrainingBackendCache{:Tracker, true}(grads, nothing), obj_fn, ts.model, - ts.parameters, ts.states, ts.optimizer, ts.optimizer_state, ts.step) - return Lux.Training.compute_gradients(AutoTracker(), obj_fn, data, ts_new) -end - -# AoS to SoA conversion -function LuxCore.apply(m::AbstractExplicitLayer, x::AbstractArray{<:TrackedReal}, ps, st) - @warn "LuxCore.apply(m::AbstractExplicitLayer, \ - x::AbstractArray{<:Tracker.TrackedReal}, ps, st) input was corrected to \ - LuxCore.apply(m::AbstractExplicitLayer, x::Tracker.TrackedArray}, ps, st).\n\n\ - 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ - 2. This might have performance implications. Check which layer was causing this \ - problem using `Lux.Experimental.@debug_mode`." maxlog=1 - return LuxCore.apply(m, ArrayInterface.aos_to_soa(x), ps, st) -end - -## Prevent an infinite loop -LuxCore.apply(m::AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st) - -@inline Lux.__eltype(::TrackedArray{T}) where {T} = T -@inline Lux.__eltype(::TrackedReal{T}) where {T} = T -@inline Lux.__eltype(::AbstractArray{<:TrackedReal{T}}) where {T} = T - -@inline Lux.__reverse(x::TrackedArray; dims=:) = ArrayInterface.aos_to_soa(reverse(x; dims)) -@inline function Lux.__reverse(x::AbstractArray{<:TrackedReal}; dims=:) - return ArrayInterface.aos_to_soa(reverse(x; dims)) -end - -# DynamicExpressions.jl -for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray) - T1 === :AbstractArray && T2 === :AbstractArray && continue - - @eval @grad_from_chainrules Lux.__apply_dynamic_expression( - de::Lux.DynamicExpressionsLayer, expr, - operator_enum, x::$(T1), ps::$(T2), dev::CPUDevice) -end - -end diff --git a/ext/LuxTrackerExt/rules.jl b/ext/LuxTrackerExt/rules.jl index 39f9a879c8..0976070b7f 100644 --- a/ext/LuxTrackerExt/rules.jl +++ b/ext/LuxTrackerExt/rules.jl @@ -1,12 +1,3 @@ -# SimpleChains.jl: DON'T REPLACE THESE WITH @grad_from_chainrules -for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray) - T1 === :AbstractArray && T2 === :AbstractArray && continue - - @eval function Lux.apply_simple_chain(layer, x::$(T1), ps::$(T2), dev::CPUDevice) - return Tracker.track(Lux.apply_simple_chain, layer, x, ps, dev) - end -end - # DynamicExpressions.jl for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray) T1 === :AbstractArray && T2 === :AbstractArray && continue diff --git a/ext/LuxZygoteExt/LuxZygoteExt.jl b/ext/LuxZygoteExt/LuxZygoteExt.jl index 84b734aceb..cefbf32164 100644 --- a/ext/LuxZygoteExt/LuxZygoteExt.jl +++ b/ext/LuxZygoteExt/LuxZygoteExt.jl @@ -4,8 +4,6 @@ using ArgCheck: @argcheck using ADTypes: AutoZygote using ChainRulesCore: ChainRulesCore using ForwardDiff: ForwardDiff -using Lux: Lux -using MLDataDevices: get_device_type, CPUDevice using Setfield: @set! using Zygote: Zygote diff --git a/src/contrib/map.jl b/src/contrib/map.jl index 6cb6c6cf6f..df4ce7e248 100644 --- a/src/contrib/map.jl +++ b/src/contrib/map.jl @@ -39,7 +39,7 @@ true ``` """ macro layer_map(f, l, ps, st) - quote + return quote layer_map($(esc(f)), $(esc(l)), $(esc(ps)), $(esc(st)), $(Meta.quot(l))) end end diff --git a/src/helpers/nested_ad.jl b/src/helpers/nested_ad.jl deleted file mode 100644 index 4cf036a9d6..0000000000 --- a/src/helpers/nested_ad.jl +++ /dev/null @@ -1,203 +0,0 @@ -#! format: off -const AD_CONVERTIBLE_FUNCTIONS = [ - # Input Gradient/Jacobian - ComposedFunction{<:Any, <:StatefulLuxLayer}, - ComposedFunction{<:StatefulLuxLayer, <:Any}, - StatefulLuxLayer, - # Parameter Gradient/Jacobian - ComposedFunction{<:Any, <:Base.Fix1{<:StatefulLuxLayer}}, - ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:Any}, - Base.Fix1{<:StatefulLuxLayer} -] -#! format: on - -## Written like this to avoid dynamic dispatch from Zygote -# Input Gradient / Jacobian -@inline __rewrite_ad_call(f::ComposedFunction{F, <:StatefulLuxLayer}) where {F} = ( - f, f.inner.ps) -@inline __rewrite_ad_call(f::ComposedFunction{<:StatefulLuxLayer, F}) where {F} = ( - @closure((x, ps)->f.outer(f.inner(x), ps)), f.outer.ps) -@inline __rewrite_ad_call(f::StatefulLuxLayer) = f, f.ps - -# Parameter Gradient / Jacobian -@inline __rewrite_ad_call(f::ComposedFunction{F, <:Base.Fix1{<:StatefulLuxLayer}}) where {F} = ( - @closure((ps, x)->f.outer(f.inner.f(x, ps))), f.inner.x) -@inline __rewrite_ad_call(f::ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, F}) where {F} = ( - @closure((ps, x)->f.outer.f(x, f.inner(ps))), f.outer.x) -@inline __rewrite_ad_call(f::Base.Fix1{<:StatefulLuxLayer}) = ( - @closure((ps, x)->f.f(x, ps)), f.x) - -## Break ambiguity -for op in [ComposedFunction{<:StatefulLuxLayer, <:StatefulLuxLayer}, - ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:StatefulLuxLayer}, - ComposedFunction{<:StatefulLuxLayer, <:Base.Fix1{<:StatefulLuxLayer}}, - ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:Base.Fix1{<:StatefulLuxLayer}}] - @eval @inline function __rewrite_ad_call(::$op) - error("Cannot rewrite ComposedFunction with StatefulLuxLayer as inner and outer layers") - end -end - -# Nested Gradients -## Essentially computes the gradient of `f(x, y)` wrt x using the function `grad_fn` -## To compute the gradient of `f(x, y)` wrt y, just reorder the arguments with a wrapper -## over `f` -for fname in (:__internal_ad_gradient_call, :__internal_ad_gradient_call_no_custom_rrule) - @eval @inline function $fname(grad_fn::G, f::F, x, y) where {G, F} - return grad_fn(Base.Fix2(f, y), x) - end -end - -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__internal_ad_gradient_call), - grad_fn::G, f::F, x, y) where {G, F} - @static if !AUTOMATIC_NESTED_AD_SWITCHING - return CRC.rrule_via_ad( - cfg, __internal_ad_gradient_call_no_custom_rrule, grad_fn, f, x, y) - end - - res = __internal_ad_gradient_call(grad_fn, f, x, y) - ∇internal_gradient_capture = @closure Δ_ -> begin - (Δ_ isa NoTangent || Δ_ isa ZeroTangent) && return ntuple(Returns(NoTangent()), 5) - - Δ = CRC.unthunk(Δ_) - (res isa Tuple || Δ isa Tuple) && (Δ = only(Δ)) # For Zygote and such which return a tuple - ∂x, ∂y = __forwarddiff_jvp(@closure((x, y)->grad_fn(f, x, y)), x, Δ, y) - return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂y - end - - return res, ∇internal_gradient_capture -end - -# Nested Pullbacks -for fname in (:__internal_ad_pullback_call, :__internal_ad_pullback_call_no_custom_rrule) - @eval @inline function $fname(pullback_fn::P, f::F, x, y, u) where {P, F} - return only(last(pullback_fn(Base.Fix2(f, y), x))(u)) - end -end - -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__internal_ad_pullback_call), - pullback_fn::P, f::F, x, y, u) where {P, F} - @static if !AUTOMATIC_NESTED_AD_SWITCHING - return CRC.rrule_via_ad( - cfg, __internal_ad_pullback_call_no_custom_rrule, pullback_fn, f, x, y, u) - end - - res = __internal_ad_pullback_call(pullback_fn, f, x, y, u) - ∇nested_ad = let pullback_fn = pullback_fn, f = f, x = x, y = y, u = u, res = res - Δ_ -> begin - (Δ_ isa NoTangent || Δ_ isa ZeroTangent) && - return ntuple(Returns(NoTangent()), 6) - - Δ = CRC.unthunk(Δ_) - (res isa Tuple || Δ isa Tuple) && (Δ = only(Δ)) # For Zygote and such which return a tuple - ∂x, ∂y = __forwarddiff_jvp(x, Δ, y) do x_dual, y_ - return last(pullback_fn(f, x_dual, y_))(u) - end - return (NoTangent(), NoTangent(), NoTangent(), ∂x, ∂y, NoTangent()) - end - end - - return res, ∇nested_ad -end - -# Nested Jacobians -## `grad_fn` is not needed for the forward pass, we need it for the reverse pass HVP -for fname in (:__internal_ad_jacobian_call, :__internal_ad_jacobian_call_no_custom_rrule) - @eval @inline function $fname( - jac_fn::J, grad_fn::G, f::F, x::AbstractArray, y) where {J, G, F} - return jac_fn(Base.Fix2(f, y), x) - end -end - -function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(__internal_ad_jacobian_call), - jac_fn::J, grad_fn::G, f::F, x::AbstractArray, y) where {J, G, F} - @static if !AUTOMATIC_NESTED_AD_SWITCHING - return CRC.rrule_via_ad( - cfg, __internal_ad_jacobian_call_no_custom_rrule, jac_fn, grad_fn, f, x, y) - end - - res = __internal_ad_jacobian_call(jac_fn, grad_fn, f, x, y) - ∇internal_jacobian_capture = let res = res, grad_fn = grad_fn, f = f, x = x, y = y - Δ_ -> begin - (Δ_ isa NoTangent || Δ_ isa ZeroTangent) && - return ntuple(Returns(NoTangent()), 6) - - Δ = CRC.unthunk(Δ_) - (res isa Tuple || Δ isa Tuple) && (Δ = only(Δ)) # For Zygote and such which return a tuple - Δ = __compactify_if_structured_matrix(res isa Tuple ? only(res) : res, Δ) - - __inner_grad_fn = @closure(i->sum ∘ Base.Fix2(getindex, i:i) ∘ vec ∘ f) - map_fn = @closure i -> begin - Δᵢ = __maybe_batched_row(Δ, i) - fn = __inner_grad_fn(i) - __f = let fn = fn - (x, y) -> grad_fn(fn, x, y) - end - return __forwarddiff_jvp(__f, x, Δᵢ, y) - end - - # FIXME: threading on CUDA cause unexpected errors on the first run to CUDNN - # when doing a algorithm lookup - ∂x, ∂y = if get_device_type(x) <: CPUDevice - tasks = map(i -> Threads.@spawn(map_fn(i)), 1:__numrows(Δ)) - mapreduce(fetch, recursive_add!!, tasks) - else - mapreduce(map_fn, recursive_add!!, 1:__numrows(Δ)) - end - - return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), ∂x, ∂y) - end - end - - return res, ∇internal_jacobian_capture -end - -# Convert a structured Matrix to a General Matrix if it doesn't have fast scalar indexing -@inline function __compactify_if_structured_matrix( - J::AbstractArray{T1, N}, Δ::AbstractArray{T2}) where {T1, T2, N} - @argcheck N ∈ (2, 3) "Only 2D and 3D arrays are supported for compactifying." - if !ArrayInterface.fast_scalar_indexing(J) && ArrayInterface.isstructured(Δ) - J_ = similar(J) - copyto!(J_, Δ) - return J_ - end - return reshape(Δ, size(J)) -end - -@inline __numrows(x::AbstractMatrix) = size(x, 1) -@inline __numrows(x::AbstractArray{T, 3}) where {T} = size(x, 1) * size(x, 3) - -@inline __maybe_batched_row(x::AbstractMatrix, i::Integer) = view(x, i, :) -@inline function __maybe_batched_row(x::AbstractArray{T, 3}, i::Integer) where {T} - M, N, K = size(x) - k = (i - 1) ÷ M + 1 - i = mod1(i, M) - y = similar(x, N * K) - data = view(x, i, :, k) - fill!(view(y, 1:(N * (K - 1))), zero(T)) - copyto!(view(y, (N * (k - 1) + 1):(N * k)), data) - fill!(view(y, (N * k + 1):(N * K)), zero(T)) - return y -end - -@inline function __partials(::Type{Tag}, x, i) where {Tag} - x isa ForwardDiff.Dual && return ForwardDiff.partials(Tag, x, i) - if x isa AbstractArray - bfn(xᵢ, iᵢ) = ForwardDiff.partials(Tag, xᵢ, iᵢ) - return bfn.(x, i) - end - map_fn = @closure(xᵢ->__partials(Tag, xᵢ, i)) - (x isa Tuple || x isa NamedTuple) && return map(map_fn, x) - x isa CRC.AbstractTangent && return __partials(Tag, CRC.backing(x), i) - x === nothing && return nothing - return fmap(map_fn, x) -end - -@inline function __dualify(::Type{Tag}, ::Type{T}, x, u) where {Tag, T} - if x isa AbstractArray - bfn(xᵢ, uᵢ) = ForwardDiff.Dual{Tag, T, 1}(xᵢ, ForwardDiff.Partials{1, T}(uᵢ)) - return bfn.(x, tuple.(reshape(u, size(x)))) - end - (x isa Tuple || x isa NamedTuple) && - return map((xᵢ, uᵢ) -> __dualify(Tag, T, xᵢ, uᵢ), x, u) - return fmap((xᵢ, uᵢ) -> __dualify(Tag, T, xᵢ, uᵢ), x, u) -end diff --git a/test/contrib/share_parameters_tests.jl b/test/contrib/share_parameters_tests.jl index 385e8a4eae..874ddd2ffc 100644 --- a/test/contrib/share_parameters_tests.jl +++ b/test/contrib/share_parameters_tests.jl @@ -16,10 +16,8 @@ @test ps_1.d3.weight == ps_1.d2.l1.weight @test ps_1.d3.bias == ps_1.d2.l1.bias - ps_new_1 = (; weight=randn(rng, Float32, 4, 2), bias=randn(rng, Float32, 4)) |> - device - ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2)) |> - device + ps_new_1 = (; weight=randn(rng, Float32, 4, 2), bias=randn(rng, Float32, 4)) |> dev + ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2)) |> dev ps_2 = Lux.Experimental.share_parameters(ps, sharing, (ps_new_1, ps_new_2)) @@ -29,7 +27,7 @@ @test ps_2.d3.bias == ps_new_2.bias == ps_2.d2.l1.bias # Mix in ComponentArray - ps_new_ca_1 = ComponentArray(ps_new_1 |> CPUDevice()) |> device + ps_new_ca_1 = ComponentArray(ps_new_1 |> CPUDevice()) |> dev ps_3 = Lux.Experimental.share_parameters(ps, sharing, (ps_new_ca_1, ps_new_2)) @@ -46,15 +44,13 @@ ps, sharing, (ps_new_1,)) # Parameter Structure Mismatch - ps_new_1 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 4)) |> - device - ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2)) |> - device + ps_new_1 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 4)) |> dev + ps_new_2 = (; weight=randn(rng, Float32, 2, 4), bias=randn(rng, Float32, 2)) |> dev @test_throws ArgumentError Lux.Experimental.share_parameters( ps, sharing, (ps_new_1, ps_new_2)) - ps_new_ca_1 = ComponentArray(ps_new_1 |> CPUDevice()) |> device + ps_new_ca_1 = ComponentArray(ps_new_1 |> CPUDevice()) |> dev @test_throws ArgumentError Lux.Experimental.share_parameters( ps, sharing, (ps_new_ca_1, ps_new_2)) diff --git a/test/utils_tests.jl b/test/utils_tests.jl index 5086c5109c..5c6c74b750 100644 --- a/test/utils_tests.jl +++ b/test/utils_tests.jl @@ -127,7 +127,7 @@ end @testitem "FP Conversions" setup=[SharedTestSetup] tags=[:others] begin rng = StableRNG(12345) - @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "$mode" for (mode, aType, dev, ongpu) in MODES model = Chain( Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), BatchNorm(1)) From 221901a77cdc6532439ffebcc00456eb8a73f452 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 12:55:29 -0700 Subject: [PATCH 23/95] fix: recurrent bias flatten --- docs/src/.vitepress/config.mts | 2 +- ext/LuxReverseDiffExt/LuxReverseDiffExt.jl | 9 +++++---- src/Lux.jl | 1 - src/layers/recurrent.jl | 10 ++++------ 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 10948ced66..eca4ec34bd 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -6,7 +6,7 @@ import { transformerMetaWordHighlight } from '@shikijs/transformers'; // https://vitepress.dev/reference/site-config export default defineConfig({ - base: 'REPLACE_ME_DOCUMENTER_VITEPRESS',// TODO: replace this in makedocs! + base: 'REPLACE_ME_DOCUMENTER_VITEPRESS', title: 'REPLACE_ME_DOCUMENTER_VITEPRESS', description: 'Documentation for LuxDL Repositories', cleanUrls: true, diff --git a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl index fdd6462086..0295b4a286 100644 --- a/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl +++ b/ext/LuxReverseDiffExt/LuxReverseDiffExt.jl @@ -3,13 +3,14 @@ module LuxReverseDiffExt using ADTypes: ADTypes, AbstractADType, AutoReverseDiff using ArrayInterface: ArrayInterface using FunctionWrappers: FunctionWrapper -using Lux: Lux -using Lux.Training: TrainingBackendCache, TrainState -using LuxCore: LuxCore, AbstractExplicitLayer -using MLDataDevices: CPUDevice using ReverseDiff: ReverseDiff, ForwardExecutor, ReverseExecutor, TrackedArray, TrackedReal, @grad_from_chainrules +using Lux: Lux, Utils +using Lux.Training: TrainingBackendCache, TrainState +using LuxCore: LuxCore +using MLDataDevices: CPUDevice + include("utils.jl") include("rules.jl") include("training.jl") diff --git a/src/Lux.jl b/src/Lux.jl index db7b0af0b8..fba0853067 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -121,7 +121,6 @@ export LuxOps # Unexported functions that are part of the public API @compat public Experimental -@compat public xlogx, xlogy # TODO: deprecated in v1.0 @compat public set_dispatch_doctor_preferences! end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index fd207bcad0..3c609abeee 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -379,8 +379,7 @@ function initialparameters(rng::AbstractRNG, lstm::LSTMCell) for init_weight in lstm.init_weight]...) ps = (; weight_i, weight_h) if has_bias(lstm) - # TODO: in v1 we make this a flat vector - bias = vcat([init_bias(rng, lstm.out_dims, 1) for init_bias in lstm.init_bias]...) + bias = vcat([init_bias(rng, lstm.out_dims) for init_bias in lstm.init_bias]...) ps = merge(ps, (bias=bias,)) end has_train_state(lstm) && @@ -425,7 +424,7 @@ const _LSTMCellInputType = Tuple{ function (lstm::LSTMCell)( (x, (hidden_state, memory))::_LSTMCellInputType, ps, st::NamedTuple) y, hidden_stateₙ, memoryₙ = match_eltype(lstm, ps, st, x, hidden_state, memory) - bias = safe_vec(safe_getproperty(ps, Val(:bias))) + bias = safe_getproperty(ps, Val(:bias)) z = fused_dense_bias_activation(identity, ps.weight_h, hidden_stateₙ, bias) g = LuxLib.Impl.matmul(ps.weight_i, y) .+ z @@ -533,8 +532,7 @@ function initialparameters(rng::AbstractRNG, gru::GRUCell) ps = (; weight_i, weight_h) if has_bias(gru) bias_i = gru.init_bias[1](rng, gru.out_dims, 1) - # TODO: in v1 we make this a flat vector - bias_h = vcat([init_bias(rng, gru.out_dims, 1) for init_bias in gru.init_bias]...) + bias_h = vcat([init_bias(rng, gru.out_dims) for init_bias in gru.init_bias]...) ps = merge(ps, (bias_i=bias_i, bias_h=bias_h)) end has_train_state(gru) && @@ -561,7 +559,7 @@ const _GRUCellInputType = Tuple{<:AbstractMatrix, Tuple{<:AbstractMatrix}} function (gru::GRUCell)((x, (hidden_state,))::_GRUCellInputType, ps, st::NamedTuple) y, hidden_stateₙ = match_eltype(gru, ps, st, x, hidden_state) gxs = multigate(ps.weight_i * y, Val(3)) - bias_h = safe_vec(safe_getproperty(ps, Val(:bias_h))) + bias_h = safe_getproperty(ps, Val(:bias_h)) ghbs = multigate( fused_dense_bias_activation(identity, ps.weight_h, hidden_stateₙ, bias_h), Val(3)) From 506f5d201c8ed85471f130ef1b633070adba5abf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 13:06:11 -0700 Subject: [PATCH 24/95] chore: update to using LuxLib@1.0 --- Project.toml | 2 +- docs/Project.toml | 2 +- src/Lux.jl | 1 + test/Project.toml | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 0ea62fc479..68afdbdf86 100644 --- a/Project.toml +++ b/Project.toml @@ -89,7 +89,7 @@ GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" LossFunctions = "0.11.1" LuxCore = "0.1.24" -LuxLib = "0.3.42" +LuxLib = "1.0" MLDataDevices = "1.1" MLUtils = "0.4.4" MPI = "0.20.19" diff --git a/docs/Project.toml b/docs/Project.toml index a1bce8d80b..0c2b042fb8 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -41,7 +41,7 @@ Literate = "2.18.0" Lux = "1" LuxCUDA = "0.3.2" LuxCore = "0.1.15" -LuxLib = "0.3.42" +LuxLib = "1.0" LuxTestUtils = "1.1" MLDataDevices = "1.1" Optimisers = "0.3.3" diff --git a/src/Lux.jl b/src/Lux.jl index fba0853067..5bc40bb0d9 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -13,6 +13,7 @@ using Functors: Functors, fmap using GPUArraysCore: @allowscalar using LossFunctions: LossFunctions using Markdown: @doc_str +using NNlib: NNlib using Optimisers: Optimisers using Random: Random, AbstractRNG using Static: StaticBool, StaticInt, StaticSymbol, True, False, static, known, dynamic diff --git a/test/Project.toml b/test/Project.toml index b82a1f6640..e82d73a027 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -59,7 +59,7 @@ LinearAlgebra = "1.10" Logging = "1.10" LuxCore = "0.1.16" LuxDeviceUtils = "0.1.26" -LuxLib = "0.3.42" +LuxLib = "1.0" LuxTestUtils = "1.1.4" MLDataDevices = "1.1" MLUtils = "0.4.3" From ef18ced87db59495065e5651fb5645b210d9014c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 13:22:58 -0700 Subject: [PATCH 25/95] chore: update to using LuxCore@1.0 --- Project.toml | 2 +- docs/Project.toml | 2 +- docs/src/api/Building_Blocks/LuxCore.md | 5 +-- docs/src/api/Lux/interop.md | 2 +- docs/src/manual/dispatch_custom_input.md | 10 +++--- docs/src/manual/interface.md | 16 +++++++--- docs/src/manual/migrate_from_flux.md | 2 +- examples/HyperNet/main.jl | 4 +-- examples/NeuralODE/main.jl | 13 ++++---- examples/SimpleRNN/main.jl | 5 ++- src/Lux.jl | 6 ++-- src/contrib/contrib.jl | 3 +- src/contrib/debug.jl | 8 ++--- src/contrib/freeze.jl | 18 +++++------ src/contrib/map.jl | 6 ++-- src/custom_errors.jl | 2 +- src/helpers/compact.jl | 13 ++++---- src/helpers/losses.jl | 2 +- src/helpers/stateful.jl | 12 ++++---- src/helpers/training.jl | 7 ++--- src/layers/basic.jl | 20 ++++++------ src/layers/containers.jl | 39 ++++++++++++------------ src/layers/conv.jl | 20 ++++++------ src/layers/display.jl | 13 ++++---- src/layers/dropout.jl | 6 ++-- src/layers/extension.jl | 8 ++--- src/layers/normalize.jl | 16 +++++----- src/layers/recurrent.jl | 8 ++--- src/transform/simplechains.jl | 6 ++-- src/utils.jl | 4 +-- test/Project.toml | 2 +- test/contrib/map_tests.jl | 2 +- test/helpers/stateful_tests.jl | 2 +- test/helpers/training_tests.jl | 2 +- test/layers/containers_tests.jl | 2 +- 35 files changed, 147 insertions(+), 141 deletions(-) diff --git a/Project.toml b/Project.toml index 68afdbdf86..c982f9d77a 100644 --- a/Project.toml +++ b/Project.toml @@ -88,7 +88,7 @@ Functors = "0.4.12" GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" LossFunctions = "0.11.1" -LuxCore = "0.1.24" +LuxCore = "1.0" LuxLib = "1.0" MLDataDevices = "1.1" MLUtils = "0.4.4" diff --git a/docs/Project.toml b/docs/Project.toml index 0c2b042fb8..865bcada48 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -40,7 +40,7 @@ LinearAlgebra = "1.10" Literate = "2.18.0" Lux = "1" LuxCUDA = "0.3.2" -LuxCore = "0.1.15" +LuxCore = "1.0" LuxLib = "1.0" LuxTestUtils = "1.1" MLDataDevices = "1.1" diff --git a/docs/src/api/Building_Blocks/LuxCore.md b/docs/src/api/Building_Blocks/LuxCore.md index 894266ccde..54a4789fdd 100644 --- a/docs/src/api/Building_Blocks/LuxCore.md +++ b/docs/src/api/Building_Blocks/LuxCore.md @@ -14,8 +14,9 @@ Pages = ["LuxCore.md"] ## Abstract Types ```@docs -LuxCore.AbstractExplicitLayer -LuxCore.AbstractExplicitContainerLayer +LuxCore.AbstractLuxLayer +LuxCore.AbstractLuxWrapperLayer +LuxCore.AbstractLuxContainerLayer ``` ## General diff --git a/docs/src/api/Lux/interop.md b/docs/src/api/Lux/interop.md index 8dce085a3a..cf377bcbe6 100644 --- a/docs/src/api/Lux/interop.md +++ b/docs/src/api/Lux/interop.md @@ -37,7 +37,7 @@ preserving the [layer interface](@ref lux-interface). `using SimpleChains` must be present somewhere in the code for these to be used. ```@docs -Adapt.adapt(from::ToSimpleChainsAdaptor, L::AbstractExplicitLayer) +Adapt.adapt(from::ToSimpleChainsAdaptor, L::AbstractLuxLayer) ToSimpleChainsAdaptor SimpleChainsLayer ``` diff --git a/docs/src/manual/dispatch_custom_input.md b/docs/src/manual/dispatch_custom_input.md index 3495703d97..67d1a8b6af 100644 --- a/docs/src/manual/dispatch_custom_input.md +++ b/docs/src/manual/dispatch_custom_input.md @@ -5,10 +5,10 @@ * Defining a dispatch on `(::Layer)(x::MyInputType, ps, st::NamedTuple)` is inconvenient, since it requires the user to define a new method for every layer type. -* `(::AbstractExplicitLayer)(x::MyInputType, ps, st::NamedTuple)` doesn't work. +* `(::AbstractLuxLayer)(x::MyInputType, ps, st::NamedTuple)` doesn't work. * Instead, we need to define the dispatch on - `Lux.apply(::AbstractExplicitLayer, x::MyInputType, ps, st::NamedTuple)`. + `Lux.apply(::AbstractLuxLayer, x::MyInputType, ps, st::NamedTuple)`. ## Concrete Example @@ -22,7 +22,7 @@ define a time dependent version of [`Chain`](@ref). ```@example dispatch using Lux, Random -struct TDChain{L <: NamedTuple} <: Lux.AbstractExplicitContainerLayer{(:layers,)} +struct TDChain{L <: NamedTuple} <: Lux.AbstractLuxWrapperLayer{:layers} layers::L end @@ -66,10 +66,10 @@ struct ArrayAndTime{A <: AbstractArray, T <: Real} end ``` -* Define the dispatch on `Lux.apply(::AbstractExplicitLayer, x::ArrayAndTime, ps, st::NamedTuple)`. +* Define the dispatch on `Lux.apply(::AbstractLuxLayer, x::ArrayAndTime, ps, st::NamedTuple)`. ```@example dispatch -function Lux.apply(layer::Lux.AbstractExplicitLayer, x::ArrayAndTime, ps, st::NamedTuple) +function Lux.apply(layer::Lux.AbstractLuxLayer, x::ArrayAndTime, ps, st::NamedTuple) y, st = layer(x.array, ps, st) return ArrayAndTime(y, x.time), st end diff --git a/docs/src/manual/interface.md b/docs/src/manual/interface.md index 79d1db0f09..37e1cfb056 100644 --- a/docs/src/manual/interface.md +++ b/docs/src/manual/interface.md @@ -22,7 +22,7 @@ framework. ### Singular Layer If the layer doesn't contain any other Lux layer, then it is a `Singular Layer`. This means -it should optionally subtype `Lux.AbstractExplicitLayer` but mandatorily define +it should optionally subtype `Lux.AbstractLuxLayer` but mandatorily define all the necessary functions mentioned in the docstrings. Consider a simplified version of [`Dense`](@ref) called `Linear`. @@ -38,7 +38,7 @@ architecture cannot change. ```@example layer_interface using LuxCore, Random, WeightInitializers # Importing `Lux` also gives you access to `LuxCore` -struct Linear{F1, F2} <: LuxCore.AbstractExplicitLayer +struct Linear{F1, F2} <: LuxCore.AbstractLuxLayer in_dims::Int out_dims::Int init_weight::F1 @@ -120,13 +120,21 @@ LuxCore.apply(l, x, ps, st) # or `l(x, ps, st)` If your layer comprises of other Lux layers, then it is a `Container Layer`. Note that you could treat it as a [`Singular Layer`](#singular-layer), and it is still fine. FWIW, if you -cannot subtype your layer with `LuxCore.AbstractExplicitContainerLayer` then you +cannot subtype your layer with `LuxCore.AbstractLuxContainerLayer` then you should go down the [`Singular Layer`](#singular-layer) route. But subtyping allows us to bypass some of these common definitions. Let us now define a layer, which is basically a composition of two linear layers. +!!! tip "Wrapper Layer" + + If you are defining a layer that is a wrapper around another layer, then you should + subtype `LuxCore.AbstractLuxWrapperLayer` instead of + `LuxCore.AbstractLuxContainerLayer`. The only difference from a container layer is that + it can wrap a single layer and the parameter/state structure is exactly the same as the + wrapped layer. + ```@example layer_interface -struct ComposedLinear{L1, L2} <: LuxCore.AbstractExplicitContainerLayer{(:linear_1, :linear_2)} +struct ComposedLinear{L1, L2} <: LuxCore.AbstractLuxContainerLayer{(:linear_1, :linear_2)} linear_1::L1 linear_2::L2 end diff --git a/docs/src/manual/migrate_from_flux.md b/docs/src/manual/migrate_from_flux.md index 3b4e892323..a7f58ffdd8 100644 --- a/docs/src/manual/migrate_from_flux.md +++ b/docs/src/manual/migrate_from_flux.md @@ -62,7 +62,7 @@ trainable. ```julia [Lux] using Lux, Random, NNlib, Zygote -struct LuxLinear <: Lux.AbstractExplicitLayer +struct LuxLinear <: Lux.AbstractLuxLayer init_A init_B end diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index fef29796fd..f522f37f4a 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -24,8 +24,8 @@ function load_datasets(n_train=1024, n_eval=32, batchsize=256) end # ## Implement a HyperNet Layer -function HyperNet(weight_generator::Lux.AbstractExplicitLayer, - core_network::Lux.AbstractExplicitLayer) +function HyperNet( + weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer) ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |> ComponentArray |> getaxes diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index 55e0ae944d..97a201fd12 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -39,7 +39,7 @@ end # First we will use the [`@compact`](@ref) macro to define the Neural ODE Layer. function NeuralODECompact( - model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...) + model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...) return @compact(; model, solver, tspan, kwargs...) do x, p dudt(u, p, t) = vec(model(reshape(u, size(x)), p)) ## Note the `p.model` here @@ -54,8 +54,7 @@ end # The NeuralODE is a ContainerLayer, which stores a `model`. The parameters and states of # the NeuralODE are same as those of the underlying model. -struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <: - Lux.AbstractExplicitContainerLayer{(:model,)} +struct NeuralODE{M <: Lux.AbstractLuxLayer, So, T, K} <: Lux.AbstractLuxWrapperLayer{:model} model::M solver::So tspan::T @@ -63,7 +62,7 @@ struct NeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <: end function NeuralODE( - model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...) + model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...) return NeuralODE(model, solver, tspan, kwargs) end @@ -177,8 +176,8 @@ train(NeuralODE; sensealg=ReverseDiffAdjoint(), cpu=true) # Starting `v0.5.5`, Lux provides a [`StatefulLuxLayer`](@ref) which can be used # to avoid the [`Box`ing of `st`](https://github.com/JuliaLang/julia/issues/15276). Using # the `@compact` API avoids this problem entirely. -struct StatefulNeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <: - Lux.AbstractExplicitContainerLayer{(:model,)} +struct StatefulNeuralODE{M <: Lux.AbstractLuxLayer, So, T, K} <: + Lux.AbstractLuxWrapperLayer{:model} model::M solver::So tspan::T @@ -186,7 +185,7 @@ struct StatefulNeuralODE{M <: Lux.AbstractExplicitLayer, So, T, K} <: end function StatefulNeuralODE( - model::Lux.AbstractExplicitLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...) + model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...) return StatefulNeuralODE(model, solver, tspan, kwargs) end diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index 8e7dd45a18..f4d22071ad 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -42,7 +42,7 @@ end # ## Creating a Classifier -# We will be extending the `Lux.AbstractExplicitContainerLayer` type for our custom model +# We will be extending the `Lux.AbstractLuxContainerLayer` type for our custom model # since it will contain a lstm block and a classifier head. # We pass the fieldnames `lstm_cell` and `classifier` to the type to ensure that the @@ -52,8 +52,7 @@ end # To understand more about container layers, please look at # [Container Layer](@ref Container-Layer). -struct SpiralClassifier{L, C} <: - Lux.AbstractExplicitContainerLayer{(:lstm_cell, :classifier)} +struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:lstm_cell, :classifier)} lstm_cell::L classifier::C end diff --git a/src/Lux.jl b/src/Lux.jl index 5bc40bb0d9..fdad3172fe 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -21,9 +21,9 @@ using Reexport: Reexport, @reexport using Statistics: mean using UnrolledUtilities: unrolled_map, unrolled_mapreduce -import LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer, initialparameters, - initialstates, parameterlength, statelength, inputsize, outputsize, - update_state, trainmode, testmode, setup, apply, replicate +import LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer, + initialparameters, initialstates, parameterlength, statelength, inputsize, + outputsize, update_state, trainmode, testmode, setup, apply, replicate @reexport using LuxCore, LuxLib, MLDataDevices, WeightInitializers @eval Expr(:export, filter(x -> x !== :dropout, Reexport.exported_names(NNlib))...) diff --git a/src/contrib/contrib.jl b/src/contrib/contrib.jl index 902b56a67b..c97022e6c4 100644 --- a/src/contrib/contrib.jl +++ b/src/contrib/contrib.jl @@ -2,7 +2,8 @@ module Experimental using ..Lux: Lux, Optional using ..Utils: Utils, BoolType, SymbolType -using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, apply +using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxContainerLayer, + AbstractLuxWrapperLayer, apply using ADTypes: ADTypes using ArgCheck: @argcheck diff --git a/src/contrib/debug.jl b/src/contrib/debug.jl index 5e7fd3f78b..cc50dbf32c 100644 --- a/src/contrib/debug.jl +++ b/src/contrib/debug.jl @@ -2,7 +2,7 @@ DebugLayer(layer::AbstractExplicitLayer; nan_check::Union{Symbol, StaticSymbol, Val}=static(:both), error_check::Union{StaticBool, Bool, Val{true}, Val{false}}=True(), - location::Union{KeyPath, String}=KeyPath()) + location::KeyPath=KeyPath()) A wrapper over Lux layers that adds checks for NaNs and errors. This is useful for debugging. @@ -43,14 +43,14 @@ track where the error originates. See [`Lux.Experimental.@debug_mode`](@ref) to construct this layer. """ -@concrete struct DebugLayer <: AbstractExplicitContainerLayer{(:layer,)} +@concrete struct DebugLayer <: AbstractLuxWrapperLayer{:layer} nan_check <: StaticSymbol error_check <: StaticBool - layer <: AbstractExplicitLayer + layer <: AbstractLuxLayer location::KeyPath end -function DebugLayer(layer::AbstractExplicitLayer; nan_check::SymbolType=static(:both), +function DebugLayer(layer::AbstractLuxLayer; nan_check::SymbolType=static(:both), error_check::BoolType=True(), location::KeyPath=KeyPath()) @argcheck dynamic(nan_check) in (:both, :forward, :backward, :none) return DebugLayer(static(nan_check), static(error_check), layer, location) diff --git a/src/contrib/freeze.jl b/src/contrib/freeze.jl index 3b32d7c094..cc0cfac74a 100644 --- a/src/contrib/freeze.jl +++ b/src/contrib/freeze.jl @@ -1,5 +1,5 @@ """ - FrozenLayer(l::AbstractExplicitLayer, which_params::Optional{Tuple}) + FrozenLayer(l::AbstractLuxLayer, which_params::Optional{Tuple}) Freeze the parameters with name `which_params` of the layer `l`. @@ -16,7 +16,7 @@ Freeze the parameters with name `which_params` of the layer `l`. ## Arguments - - `l`: Lux AbstractExplicitLayer. + - `l`: Lux AbstractLuxLayer. - `which_params`: Parameter Names to be Frozen. Can be set to `nothing`, in which case all parameters are frozen. @@ -46,10 +46,10 @@ FrozenLayer(Dense(2 => 2), (:weight,)) # 2 parameters, plus 4 non-trainable See also [`Lux.Experimental.freeze`](@ref), [`Lux.Experimental.unfreeze`](@ref). """ -struct FrozenLayer{which_params, L <: AbstractExplicitLayer} <: AbstractExplicitLayer +struct FrozenLayer{which_params, L <: AbstractLuxLayer} <: AbstractLuxLayer layer::L - function FrozenLayer(l::AbstractExplicitLayer, which_params::Optional{Tuple}=nothing) + function FrozenLayer(l::AbstractLuxLayer, which_params::Optional{Tuple}=nothing) if which_params !== nothing && length(which_params) == 0 @warn "Layer `FrozenLayer($l, (,))` is same as `l`, returning `l`." return l @@ -92,24 +92,24 @@ function Base.show(io::IO, f::FrozenLayer{which_params}) where {which_params} end """ - freeze(l::AbstractExplicitLayer, which_params::Optional{Tuple} = nothing) + freeze(l::AbstractLuxLayer, which_params::Optional{Tuple} = nothing) Constructs a version of `l` with `which_params` frozen. If `which_params` is nothing, then all parameters are frozen. """ -function freeze(l::AbstractExplicitLayer, which_params::Optional{Tuple}=nothing) +function freeze(l::AbstractLuxLayer, which_params::Optional{Tuple}=nothing) return FrozenLayer(l, which_params) end """ - freeze(l::AbstractExplicitLayer, ps, st::NamedTuple, + freeze(l::AbstractLuxLayer, ps, st::NamedTuple, which_params::Optional{Tuple} = nothing) Construct a [`Lux.Experimental.FrozenLayer`](@ref) for `l` with the current parameters and states. If `which_params` is nothing, then all parameters are frozen. """ function freeze( - l::AbstractExplicitLayer, ps, st::NamedTuple, which_params::Optional{Tuple}=nothing) + l::AbstractLuxLayer, ps, st::NamedTuple, which_params::Optional{Tuple}=nothing) fl = freeze(l, which_params) ps_frozen = [] ps_trainable = [] @@ -137,6 +137,6 @@ unfreeze(l::FrozenLayer) = l.layer Unwraps a [`Lux.Experimental.FrozenLayer`](@ref) `l` with the current parameters and states. """ -function unfreeze(fl::AbstractExplicitLayer, ps, st::NamedTuple) +function unfreeze(fl::AbstractLuxLayer, ps, st::NamedTuple) return unfreeze(fl), merge(ps, st.frozen_params), st.states end diff --git a/src/contrib/map.jl b/src/contrib/map.jl index df4ce7e248..a8cc438f68 100644 --- a/src/contrib/map.jl +++ b/src/contrib/map.jl @@ -45,7 +45,7 @@ macro layer_map(f, l, ps, st) end @doc doc""" - layer_map(f::Function, l::AbstractExplicitLayer, ps, st::NamedTuple, + layer_map(f::Function, l::AbstractLuxLayer, ps, st::NamedTuple, name::Symbol=:model) Map the function `f` over the model `l`, with the parameters `ps` and states `st`. This is @@ -54,9 +54,9 @@ the function on all of them together. ## Call Signature for `f` - - Must take 4 inputs -- `AbstractExplicitLayer`, Corresponding Parameters, Corresponding + - Must take 4 inputs -- `AbstractLuxLayer`, Corresponding Parameters, Corresponding States, and the `Functors.KeyPath` to the layer. - - Must return a tuple of 3 elements -- `AbstractExplicitLayer`, new parameters and the new + - Must return a tuple of 3 elements -- `AbstractLuxLayer`, new parameters and the new states. !!! tip "Use `Lux.Experimental.@layer_map` instead" diff --git a/src/custom_errors.jl b/src/custom_errors.jl index bf267a70ae..7ef16a186b 100644 --- a/src/custom_errors.jl +++ b/src/custom_errors.jl @@ -16,7 +16,7 @@ struct SimpleChainsModelConversionException <: AbstractLuxException msg::String end -function SimpleChainsModelConversionException(layer::AbstractExplicitLayer) +function SimpleChainsModelConversionException(layer::AbstractLuxLayer) return SimpleChainsModelConversionException("Conversion to SimpleChains not supported \ for $(typeof(layer))") end diff --git a/src/helpers/compact.jl b/src/helpers/compact.jl index bbbb8544c2..f569cae811 100644 --- a/src/helpers/compact.jl +++ b/src/helpers/compact.jl @@ -293,7 +293,7 @@ macro non_trainable(x) end struct CompactLuxLayer{dispatch, F, N, L, V, SK} <: - AbstractExplicitContainerLayer{(:layers, :value_storage)} + AbstractLuxContainerLayer{(:layers, :value_storage)} d::StaticSymbol{dispatch} f::F name::N @@ -323,15 +323,14 @@ function CompactLuxLayer(dispatch::StaticSymbol, f::F, name::NAME_TYPE, setup_strings = NamedTuple() for (name, val) in pairs(kws) is_lux_layer = false - if val isa AbstractExplicitLayer + if val isa AbstractLuxLayer is_lux_layer = true push!(layers, name => val) elseif LuxCore.contains_lux_layer(val) # FIXME: This might lead to incorrect constructions? If the function is a # closure over the provided keyword arguments? val = CompactMacroImpl.try_make_lux_layer(val) - if LuxCore.check_fmap_condition( - !Base.Fix2(isa, AbstractExplicitLayer), nothing, val) + if LuxCore.check_fmap_condition(!Base.Fix2(isa, AbstractLuxLayer), nothing, val) throw(LuxCompactModelParsingException("A container `$(name) = $(val)` is \ found which combines Lux layers \ with non-Lux layers. This is not \ @@ -422,7 +421,7 @@ using MacroTools: MacroTools, @capture, combinedef, splitdef using Random: AbstractRNG using Static: static -using LuxCore: LuxCore, AbstractExplicitLayer +using LuxCore: LuxCore, AbstractLuxLayer using ..Lux: Lux, CompactLuxLayer, LuxCompactModelParsingException, StatefulLuxLayer, safe_getproperty @@ -585,7 +584,7 @@ end (f::InitFn)(args...) = f.f(args...) -@concrete struct ValueStorage <: AbstractExplicitLayer +@concrete struct ValueStorage <: AbstractLuxLayer ps_init_fns st_init_fns end @@ -657,7 +656,7 @@ function try_make_lux_layer(x::Union{AbstractVector, Tuple}) end try_make_lux_layer(x) = x -function maybe_make_stateful(layer::AbstractExplicitLayer, ps, st) +function maybe_make_stateful(layer::AbstractLuxLayer, ps, st) return StatefulLuxLayer{true}(layer, ps, st) end maybe_make_stateful(::Nothing, ::Nothing, st) = st diff --git a/src/helpers/losses.jl b/src/helpers/losses.jl index 0496f5b640..597f35fea1 100644 --- a/src/helpers/losses.jl +++ b/src/helpers/losses.jl @@ -120,7 +120,7 @@ end abstract type AbstractLossFunction <: Function end -function (loss::AbstractLossFunction)(model::AbstractExplicitLayer, ps, st, (x, y)) +function (loss::AbstractLossFunction)(model::AbstractLuxLayer, ps, st, (x, y)) ŷ, stₙ = model(x, ps, st) return loss(ŷ, y), stₙ, (;) end diff --git a/src/helpers/stateful.jl b/src/helpers/stateful.jl index 376c2db3ac..fffea8cf71 100644 --- a/src/helpers/stateful.jl +++ b/src/helpers/stateful.jl @@ -3,7 +3,7 @@ !!! warning - This is not a Lux.AbstractExplicitLayer + This is not a Lux.AbstractLuxLayer A convenience wrapper over Lux layers which stores the parameters and states internally. This is meant to be used in internal implementation of layers. @@ -40,7 +40,7 @@ This is meant to be used in internal implementation of layers. - `y`: The output of the layer """ -mutable struct StatefulLuxLayer{ST, M <: AbstractExplicitLayer, psType, stType} +mutable struct StatefulLuxLayer{ST, M <: AbstractLuxLayer, psType, stType} const model::M ps::psType st::stType @@ -58,10 +58,10 @@ function StatefulLuxLayer{ST}(model, ps, st, st_any) where {ST} return StatefulLuxLayer(model, ps, st, st_any, static(ST)) end -function StatefulLuxLayer{true}(model::AbstractExplicitLayer, ps, st::NamedTuple) +function StatefulLuxLayer{true}(model::AbstractLuxLayer, ps, st::NamedTuple) return StatefulLuxLayer{true}(model, ps, st, nothing) end -function StatefulLuxLayer{false}(model::AbstractExplicitLayer, ps, st::NamedTuple) +function StatefulLuxLayer{false}(model::AbstractLuxLayer, ps, st::NamedTuple) return StatefulLuxLayer{false}(model, ps, nothing, st) end @@ -121,8 +121,8 @@ function (s::StatefulLuxLayer)(x, p=s.ps) return y end -function CRC.rrule(::Type{<:StatefulLuxLayer{FT}}, - model::AbstractExplicitLayer, ps, st, st_any) where {FT} +function CRC.rrule( + ::Type{<:StatefulLuxLayer{FT}}, model::AbstractLuxLayer, ps, st, st_any) where {FT} slayer = StatefulLuxLayer{FT}(model, ps, st, st_any) ∇StatefulLuxLayer(Δ) = NoTangent(), NoTangent(), Δ.ps, NoTangent(), NoTangent() return slayer, ∇StatefulLuxLayer diff --git a/src/helpers/training.jl b/src/helpers/training.jl index 0d4bd7be12..134e30102e 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -7,7 +7,7 @@ using FastClosures: @closure using Optimisers: Optimisers using ..Lux: Lux -using LuxCore: LuxCore, AbstractExplicitLayer +using LuxCore: LuxCore, AbstractLuxLayer using Optimisers: Optimisers """ @@ -44,7 +44,7 @@ Internal fields: end """ - TrainState(model::Lux.AbstractExplicitLayer, ps, st, optimizer::Optimisers.AbstractRule) + TrainState(model::Lux.AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule) Constructor for [`TrainState`](@ref). @@ -62,8 +62,7 @@ Constructor for [`TrainState`](@ref). [`TrainState`](@ref) object. """ -function TrainState( - model::AbstractExplicitLayer, ps, st, optimizer::Optimisers.AbstractRule) +function TrainState(model::AbstractLuxLayer, ps, st, optimizer::Optimisers.AbstractRule) st_opt = Optimisers.setup(optimizer, ps) return TrainState(nothing, nothing, model, ps, st, optimizer, st_opt, 0) end diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e952969cad..fea71d0cdc 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -32,7 +32,7 @@ julia> y, st_new = model(x, ps, st); (2, 2, 3) ``` """ -struct ReshapeLayer{N} <: AbstractExplicitLayer +struct ReshapeLayer{N} <: AbstractLuxLayer dims::NTuple{N, Int} end @@ -81,7 +81,7 @@ julia> y, st_new = model(x, ps, st) ([3.0, 2.0, 1.0], NamedTuple()) ``` """ -@concrete struct ReverseSequence <: AbstractExplicitLayer +@concrete struct ReverseSequence <: AbstractLuxLayer dim <: Union{Nothing, StaticInt} end @@ -141,7 +141,7 @@ julia> y, st_new = model(x, ps, st); (8, 2) ``` """ -@concrete struct FlattenLayer <: AbstractExplicitLayer +@concrete struct FlattenLayer <: AbstractLuxLayer N <: Union{Nothing, StaticInt} end @@ -177,7 +177,7 @@ Return a view of all the data of the input `x` where the index for dimension `di - `view(x,:,:,...,i,:,:,...)` where `i` is in position `d` - Empty `NamedTuple()` """ -@concrete struct SelectDim <: AbstractExplicitLayer +@concrete struct SelectDim <: AbstractLuxLayer dim <: StaticInt index <: StaticInt end @@ -212,7 +212,7 @@ julia> y, st_new = model(x, ps, st) (1, NamedTuple()) ``` """ -struct NoOpLayer <: AbstractExplicitLayer end +struct NoOpLayer <: AbstractLuxLayer end (noop::NoOpLayer)(x, _, st::NamedTuple) = x, st @@ -238,7 +238,7 @@ be `Chain((x, ps, st) -> (relu.(x), st))`. An easier thing to do would be - Output of `f(x)` - Empty `NamedTuple()` """ -@concrete struct WrappedFunction <: AbstractExplicitLayer +@concrete struct WrappedFunction <: AbstractLuxLayer func <: Function end @@ -283,7 +283,7 @@ Create a traditional fully connected layer, whose forward pass is given by: - `weight`: Weight Matrix of size `(out_dims, in_dims)` - `bias`: Bias of size `(out_dims, 1)` (present if `use_bias=true`) """ -@concrete struct Dense <: AbstractExplicitLayer +@concrete struct Dense <: AbstractLuxLayer activation in_dims <: IntegerType out_dims <: IntegerType @@ -370,7 +370,7 @@ Elements are non-zero). The forward pass is given by: `y = activation.(weight .* - `weight`: Weight Array of size `(dims...)` - `bias`: Bias of size `(dims...)` """ -@concrete struct Scale{UB <: StaticBool} <: AbstractExplicitLayer +@concrete struct Scale{UB <: StaticBool} <: AbstractLuxLayer activation dims <: Tuple{Vararg{IntegerType}} init_weight @@ -472,7 +472,7 @@ with `B` the Bilinear layer. - `weight`: Weight Matrix of size `(out_dims, in1_dims, in2_dims)` - `bias`: Bias of size `(out_dims, 1)` (present if `use_bias=true`) """ -@concrete struct Bilinear <: AbstractExplicitLayer +@concrete struct Bilinear <: AbstractLuxLayer activation in1_dims <: IntegerType in2_dims <: IntegerType @@ -578,7 +578,7 @@ This layer is often used to store word embeddings and retrieve them using indice input, an N + 1 dimensional output is returned. - Empty `NamedTuple()` """ -@concrete struct Embedding <: AbstractExplicitLayer +@concrete struct Embedding <: AbstractLuxLayer in_dims <: Union{IntegerType, Tuple{Vararg{IntegerType}}} out_dims <: IntegerType init_weight diff --git a/src/layers/containers.jl b/src/layers/containers.jl index 20b605d36c..ae6f84404a 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -16,7 +16,7 @@ The simplest "ResNet"-type connection is just `SkipConnection(layer, +)`. - `connection`: + A 2-argument function that takes `layer(input)` and the input OR - + An AbstractExplicitLayer that takes `(layer(input), input)` as input + + An AbstractLuxLayer that takes `(layer(input), input)` as input # Extended Help @@ -32,18 +32,18 @@ The simplest "ResNet"-type connection is just `SkipConnection(layer, +)`. ## Parameters - Parameters of `layer` OR - - If `connection` is an AbstractExplicitLayer, then NamedTuple with fields `:layers` and + - If `connection` is an AbstractLuxLayer, then NamedTuple with fields `:layers` and `:connection` ## States - States of `layer` OR - - If `connection` is an AbstractExplicitLayer, then NamedTuple with fields `:layers` and + - If `connection` is an AbstractLuxLayer, then NamedTuple with fields `:layers` and `:connection` See [`Parallel`](@ref) for a more general implementation. """ -@concrete struct SkipConnection <: AbstractExplicitContainerLayer{(:layers,)} +@concrete struct SkipConnection <: AbstractLuxWrapperLayer{:layers} layers connection name @@ -60,13 +60,12 @@ function SkipConnection(; layers, connection, name::NAME_TYPE=nothing) end function initialparameters( - rng::AbstractRNG, l::SkipConnection{T, <:AbstractExplicitLayer}) where {T} + rng::AbstractRNG, l::SkipConnection{T, <:AbstractLuxLayer}) where {T} return (layers=initialparameters(rng, l.layers), connection=initialparameters(rng, l.connection)) end -function initialstates( - rng::AbstractRNG, l::SkipConnection{T, <:AbstractExplicitLayer}) where {T} +function initialstates(rng::AbstractRNG, l::SkipConnection{T, <:AbstractLuxLayer}) where {T} return ( layers=initialstates(rng, l.layers), connection=initialstates(rng, l.connection)) end @@ -76,7 +75,7 @@ function (skip::SkipConnection)(x, ps, st::NamedTuple) return skip.connection(mx, x), st end -function (skip::SkipConnection{<:AbstractExplicitLayer, <:AbstractExplicitLayer})( +function (skip::SkipConnection{<:AbstractLuxLayer, <:AbstractLuxLayer})( x, ps, st::NamedTuple) mx, st1 = apply(skip.layers, x, ps.layers, st.layers) y, st2 = apply(skip.connection, (mx, x), ps.connection, st.connection) @@ -147,7 +146,7 @@ julia> size.(first(model((x1, x2), ps, st))) ((1,), (1,)) ``` """ -@concrete struct Parallel <: AbstractExplicitContainerLayer{(:layers,)} +@concrete struct Parallel <: AbstractLuxWrapperLayer{:layers} connection layers <: NamedTuple name @@ -254,7 +253,7 @@ BranchLayer( # plus 0 states. ``` """ -@concrete struct BranchLayer <: AbstractExplicitContainerLayer{(:layers,)} +@concrete struct BranchLayer <: AbstractLuxWrapperLayer{:layers} layers <: NamedTuple name end @@ -297,7 +296,7 @@ x1 → layer1 → y1 ↘ - `connection`: Takes 2 inputs and combines them - - `layers`: `AbstractExplicitLayer`s. Layers can be specified in two formats: + - `layers`: `AbstractLuxLayer`s. Layers can be specified in two formats: + A list of `N` Lux layers + Specified as `N` keyword arguments. @@ -342,7 +341,7 @@ end - States of each `layer` wrapped in a NamedTuple with `fields = layer_1, layer_2, ..., layer_N` (naming changes if using the kwargs API) """ -@concrete struct PairwiseFusion <: AbstractExplicitContainerLayer{(:layers,)} +@concrete struct PairwiseFusion <: AbstractLuxWrapperLayer{:layers} connection layers <: NamedTuple name @@ -452,7 +451,7 @@ MyFancyChain( # plus 7 states. ``` """ -@concrete struct Chain <: AbstractExplicitContainerLayer{(:layers,)} +@concrete struct Chain <: AbstractLuxWrapperLayer{:layers} layers <: NamedTuple name end @@ -472,10 +471,10 @@ function wrap_functions_in_chain_call(layers::Union{AbstractVector, Tuple}) append!(new_layers, f) elseif f isa Function push!(new_layers, WrappedFunction(f)) - elseif f isa AbstractExplicitLayer + elseif f isa AbstractLuxLayer push!(new_layers, f) else - throw("Encountered a non-AbstractExplicitLayer in Chain.") + throw("Encountered a non-AbstractLuxLayer in Chain.") end end return layers isa AbstractVector ? new_layers : Tuple(new_layers) @@ -561,7 +560,7 @@ See also [`Parallel`](@ref) to reduce with other operators. [1] Goodfellow, Warde-Farley, Mirza, Courville & Bengio "Maxout Networks" [https://arxiv.org/abs/1302.4389](https://arxiv.org/abs/1302.4389) """ -@concrete struct Maxout <: AbstractExplicitContainerLayer{(:layers,)} +@concrete struct Maxout <: AbstractLuxWrapperLayer{:layers} layers <: NamedTuple end @@ -616,7 +615,7 @@ times for gradients might be unreasonably high. ## Arguments - - `model` must be an `AbstractExplicitLayer` + - `model` must be an `AbstractLuxLayer` ## Keyword Arguments @@ -643,10 +642,10 @@ times for gradients might be unreasonably high. - State of `model` """ -@concrete struct RepeatedLayer <: AbstractExplicitContainerLayer{(:model,)} +@concrete struct RepeatedLayer <: AbstractLuxWrapperLayer{:model} nrepeats <: StaticInt input_injection <: StaticBool - model <: AbstractExplicitLayer + model <: AbstractLuxLayer end function LuxCore.display_name(r::RepeatedLayer) @@ -655,7 +654,7 @@ function LuxCore.display_name(r::RepeatedLayer) end function RepeatedLayer( - model::AbstractExplicitLayer; repeats::Union{StaticInt, Integer, Val}=Val(10), + model::AbstractLuxLayer; repeats::Union{StaticInt, Integer, Val}=Val(10), input_injection::Union{StaticBool, Bool, Val{true}, Val{false}}=Val(false)) return RepeatedLayer(static(repeats), static(input_injection), model) end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 44ce03fa67..5e4290621b 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -139,7 +139,7 @@ O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s - `weight`: Convolution kernel - `bias`: Bias (present if `use_bias=true`) """ -@concrete struct Conv <: AbstractExplicitLayer +@concrete struct Conv <: AbstractLuxLayer activation in_chs <: IntegerType out_chs <: IntegerType @@ -243,7 +243,7 @@ value. See also [`Conv`](@ref), [`MeanPool`](@ref), [`GlobalMaxPool`](@ref), [`AdaptiveMaxPool`](@ref) """ -@concrete struct MaxPool <: AbstractExplicitLayer +@concrete struct MaxPool <: AbstractLuxLayer k <: Tuple{Vararg{IntegerType}} pad <: Tuple{Vararg{IntegerType}} stride <: Tuple{Vararg{IntegerType}} @@ -313,7 +313,7 @@ value. See also [`Conv`](@ref), [`MaxPool`](@ref), [`GlobalMeanPool`](@ref), [`AdaptiveMeanPool`](@ref) """ -@concrete struct MeanPool <: AbstractExplicitLayer +@concrete struct MeanPool <: AbstractLuxLayer k <: Tuple{Vararg{IntegerType}} pad <: Tuple{Vararg{IntegerType}} stride <: Tuple{Vararg{IntegerType}} @@ -385,7 +385,7 @@ Currently supported upsampling `mode`s and corresponding NNlib's methods are: - Upsampled Input of size `size` or of size `(I_1 x scale[1], ..., I_N x scale[N], C, N)` - Empty `NamedTuple()` """ -@concrete struct Upsample <: AbstractExplicitLayer +@concrete struct Upsample <: AbstractLuxLayer scale size upsample_mode <: StaticSymbol @@ -448,7 +448,7 @@ by performing max pooling on the complete (w,h)-shaped feature maps. See also [`MaxPool`](@ref), [`AdaptiveMaxPool`](@ref), [`GlobalMeanPool`](@ref) """ -struct GlobalMaxPool <: AbstractExplicitLayer end +struct GlobalMaxPool <: AbstractLuxLayer end function (g::GlobalMaxPool)(x, _, st::NamedTuple) return maxpool(x, PoolDims(x, size(x)[1:(end - 2)])), st @@ -471,7 +471,7 @@ by performing mean pooling on the complete (w,h)-shaped feature maps. See also [`MeanPool`](@ref), [`AdaptiveMeanPool`](@ref), [`GlobalMaxPool`](@ref) """ -struct GlobalMeanPool <: AbstractExplicitLayer end +struct GlobalMeanPool <: AbstractLuxLayer end function (g::GlobalMeanPool)(x, _, st::NamedTuple) return meanpool(x, PoolDims(x, size(x)[1:(end - 2)])), st @@ -499,7 +499,7 @@ Adaptive Max Pooling layer. Calculates the necessary window size such that its o See also [`MaxPool`](@ref), [`AdaptiveMeanPool`](@ref). """ -struct AdaptiveMaxPool{S, O <: Tuple{Vararg{IntegerType}}} <: AbstractExplicitLayer +struct AdaptiveMaxPool{S, O <: Tuple{Vararg{IntegerType}}} <: AbstractLuxLayer out::O AdaptiveMaxPool(out) = new{length(out) + 2, typeof(out)}(out) end @@ -532,7 +532,7 @@ Adaptive Mean Pooling layer. Calculates the necessary window size such that its See also [`MeanPool`](@ref), [`AdaptiveMaxPool`](@ref). """ -struct AdaptiveMeanPool{S, O <: Tuple{Vararg{IntegerType}}} <: AbstractExplicitLayer +struct AdaptiveMeanPool{S, O <: Tuple{Vararg{IntegerType}}} <: AbstractLuxLayer out::O AdaptiveMeanPool(out) = new{length(out) + 2, typeof(out)}(out) end @@ -643,7 +643,7 @@ O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s - `weight`: Convolution kernel - `bias`: Bias (present if `use_bias=true`) """ -@concrete struct CrossCor <: AbstractExplicitLayer +@concrete struct CrossCor <: AbstractLuxLayer activation in_chs <: IntegerType out_chs <: IntegerType @@ -762,7 +762,7 @@ Standard convolutional transpose layer. - `weight`: Convolution Transpose kernel - `bias`: Bias (present if `use_bias=true`) """ -@concrete struct ConvTranspose <: AbstractExplicitLayer +@concrete struct ConvTranspose <: AbstractLuxLayer activation in_chs <: IntegerType out_chs <: IntegerType diff --git a/src/layers/display.jl b/src/layers/display.jl index 5c2efc1769..48e09f0f3d 100644 --- a/src/layers/display.jl +++ b/src/layers/display.jl @@ -1,10 +1,10 @@ module PrettyPrinting using Functors: Functors -using LuxCore: LuxCore, AbstractExplicitContainerLayer, AbstractExplicitLayer, display_name +using LuxCore: LuxCore, AbstractLuxContainerLayer, AbstractLuxLayer, display_name printable_children(x) = Functors.children(x) -function printable_children(m::AbstractExplicitContainerLayer{layers}) where {layers} +function printable_children(m::AbstractLuxContainerLayer{layers}) where {layers} children = Functors.children(m) length(layers) ≥ 2 && return children field = first(layers) @@ -15,7 +15,7 @@ function printable_children(m::AbstractExplicitContainerLayer{layers}) where {la end show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for: -show_leaflike(x::AbstractExplicitLayer) = false +show_leaflike(x::AbstractLuxLayer) = false function underscorise(n::Integer) return join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_') @@ -75,7 +75,7 @@ function show_parameters_count(io::IO, layer, indent, str::String) return end -function print_wrapper_model(io::IO, desc::String, model::AbstractExplicitLayer) +function print_wrapper_model(io::IO, desc::String, model::AbstractLuxLayer) if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL print(io, desc, "(\n") big_show(io, model, 4) @@ -96,7 +96,8 @@ tuple_string(pad::Tuple) = all(==(pad[1]), pad) ? string(pad[1]) : string(pad) end -function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitContainerLayer) +function Base.show(io::IO, ::MIME"text/plain", + x::Union{AbstractLuxContainerLayer, AbstractLuxWrapperLayer}) if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL PrettyPrinting.big_show(io, x) elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix @@ -106,7 +107,7 @@ function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitContainerLayer end end -function Base.show(io::IO, ::MIME"text/plain", x::AbstractExplicitLayer) +function Base.show(io::IO, ::MIME"text/plain", x::AbstractLuxLayer) !get(io, :compact, false) && return PrettyPrinting.layer_show(io, x) show(io, x) end diff --git a/src/layers/dropout.jl b/src/layers/dropout.jl index e1d1ffa441..a7ef56400f 100644 --- a/src/layers/dropout.jl +++ b/src/layers/dropout.jl @@ -28,7 +28,7 @@ Call [`Lux.testmode`](@ref) to switch to test mode. See also [`Dropout`](@ref), [`VariationalHiddenDropout`](@ref) """ -struct AlphaDropout{T <: Real} <: AbstractExplicitLayer +struct AlphaDropout{T <: Real} <: AbstractLuxLayer p::T alpha::T scale::T @@ -90,7 +90,7 @@ Call [`Lux.testmode`](@ref) to switch to test mode. See also [`AlphaDropout`](@ref), [`VariationalHiddenDropout`](@ref) """ -@concrete struct Dropout{T} <: AbstractExplicitLayer +@concrete struct Dropout{T} <: AbstractLuxLayer p::T q::T dims @@ -154,7 +154,7 @@ Call [`Lux.testmode`](@ref) to switch to test mode. See also [`AlphaDropout`](@ref), [`Dropout`](@ref) """ -@concrete struct VariationalHiddenDropout{T} <: AbstractExplicitLayer +@concrete struct VariationalHiddenDropout{T} <: AbstractLuxLayer p::T q::T dims diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 9e1d1fd7de..feca2e9f0c 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -144,7 +144,7 @@ API internally. - `p`: Flattened parameters of the `layer` """ -@concrete struct FluxLayer <: AbstractExplicitLayer +@concrete struct FluxLayer <: AbstractLuxLayer layer re <: Optimisers.Restructure init_parameters @@ -171,7 +171,7 @@ Base.show(io::IO, ::MIME"text/plain", l::FluxLayer) = print(io, "FluxLayer($(l.l SimpleChainsLayer(layer, ToArray::Union{Bool, Val}=Val(false)) Wraps a `SimpleChains` layer into a `Lux` layer. All operations are performed using -`SimpleChains` but the layer satisfies the `AbstractExplicitLayer` interface. +`SimpleChains` but the layer satisfies the `AbstractLuxLayer` interface. `ToArray` is a boolean flag that determines whether the output should be converted to a regular `Array` or not. Default is `false`. @@ -181,8 +181,8 @@ regular `Array` or not. Default is `false`. - `layer`: SimpleChains layer - `lux_layer`: Potentially equivalent Lux layer that is used for printing """ -struct SimpleChainsLayer{ToArray, SL, LL <: Union{Nothing, AbstractExplicitLayer}} <: - AbstractExplicitLayer +struct SimpleChainsLayer{ToArray, SL, LL <: Union{Nothing, AbstractLuxLayer}} <: + AbstractLuxLayer to_array::ToArray layer::SL lux_layer::LL diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index dc7de1252c..6cfc28c970 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -86,7 +86,7 @@ Chain( See also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`LayerNorm`](@ref), [`WeightNorm`](@ref) """ -@concrete struct BatchNorm{N} <: AbstractExplicitLayer +@concrete struct BatchNorm{N} <: AbstractLuxLayer activation epsilon::N momentum::N @@ -222,7 +222,7 @@ Chain( See also [`GroupNorm`](@ref), [`InstanceNorm`](@ref), [`LayerNorm`](@ref), [`WeightNorm`](@ref) """ -@concrete struct GroupNorm <: AbstractExplicitLayer +@concrete struct GroupNorm <: AbstractLuxLayer activation epsilon chs <: IntegerType @@ -336,7 +336,7 @@ Chain( See also [`BatchNorm`](@ref), [`GroupNorm`](@ref), [`LayerNorm`](@ref), [`WeightNorm`](@ref) """ -@concrete struct InstanceNorm <: AbstractExplicitLayer +@concrete struct InstanceNorm <: AbstractLuxLayer activation epsilon chs <: IntegerType @@ -376,7 +376,7 @@ function Base.show(io::IO, l::InstanceNorm) end @doc doc""" - WeightNorm(layer::AbstractExplicitLayer, which_params::NTuple{N, Symbol}, + WeightNorm(layer::AbstractLuxLayer, which_params::NTuple{N, Symbol}, dims::Union{Tuple, Nothing}=nothing) Applies [weight normalization](https://arxiv.org/abs/1602.07868) to a parameter in the given @@ -416,13 +416,13 @@ parameters: one specifying the magnitude (e.g. `weight_g`) and one specifying th - Same as that of `layer` """ -@concrete struct WeightNorm <: AbstractExplicitLayer - layer <: AbstractExplicitLayer +@concrete struct WeightNorm <: AbstractLuxLayer + layer <: AbstractLuxLayer which_params dims function WeightNorm( - layer::AbstractExplicitLayer, which_params, dims::Union{Tuple, Nothing}=nothing) + layer::AbstractLuxLayer, which_params, dims::Union{Tuple, Nothing}=nothing) which_params = static(which_params) dims = static(dims) return new{typeof(layer), typeof(which_params), typeof(dims)}( @@ -557,7 +557,7 @@ where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`. + `bias`: Bias of shape `(shape..., 1)` + `scale`: Scale of shape `(shape..., 1)` """ -@concrete struct LayerNorm <: AbstractExplicitLayer +@concrete struct LayerNorm <: AbstractLuxLayer shape activation epsilon diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 3c609abeee..fbde8afe49 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,4 +1,4 @@ -abstract type AbstractRecurrentCell <: AbstractExplicitLayer end +abstract type AbstractRecurrentCell <: AbstractLuxLayer end const AbstractDebugRecurrentCell = Experimental.DebugLayer{ <:Any, <:Any, <:AbstractRecurrentCell} @@ -85,7 +85,7 @@ automatically operate over a sequence of inputs. For some discussion on this topic, see https://github.com/LuxDL/Lux.jl/issues/472. """ -@concrete struct Recurrence{R <: StaticBool} <: AbstractExplicitContainerLayer{(:cell,)} +@concrete struct Recurrence{R <: StaticBool} <: AbstractLuxWrapperLayer{:cell} cell <: Union{<:AbstractRecurrentCell, <:AbstractDebugRecurrentCell} ordering <: AbstractTimeSeriesDataBatchOrdering return_sequence::R @@ -151,7 +151,7 @@ update the state with `Lux.update_state(st, :carry, nothing)`. + `cell`: Same as `cell`. + `carry`: The carry state of the `cell`. """ -@concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)} +@concrete struct StatefulRecurrentCell <: AbstractLuxWrapperLayer{:cell} cell <: Union{<:AbstractRecurrentCell, <:AbstractDebugRecurrentCell} end @@ -629,7 +629,7 @@ Bidirectional RNN wrapper. - Same as `cell` and `backward_cell`. """ -@concrete struct BidirectionalRNN <: AbstractExplicitContainerLayer{(:model,)} +@concrete struct BidirectionalRNN <: AbstractLuxWrapperLayer{:model} model <: Parallel end diff --git a/src/transform/simplechains.jl b/src/transform/simplechains.jl index d224fcd64b..f6e2ecb7e9 100644 --- a/src/transform/simplechains.jl +++ b/src/transform/simplechains.jl @@ -2,7 +2,7 @@ ToSimpleChainsAdaptor(input_dims, convert_to_array::Bool=false) Adaptor for converting a Lux Model to SimpleChains. The returned model is still a Lux model, -and satisfies the `AbstractExplicitLayer` interfacem but all internal calculations are +and satisfies the `AbstractLuxLayer` interfacem but all internal calculations are performed using SimpleChains. !!! warning @@ -59,12 +59,12 @@ struct ToSimpleChainsAdaptor{ID, AT} <: AbstractFromLuxAdaptor end """ - Adapt.adapt(from::ToSimpleChainsAdaptor, L::AbstractExplicitLayer) + Adapt.adapt(from::ToSimpleChainsAdaptor, L::AbstractLuxLayer) Adapt a Simple Chains model to Lux model. See [`ToSimpleChainsAdaptor`](@ref) for more details. """ -function Adapt.adapt(to::ToSimpleChainsAdaptor, L::AbstractExplicitLayer) +function Adapt.adapt(to::ToSimpleChainsAdaptor, L::AbstractLuxLayer) if Base.get_extension(@__MODULE__, :LuxSimpleChainsExt) === nothing error("`ToSimpleChainsAdaptor` requires `SimpleChains.jl` to be loaded.") end diff --git a/src/utils.jl b/src/utils.jl index 4009c9b115..4cea362907 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,7 +10,7 @@ using Functors: fmapstructure using Random: AbstractRNG using Static: Static, StaticBool, StaticInteger, StaticSymbol -using LuxCore: LuxCore, AbstractExplicitLayer +using LuxCore: LuxCore, AbstractLuxLayer using MLDataDevices: get_device const CRC = ChainRulesCore @@ -189,7 +189,7 @@ set_refval!(x, y) = (x[] = y) @non_differentiable set_refval!(::Any...) EnzymeRules.inactive(::typeof(set_refval!), ::Any...) = nothing -function named_tuple_layers(layers::Vararg{AbstractExplicitLayer, N}) where {N} +function named_tuple_layers(layers::Vararg{AbstractLuxLayer, N}) where {N} return NamedTuple{ntuple(i -> Symbol(:layer_, i), N)}(layers) end diff --git a/test/Project.toml b/test/Project.toml index e82d73a027..c35fd43954 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -57,7 +57,7 @@ Hwloc = "3.2.0" InteractiveUtils = "<0.0.1, 1" LinearAlgebra = "1.10" Logging = "1.10" -LuxCore = "0.1.16" +LuxCore = "1.0" LuxDeviceUtils = "0.1.26" LuxLib = "1.0" LuxTestUtils = "1.1.4" diff --git a/test/contrib/map_tests.jl b/test/contrib/map_tests.jl index 842cd18581..401afbc8d5 100644 --- a/test/contrib/map_tests.jl +++ b/test/contrib/map_tests.jl @@ -57,7 +57,7 @@ @test all(iszero, ps_.dense_3.bias) # Custom Layers -- See https://github.com/LuxDL/Lux.jl/issues/187 - struct SimpleCustom{L1, L2} <: Lux.AbstractExplicitContainerLayer{(:dense, :conv)} + struct SimpleCustom{L1, L2} <: Lux.AbstractLuxContainerLayer{(:dense, :conv)} dense::L1 conv::L2 end diff --git a/test/helpers/stateful_tests.jl b/test/helpers/stateful_tests.jl index 8392182a3d..cc2b4e4afb 100644 --- a/test/helpers/stateful_tests.jl +++ b/test/helpers/stateful_tests.jl @@ -3,7 +3,7 @@ rng = StableRNG(12345) - struct NotFixedStateModel <: Lux.AbstractExplicitLayer end + struct NotFixedStateModel <: Lux.AbstractLuxLayer end (m::NotFixedStateModel)(x, ps, st) = (x, (; s=1)) diff --git a/test/helpers/training_tests.jl b/test/helpers/training_tests.jl index b44a7cdd59..09c24be72b 100644 --- a/test/helpers/training_tests.jl +++ b/test/helpers/training_tests.jl @@ -236,7 +236,7 @@ end @test_throws ArgumentError Training.compute_gradients( AutoReverseDiff(; compile=true), mse2, dataset[1], tstate) - struct StrangeModel <: Lux.AbstractExplicitLayer end + struct StrangeModel <: Lux.AbstractLuxLayer end function (m::StrangeModel)(x, ps, st) return x, (; new_state=0.0) diff --git a/test/layers/containers_tests.jl b/test/layers/containers_tests.jl index cbd425f7f3..218b1a7bf7 100644 --- a/test/layers/containers_tests.jl +++ b/test/layers/containers_tests.jl @@ -131,7 +131,7 @@ end x::X end - struct L1 <: Lux.AbstractExplicitLayer end + struct L1 <: Lux.AbstractLuxLayer end (::L1)(x, ps, st) = (ps.x * x, st) Lux.initialparameters(rng::AbstractRNG, ::L1) = (x=randn(rng, Float32, 3, 3),) Base.:*(a::AbstractArray, b::Input) = a * b.x From f0c737400aedafd04b6ee37d5331f1a883062811 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 13:26:13 -0700 Subject: [PATCH 26/95] fix: broken tests --- src/layers/containers.jl | 5 ----- test/contrib/debug_tests.jl | 4 ++-- test/helpers/training_tests.jl | 12 ++++++------ test/layers/basic_tests.jl | 2 +- test/transform/simple_chains_tests.jl | 6 +++--- 5 files changed, 12 insertions(+), 17 deletions(-) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index ae6f84404a..5450376077 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -399,11 +399,6 @@ Collects multiple layers / functions to be called in sequence on a given input. + A list of `N` Lux layers + Specified as `N` keyword arguments. -## Keyword Arguments - - - `disable_optimizations`: Prevents any structural optimization - - `name`: Name of the layer (optional) - # Extended Help ## Inputs diff --git a/test/contrib/debug_tests.jl b/test/contrib/debug_tests.jl index b06038af83..2053c1189a 100644 --- a/test/contrib/debug_tests.jl +++ b/test/contrib/debug_tests.jl @@ -61,8 +61,8 @@ end end @testset "$mode: NaN Debugging" for (mode, aType, dev, ongpu) in MODES - model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), - BatchNorm(1); disable_optimizations=true) + model = Chain( + Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)), BatchNorm(1)) x = randn(rng, Float32, 1, 5) |> aType ps, st = Lux.setup(rng, model) |> dev diff --git a/test/helpers/training_tests.jl b/test/helpers/training_tests.jl index 09c24be72b..4d1cd0ccd0 100644 --- a/test/helpers/training_tests.jl +++ b/test/helpers/training_tests.jl @@ -146,9 +146,9 @@ end x = randn(rng, Float32, 4, 32) opt = Adam(0.001f0) - tstate = Lux.Experimental.TrainState(model, ps, st, opt) + tstate = Training.TrainState(model, ps, st, opt) - _, _, _, tstate_new = @inferred Lux.Experimental.compute_gradients( + _, _, _, tstate_new = @inferred Training.compute_gradients( AutoEnzyme(), mse, (x, x), tstate) @test tstate_new.states !== tstate.states @@ -156,15 +156,15 @@ end model = Chain(Dense(4 => 3), Dense(3 => 4)) ps, st = Lux.setup(rng, model) - tstate = Lux.Experimental.TrainState(model, ps, st, opt) + tstate = Training.TrainState(model, ps, st, opt) - _, _, _, tstate_new = @inferred Lux.Experimental.compute_gradients( + _, _, _, tstate_new = @inferred Training.compute_gradients( AutoEnzyme(), mse, (x, x), tstate) - @test @inferred(Lux.Experimental.compute_gradients( + @test @inferred(Training.compute_gradients( AutoEnzyme(), mse, (x, x), tstate_new)) isa Any - _, _, _, tstate_new2 = @inferred Lux.Experimental.compute_gradients( + _, _, _, tstate_new2 = @inferred Training.compute_gradients( AutoEnzyme(), mse2, (x, x), tstate_new) @test hasfield(typeof(tstate_new2.cache.extras), :forward) @test hasfield(typeof(tstate_new2.cache.extras), :reverse) diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 15982a27c4..fda80d6a65 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -102,7 +102,7 @@ @jet layer(x, ps, st) __f = x -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x gpu_testing=$ongpu atol=1.0f-3 rtol=1.0f-3 + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) end @testset "PeriodicEmbedding" begin diff --git a/test/transform/simple_chains_tests.jl b/test/transform/simple_chains_tests.jl index 3c20563195..97428139df 100644 --- a/test/transform/simple_chains_tests.jl +++ b/test/transform/simple_chains_tests.jl @@ -103,9 +103,9 @@ @test length(gs[2].params) == length(ps.params) # See https://github.com/LuxDL/Lux.jl/issues/644 - test_gradients( - __f, x, ps; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoEnzyme()], - soft_fail=[AutoForwardDiff(), AutoFiniteDiff(), AutoTracker()]) + test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, + broken_backends=[AutoEnzyme(), AutoTracker()], + soft_fail=[AutoForwardDiff(), AutoFiniteDiff()]) end end From f4d3fc8f0a8aed55f9fc11ef6b6de12d8f1c5ce3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 13:29:00 -0700 Subject: [PATCH 27/95] chore: remove all references to LuxDeviceUtils --- docs/src/manual/migrate_from_flux.md | 6 +++--- test/Project.toml | 2 -- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/src/manual/migrate_from_flux.md b/docs/src/manual/migrate_from_flux.md index a7f58ffdd8..7eb4352f1f 100644 --- a/docs/src/manual/migrate_from_flux.md +++ b/docs/src/manual/migrate_from_flux.md @@ -49,9 +49,9 @@ should be implemented. A summary of the differences would be: * Lux relies on the user to define `Lux.initialparameters` and `Lux.initialstates` to distinguish between trainable parameters (called "parameters") and non-trainable parameters (called "states"). Additionally, Lux layers define the model architecture, - hence device transfer utilities like [`LuxDeviceUtils.gpu_device`](@ref), - [`LuxDeviceUtils.cpu_device`](@ref), etc. cannot be applied on Lux layers, instead they - need to be applied on the parameters and states. + hence device transfer utilities like [`gpu_device`](@ref), [`cpu_device`](@ref), etc. + cannot be applied on Lux layers, instead they need to be applied on the parameters and + states. Let's work through a concrete example to demonstrate this. We will implement a very simple layer that computes ``A \times B \times x`` where ``A`` is not trainable and ``B`` is diff --git a/test/Project.toml b/test/Project.toml index c35fd43954..6775daa6de 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -17,7 +17,6 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" @@ -58,7 +57,6 @@ InteractiveUtils = "<0.0.1, 1" LinearAlgebra = "1.10" Logging = "1.10" LuxCore = "1.0" -LuxDeviceUtils = "0.1.26" LuxLib = "1.0" LuxTestUtils = "1.1.4" MLDataDevices = "1.1" From 947817bff549ab1fb95e4f494fae9f64d3be7263 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 16:34:15 -0700 Subject: [PATCH 28/95] feat: define fallback `outputsize` --- docs/src/api/Building_Blocks/LuxCore.md | 6 --- src/Lux.jl | 2 +- src/helpers/size_propagator.jl | 21 +--------- src/layers/basic.jl | 10 ++--- src/layers/containers.jl | 2 - test/helpers/size_propagator_test.jl | 49 +++++++++--------------- test/helpers/size_propagator_tests.jl | 51 +++++++++++++++++++++++++ test/layers/basic_tests.jl | 12 +++--- test/layers/containers_tests.jl | 8 ++-- test/qa_tests.jl | 2 +- 10 files changed, 88 insertions(+), 75 deletions(-) create mode 100644 test/helpers/size_propagator_tests.jl diff --git a/docs/src/api/Building_Blocks/LuxCore.md b/docs/src/api/Building_Blocks/LuxCore.md index 54a4789fdd..3016597b62 100644 --- a/docs/src/api/Building_Blocks/LuxCore.md +++ b/docs/src/api/Building_Blocks/LuxCore.md @@ -50,12 +50,6 @@ LuxCore.update_state ## Layer size -!!! warning - - These specifications have been added very recently and most layers currently do not - implement them. - ```@docs -LuxCore.inputsize LuxCore.outputsize ``` diff --git a/src/Lux.jl b/src/Lux.jl index fdad3172fe..f3ea3cabe5 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -22,7 +22,7 @@ using Statistics: mean using UnrolledUtilities: unrolled_map, unrolled_mapreduce import LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer, - initialparameters, initialstates, parameterlength, statelength, inputsize, + initialparameters, initialstates, parameterlength, statelength, outputsize, update_state, trainmode, testmode, setup, apply, replicate @reexport using LuxCore, LuxLib, MLDataDevices, WeightInitializers diff --git a/src/helpers/size_propagator.jl b/src/helpers/size_propagator.jl index 33d301e70e..954065e53d 100644 --- a/src/helpers/size_propagator.jl +++ b/src/helpers/size_propagator.jl @@ -1,6 +1,4 @@ # Initial design is based off of https://github.com/FluxML/Flux.jl/blob/942c6e5051b7a8cb064432d1f0604319497d5f09/src/outputsize.jl -# Currently this is not being used anywhere. However, with 1.0 release we will define -# outputsize for all layers using this. module NilSizePropagation using ArrayInterface: ArrayInterface @@ -199,25 +197,8 @@ end end -# TODO: In v1 we change to this `outputsize` function, till then this is private API -function compute_output_size(layer::AbstractExplicitLayer, - input_size::NTuple{N, <:Integer}, rng::AbstractRNG) where {N} - x = NilSizePropagation.NilArray{N}(input_size) - return compute_output_size(layer, x, rng) -end - -function compute_output_size( - layer::AbstractExplicitLayer, input_size::NTuple{N, <:Integer}, ps, st) where {N} - x = NilSizePropagation.NilArray{N}(input_size) - return compute_output_size(layer, x, ps, st) -end - -function compute_output_size(layer::AbstractExplicitLayer, x, rng::AbstractRNG) +function LuxCore.outputsize(layer::AbstractLuxLayer, x, rng::AbstractRNG) ps, st = setup(rng, layer) - return compute_output_size(layer, x, ps, st) -end - -function compute_output_size(layer::AbstractExplicitLayer, x, ps, st) x_nil = NilSizePropagation.recursively_nillify(x) ps_nil = NilSizePropagation.recursively_nillify(ps) st_nil = NilSizePropagation.recursively_nillify(st) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index fea71d0cdc..e8124e2e9d 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -36,7 +36,7 @@ struct ReshapeLayer{N} <: AbstractLuxLayer dims::NTuple{N, Int} end -outputsize(r::ReshapeLayer) = r.dims +outputsize(r::ReshapeLayer, _, __) = r.dims function (r::ReshapeLayer)(x::AbstractArray, _, st::NamedTuple) return reshape(x, r.dims..., size(x, ndims(x))), st @@ -322,7 +322,7 @@ end parameterlength(d::Dense) = d.out_dims * d.in_dims + has_bias(d) * d.out_dims statelength(d::Dense) = 0 -outputsize(d::Dense) = (d.out_dims,) +outputsize(d::Dense, _, __) = (d.out_dims,) function (d::Dense)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(d, ps, st, x) @@ -409,7 +409,7 @@ end parameterlength(d::Scale) = (1 + has_bias(d)) * prod(d.dims) statelength(d::Scale) = 0 -outputsize(d::Scale) = d.dims +outputsize(d::Scale, _, __) = d.dims function (d::Scale{False})(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(d, ps, st, x) @@ -516,7 +516,7 @@ function parameterlength(b::Bilinear) end statelength(b::Bilinear) = 0 -outputsize(b::Bilinear) = (b.out_dims,) +outputsize(b::Bilinear, _, __) = (b.out_dims,) function (b::Bilinear)( (x, y)::Tuple{<:AbstractVecOrMat, <:AbstractVecOrMat}, ps, st::NamedTuple) @@ -596,7 +596,7 @@ function Base.show(io::IO, e::Embedding) return print(io, "Embedding(", e.in_dims, " => ", e.out_dims, ")") end -outputsize(e::Embedding) = (e.out_dims,) +outputsize(e::Embedding, _, __) = (e.out_dims,) (e::Embedding)(x::Integer, ps, st::NamedTuple) = view(ps.weight, :, x), st function (e::Embedding)(x::AbstractVector{<:Integer}, ps, st::NamedTuple) diff --git a/src/layers/containers.jl b/src/layers/containers.jl index 5450376077..1b5832c892 100644 --- a/src/layers/containers.jl +++ b/src/layers/containers.jl @@ -506,8 +506,6 @@ Base.length(c::Chain) = length(c.layers) Base.lastindex(c::Chain) = lastindex(c.layers) Base.firstindex(c::Chain) = firstindex(c.layers) -outputsize(c::Chain) = outputsize(c.layers[end]) - """ Maxout(layers...) Maxout(; layers...) diff --git a/test/helpers/size_propagator_test.jl b/test/helpers/size_propagator_test.jl index 9dde070ced..7825cb75d3 100644 --- a/test/helpers/size_propagator_test.jl +++ b/test/helpers/size_propagator_test.jl @@ -5,10 +5,10 @@ lenet = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)), Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3), Dense(256 => 120, relu), Dense(120 => 84, relu), Dense(84 => 10)) - ps, st = Lux.setup(rng, lenet) - @test Lux.compute_output_size(lenet, (28, 28, 1, 3), ps, st) == (10,) - @test Lux.compute_output_size(lenet, (28, 28, 1, 12), ps, st) == (10,) + for x in (randn(rng, Float32, 28, 28, 1, 3), randn(rng, Float32, 28, 28, 1, 12)) + @test Lux.outputsize(lenet, x, rng) == (10,) + end end @testset "Chain with BatchNorm" begin @@ -17,35 +17,24 @@ MaxPool((2, 2)), FlattenLayer(3), Dense(256 => 120, relu), BatchNorm(120, relu), Dense(120 => 84, relu), Dropout(0.5f0), BatchNorm(84, relu), Dense(84 => 10), BatchNorm(10, relu)) - ps, st = Lux.setup(rng, lenet) - @test Lux.compute_output_size(lenet, (28, 28, 1, 3), ps, st) == (10,) - @test Lux.compute_output_size(lenet, (28, 28, 1, 12), ps, st) == (10,) + for x in (randn(rng, Float32, 28, 28, 1, 3), randn(rng, Float32, 28, 28, 1, 12)) + @test Lux.outputsize(lenet, x, rng) == (10,) + end end - @testset "Normalization Layers" begin - layer = BatchNorm(3, relu) - ps, st = Lux.setup(rng, layer) - - @test Lux.compute_output_size(layer, (4, 4, 3, 2), ps, st) == (4, 4, 3) - @test Lux.compute_output_size(layer, (3, 3), ps, st) == (3,) - - layer = GroupNorm(6, 3, relu) - ps, st = Lux.setup(rng, layer) - - @test Lux.compute_output_size(layer, (4, 4, 6, 2), ps, st) == (4, 4, 6) - @test Lux.compute_output_size(layer, (6, 3), ps, st) == (6,) - - layer = InstanceNorm(3, relu) - ps, st = Lux.setup(rng, layer) - - @test Lux.compute_output_size(layer, (4, 4, 3, 2), ps, st) == (4, 4, 3) - @test Lux.compute_output_size(layer, (4, 3, 2), ps, st) == (4, 3) - - layer = LayerNorm((2, 1, 3), relu) - ps, st = Lux.setup(rng, layer) - - @test Lux.compute_output_size(layer, (2, 4, 3, 2), ps, st) == (2, 4, 3) - @test Lux.compute_output_size(layer, (2, 1, 3, 3), ps, st) == (2, 1, 3) + norm_layer = [ + (BatchNorm(3, relu), [randn(rng, Float32, 4, 4, 3, 2), randn(rng, Float32, 3, 3)]), + (GroupNorm(6, 3, relu), + [randn(rng, Float32, 4, 4, 6, 2), randn(rng, Float32, 6, 3)]), + (InstanceNorm(3, relu), + [randn(rng, Float32, 4, 4, 3, 2), randn(rng, Float32, 4, 3, 2)]), + (LayerNorm((2, 1, 3), relu), + [randn(rng, Float32, 2, 4, 3, 2), randn(rng, Float32, 2, 1, 3, 3)])] + + @testset "Normalization: $(nameof(typeof(layer)))" for (layer, xs) in norm_layer + for x in xs + @test Lux.outputsize(layer, x, rng) == size(x)[1:(end - 1)] + end end end diff --git a/test/helpers/size_propagator_tests.jl b/test/helpers/size_propagator_tests.jl new file mode 100644 index 0000000000..9dde070ced --- /dev/null +++ b/test/helpers/size_propagator_tests.jl @@ -0,0 +1,51 @@ +@testitem "Size Propagator" setup=[SharedTestSetup] tags=[:helpers] begin + rng = StableRNG(12345) + + @testset "Simple Chain (LeNet)" begin + lenet = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)), + Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3), + Dense(256 => 120, relu), Dense(120 => 84, relu), Dense(84 => 10)) + ps, st = Lux.setup(rng, lenet) + + @test Lux.compute_output_size(lenet, (28, 28, 1, 3), ps, st) == (10,) + @test Lux.compute_output_size(lenet, (28, 28, 1, 12), ps, st) == (10,) + end + + @testset "Chain with BatchNorm" begin + lenet = Chain(Conv((5, 5), 1 => 6, relu), BatchNorm(6, relu), MaxPool((2, 2)), + Conv((5, 5), 6 => 16, relu), BatchNorm(16, relu), + MaxPool((2, 2)), FlattenLayer(3), Dense(256 => 120, relu), + BatchNorm(120, relu), Dense(120 => 84, relu), Dropout(0.5f0), + BatchNorm(84, relu), Dense(84 => 10), BatchNorm(10, relu)) + ps, st = Lux.setup(rng, lenet) + + @test Lux.compute_output_size(lenet, (28, 28, 1, 3), ps, st) == (10,) + @test Lux.compute_output_size(lenet, (28, 28, 1, 12), ps, st) == (10,) + end + + @testset "Normalization Layers" begin + layer = BatchNorm(3, relu) + ps, st = Lux.setup(rng, layer) + + @test Lux.compute_output_size(layer, (4, 4, 3, 2), ps, st) == (4, 4, 3) + @test Lux.compute_output_size(layer, (3, 3), ps, st) == (3,) + + layer = GroupNorm(6, 3, relu) + ps, st = Lux.setup(rng, layer) + + @test Lux.compute_output_size(layer, (4, 4, 6, 2), ps, st) == (4, 4, 6) + @test Lux.compute_output_size(layer, (6, 3), ps, st) == (6,) + + layer = InstanceNorm(3, relu) + ps, st = Lux.setup(rng, layer) + + @test Lux.compute_output_size(layer, (4, 4, 3, 2), ps, st) == (4, 4, 3) + @test Lux.compute_output_size(layer, (4, 3, 2), ps, st) == (4, 3) + + layer = LayerNorm((2, 1, 3), relu) + ps, st = Lux.setup(rng, layer) + + @test Lux.compute_output_size(layer, (2, 4, 3, 2), ps, st) == (2, 4, 3) + @test Lux.compute_output_size(layer, (2, 1, 3, 3), ps, st) == (2, 1, 3) + end +end diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index fda80d6a65..8030bd18ef 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -9,7 +9,7 @@ x = randn(rng, 6, 3) |> aType @test size(layer(x, ps, st)[1]) == (2, 3, 3) - @test Lux.outputsize(layer) == (2, 3) + @test Lux.outputsize(layer, x, rng) == (2, 3) @jet layer(x, ps, st) @@ -159,7 +159,7 @@ end @test size(first(Lux.apply(layer, randn(10), ps, st))) == (5,) @test size(first(Lux.apply(layer, randn(10, 2), ps, st))) == (5, 2) - @test LuxCore.outputsize(layer) == (5,) + @test LuxCore.outputsize(layer, randn(10), rng) == (5,) end @testset "zeros" begin @@ -242,7 +242,7 @@ end @test size(first(Lux.apply(layer, randn(10, 5, 2) |> aType, ps, st))) == (10, 5, 2) - @test LuxCore.outputsize(layer) == (10, 5) + @test LuxCore.outputsize(layer, randn(10), rng) == (10, 5) end @testset "zeros" begin @@ -342,7 +342,7 @@ end @test size(layer((x, y), ps, st)[1]) == (3, 1) @test sum(abs2, layer((x, y), ps, st)[1]) == 0.0f0 - @test LuxCore.outputsize(layer) == (3,) + @test LuxCore.outputsize(layer, (x, y), rng) == (3,) @jet layer((x, y), ps, st) @@ -390,7 +390,7 @@ end @test size(ps.weight) == (embed_size, vocab_size) - @test LuxCore.outputsize(layer) == (4,) + @test LuxCore.outputsize(layer, nothing, rng) == (4,) x = rand(1:vocab_size, 1)[1] y, st_ = layer(x, ps, st) @@ -422,7 +422,7 @@ end @test size(ps.weight) == (embed_size, vocab_size...) - @test LuxCore.outputsize(layer) == (4,) + @test LuxCore.outputsize(layer, nothing, rng) == (4,) x = (rand(1:vocab_size[1], 1)[1], rand(1:vocab_size[2], 1)[1]) y, st_ = layer(x, ps, st) diff --git a/test/layers/containers_tests.jl b/test/layers/containers_tests.jl index 218b1a7bf7..e58296acc1 100644 --- a/test/layers/containers_tests.jl +++ b/test/layers/containers_tests.jl @@ -262,7 +262,7 @@ end x = rand(Float32, 10, 1) |> aType y, _ = layer(x, ps, st) @test size(y) == (1, 1) - @test Lux.outputsize(layer) == (1,) + @test Lux.outputsize(layer, x, rng) == (1,) @jet layer(x, ps, st) @@ -290,7 +290,7 @@ end x = rand(Float32, 10, 1) |> aType y, _ = layer(x, ps, st) @test size(y) == (2, 1) - @test Lux.outputsize(layer) == (2,) + @test Lux.outputsize(layer, x, rng) == (2,) @jet layer(x, ps, st) @@ -305,7 +305,7 @@ end x = rand(Float32, 10, 1) |> aType y, _ = layer(x, ps, st) @test size(y) == (2, 1) - @test Lux.outputsize(layer) == (2,) + @test Lux.outputsize(layer, x, rng) == (2,) @jet layer(x, ps, st) @@ -320,7 +320,7 @@ end x = rand(Float32, 10, 1) |> aType y, _ = layer(x, ps, st) @test size(y) == (5, 1) - @test Lux.outputsize(layer) == (5,) + @test Lux.outputsize(layer, x, rng) == (5,) @jet layer(x, ps, st) diff --git a/test/qa_tests.jl b/test/qa_tests.jl index ac59d4ffcd..3a8b1e7a97 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -19,7 +19,7 @@ end Lux; skip=(Base, Core, LuxCore, MLDataDevices, LuxLib, WeightInitializers)) === nothing @test check_no_stale_explicit_imports( - Lux; ignore=(:inputsize, :setup, :testmode, :trainmode, :update_state)) === nothing + Lux; ignore=(:setup, :testmode, :trainmode, :update_state)) === nothing @test check_no_self_qualified_accesses(Lux) === nothing @test check_all_explicit_imports_via_owners(Lux) === nothing @test check_all_qualified_accesses_via_owners( From c52754a94cba700fa739a0f3a399d0e6702a9b20 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 16:54:06 -0700 Subject: [PATCH 29/95] test: fix broken tests --- src/Lux.jl | 3 +- src/contrib/contrib.jl | 9 ++--- src/extended_ops.jl | 3 +- src/helpers/training.jl | 1 - src/layers/basic.jl | 10 ++--- src/layers/display.jl | 1 + test/contrib/debug_tests.jl | 2 +- test/helpers/size_propagator_tests.jl | 53 +++++++++++---------------- test/qa_tests.jl | 2 +- 9 files changed, 37 insertions(+), 47 deletions(-) diff --git a/src/Lux.jl b/src/Lux.jl index f3ea3cabe5..346d6c7db5 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -26,7 +26,8 @@ import LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperL outputsize, update_state, trainmode, testmode, setup, apply, replicate @reexport using LuxCore, LuxLib, MLDataDevices, WeightInitializers -@eval Expr(:export, filter(x -> x !== :dropout, Reexport.exported_names(NNlib))...) +using NNlib: NNlib, DenseConvDims, PoolDims, logsigmoid, logsoftmax, maxpool, meanpool, + pixel_shuffle, sigmoid_fast, tanh_fast const CRC = ChainRulesCore diff --git a/src/contrib/contrib.jl b/src/contrib/contrib.jl index c97022e6c4..87360b2118 100644 --- a/src/contrib/contrib.jl +++ b/src/contrib/contrib.jl @@ -1,10 +1,5 @@ module Experimental -using ..Lux: Lux, Optional -using ..Utils: Utils, BoolType, SymbolType -using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxContainerLayer, - AbstractLuxWrapperLayer, apply - using ADTypes: ADTypes using ArgCheck: @argcheck using ChainRulesCore: ChainRulesCore @@ -18,6 +13,10 @@ using Random: AbstractRNG, Random using Setfield: Setfield using Static: StaticSymbol, StaticBool, True, known, static, dynamic +using ..Lux: Lux, Optional +using ..Utils: Utils, BoolType, SymbolType +using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxWrapperLayer, apply + const CRC = ChainRulesCore include("map.jl") diff --git a/src/extended_ops.jl b/src/extended_ops.jl index 2bcd12555d..d2e65b4f28 100644 --- a/src/extended_ops.jl +++ b/src/extended_ops.jl @@ -11,9 +11,10 @@ using ChainRulesCore: ChainRulesCore, NoTangent, ZeroTangent, @thunk, @non_diffe using Compat: @compat using EnzymeCore: EnzymeCore using FastClosures: @closure -using MLDataDevices: get_device_type, AbstractGPUDevice, AbstractDevice using Static: StaticBool, StaticSymbol, known +using MLDataDevices: get_device_type, AbstractGPUDevice, AbstractDevice + using ..Utils: Utils const CRC = ChainRulesCore diff --git a/src/helpers/training.jl b/src/helpers/training.jl index 134e30102e..cd146169cf 100644 --- a/src/helpers/training.jl +++ b/src/helpers/training.jl @@ -8,7 +8,6 @@ using Optimisers: Optimisers using ..Lux: Lux using LuxCore: LuxCore, AbstractLuxLayer -using Optimisers: Optimisers """ TrainState diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e8124e2e9d..1a761767d9 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -36,7 +36,7 @@ struct ReshapeLayer{N} <: AbstractLuxLayer dims::NTuple{N, Int} end -outputsize(r::ReshapeLayer, _, __) = r.dims +outputsize(r::ReshapeLayer, _, ::AbstractRNG) = r.dims function (r::ReshapeLayer)(x::AbstractArray, _, st::NamedTuple) return reshape(x, r.dims..., size(x, ndims(x))), st @@ -322,7 +322,7 @@ end parameterlength(d::Dense) = d.out_dims * d.in_dims + has_bias(d) * d.out_dims statelength(d::Dense) = 0 -outputsize(d::Dense, _, __) = (d.out_dims,) +outputsize(d::Dense, _, ::AbstractRNG) = (d.out_dims,) function (d::Dense)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(d, ps, st, x) @@ -409,7 +409,7 @@ end parameterlength(d::Scale) = (1 + has_bias(d)) * prod(d.dims) statelength(d::Scale) = 0 -outputsize(d::Scale, _, __) = d.dims +outputsize(d::Scale, _, ::AbstractRNG) = d.dims function (d::Scale{False})(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(d, ps, st, x) @@ -516,7 +516,7 @@ function parameterlength(b::Bilinear) end statelength(b::Bilinear) = 0 -outputsize(b::Bilinear, _, __) = (b.out_dims,) +outputsize(b::Bilinear, _, ::AbstractRNG) = (b.out_dims,) function (b::Bilinear)( (x, y)::Tuple{<:AbstractVecOrMat, <:AbstractVecOrMat}, ps, st::NamedTuple) @@ -596,7 +596,7 @@ function Base.show(io::IO, e::Embedding) return print(io, "Embedding(", e.in_dims, " => ", e.out_dims, ")") end -outputsize(e::Embedding, _, __) = (e.out_dims,) +outputsize(e::Embedding, _, ::AbstractRNG) = (e.out_dims,) (e::Embedding)(x::Integer, ps, st::NamedTuple) = view(ps.weight, :, x), st function (e::Embedding)(x::AbstractVector{<:Integer}, ps, st::NamedTuple) diff --git a/src/layers/display.jl b/src/layers/display.jl index 48e09f0f3d..8570a7091d 100644 --- a/src/layers/display.jl +++ b/src/layers/display.jl @@ -1,6 +1,7 @@ module PrettyPrinting using Functors: Functors + using LuxCore: LuxCore, AbstractLuxContainerLayer, AbstractLuxLayer, display_name printable_children(x) = Functors.children(x) diff --git a/test/contrib/debug_tests.jl b/test/contrib/debug_tests.jl index 2053c1189a..2aff0dd4e2 100644 --- a/test/contrib/debug_tests.jl +++ b/test/contrib/debug_tests.jl @@ -3,7 +3,7 @@ rng = StableRNG(12345) - @testset "$mode" for (mode, aType, device, ongpu) in MODES + @testset "$mode" for (mode, aType, dev, ongpu) in MODES model = Chain( Dense(1 => 16, relu), Chain(Dense(16 => 3), Dense(1 => 1)), BatchNorm(1)) diff --git a/test/helpers/size_propagator_tests.jl b/test/helpers/size_propagator_tests.jl index 9dde070ced..bf840de417 100644 --- a/test/helpers/size_propagator_tests.jl +++ b/test/helpers/size_propagator_tests.jl @@ -3,49 +3,38 @@ @testset "Simple Chain (LeNet)" begin lenet = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)), - Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3), + Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(), Dense(256 => 120, relu), Dense(120 => 84, relu), Dense(84 => 10)) - ps, st = Lux.setup(rng, lenet) - @test Lux.compute_output_size(lenet, (28, 28, 1, 3), ps, st) == (10,) - @test Lux.compute_output_size(lenet, (28, 28, 1, 12), ps, st) == (10,) + @testset "size(x) = $(size(x))" for x in ( + randn(rng, Float32, 28, 28, 1, 3), randn(rng, Float32, 28, 28, 1, 12)) + @test Lux.outputsize(lenet, x, rng) == (10,) + end end @testset "Chain with BatchNorm" begin lenet = Chain(Conv((5, 5), 1 => 6, relu), BatchNorm(6, relu), MaxPool((2, 2)), Conv((5, 5), 6 => 16, relu), BatchNorm(16, relu), - MaxPool((2, 2)), FlattenLayer(3), Dense(256 => 120, relu), + MaxPool((2, 2)), FlattenLayer(), Dense(256 => 120, relu), BatchNorm(120, relu), Dense(120 => 84, relu), Dropout(0.5f0), BatchNorm(84, relu), Dense(84 => 10), BatchNorm(10, relu)) - ps, st = Lux.setup(rng, lenet) - @test Lux.compute_output_size(lenet, (28, 28, 1, 3), ps, st) == (10,) - @test Lux.compute_output_size(lenet, (28, 28, 1, 12), ps, st) == (10,) + @testset "size(x) = $(size(x))" for x in ( + randn(rng, Float32, 28, 28, 1, 3), randn(rng, Float32, 28, 28, 1, 12)) + @test Lux.outputsize(lenet, x, rng) == (10,) + end end - @testset "Normalization Layers" begin - layer = BatchNorm(3, relu) - ps, st = Lux.setup(rng, layer) - - @test Lux.compute_output_size(layer, (4, 4, 3, 2), ps, st) == (4, 4, 3) - @test Lux.compute_output_size(layer, (3, 3), ps, st) == (3,) - - layer = GroupNorm(6, 3, relu) - ps, st = Lux.setup(rng, layer) - - @test Lux.compute_output_size(layer, (4, 4, 6, 2), ps, st) == (4, 4, 6) - @test Lux.compute_output_size(layer, (6, 3), ps, st) == (6,) - - layer = InstanceNorm(3, relu) - ps, st = Lux.setup(rng, layer) - - @test Lux.compute_output_size(layer, (4, 4, 3, 2), ps, st) == (4, 4, 3) - @test Lux.compute_output_size(layer, (4, 3, 2), ps, st) == (4, 3) - - layer = LayerNorm((2, 1, 3), relu) - ps, st = Lux.setup(rng, layer) - - @test Lux.compute_output_size(layer, (2, 4, 3, 2), ps, st) == (2, 4, 3) - @test Lux.compute_output_size(layer, (2, 1, 3, 3), ps, st) == (2, 1, 3) + @testset "Normalization: $(nameof(typeof(layer)))" for (layer, xs) in [ + (BatchNorm(3, relu), [randn(rng, Float32, 4, 4, 3, 2), randn(rng, Float32, 3, 3)]), + (GroupNorm(6, 3, relu), + [randn(rng, Float32, 4, 4, 6, 2), randn(rng, Float32, 6, 3)]), + (InstanceNorm(3, relu), + [randn(rng, Float32, 4, 4, 3, 2), randn(rng, Float32, 4, 3, 2)]), + (LayerNorm((2, 1, 3), relu), + [randn(rng, Float32, 2, 4, 3, 2), randn(rng, Float32, 2, 1, 3, 3)])] + for x in xs + @test Lux.outputsize(layer, x, rng) == size(x)[1:(end - 1)] + end end end diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 3a8b1e7a97..6897c33467 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -35,7 +35,7 @@ end doctestexpr = quote using SimpleChains: static using DynamicExpressions - using Adapt, Lux, Random, Optimisers, Zygote + using Adapt, Lux, Random, Optimisers, Zygote, NNlib end DocMeta.setdocmeta!(Lux, :DocTestSetup, doctestexpr; recursive=true) From ec3be20f4fe382a5e3704412b48ee1aef9bc8395 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 18 Aug 2024 17:51:14 -0700 Subject: [PATCH 30/95] fix: printing of container layers --- src/layers/display.jl | 6 ++---- test/enzyme_tests.jl | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/layers/display.jl b/src/layers/display.jl index 8570a7091d..6f5f52c644 100644 --- a/src/layers/display.jl +++ b/src/layers/display.jl @@ -2,13 +2,11 @@ module PrettyPrinting using Functors: Functors -using LuxCore: LuxCore, AbstractLuxContainerLayer, AbstractLuxLayer, display_name +using LuxCore: LuxCore, AbstractLuxWrapperLayer, AbstractLuxLayer, display_name printable_children(x) = Functors.children(x) -function printable_children(m::AbstractLuxContainerLayer{layers}) where {layers} +function printable_children(m::AbstractLuxWrapperLayer{field}) where {field} children = Functors.children(m) - length(layers) ≥ 2 && return children - field = first(layers) hasfield(typeof(children), field) || return children nt = getfield(children, field) nt isa NamedTuple || (nt = NamedTuple{(field,)}((nt,))) diff --git a/test/enzyme_tests.jl b/test/enzyme_tests.jl index b3e5c7b7b8..5d7ac76cf0 100644 --- a/test/enzyme_tests.jl +++ b/test/enzyme_tests.jl @@ -2,7 +2,7 @@ # able to remove this, but this file is still helpful to catch errors in a localized way. @testsetup module EnzymeTestSetup using LuxTestUtils, Enzyme, Zygote, Test -using Lux +using Lux, NNlib using LuxTestUtils: check_approx generic_loss_function(model, x, ps, st) = sum(abs2, first(model(x, ps, st))) From 005ecf4a4ea481dc1b28ff4d727957cb7e9ec52e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 22:33:51 -0400 Subject: [PATCH 31/95] fix!: (re)move deprecated `DynamicExpressionsLayer` --- docs/src/api/Lux/interop.md | 13 -- docs/tutorials.jl | 1 - examples/SymbolicOptimalControl/Project.toml | 9 -- examples/SymbolicOptimalControl/main.jl | 5 - ext/LuxDynamicExpressionsExt.jl | 154 ------------------- ext/LuxReverseDiffExt/rules.jl | 9 -- ext/LuxTrackerExt/rules.jl | 9 -- src/Lux.jl | 1 - src/layers/extension.jl | 119 -------------- test/layers/dynamic_expressions_tests.jl | 60 -------- 10 files changed, 380 deletions(-) delete mode 100644 examples/SymbolicOptimalControl/Project.toml delete mode 100644 examples/SymbolicOptimalControl/main.jl delete mode 100644 ext/LuxDynamicExpressionsExt.jl delete mode 100644 test/layers/dynamic_expressions_tests.jl diff --git a/docs/src/api/Lux/interop.md b/docs/src/api/Lux/interop.md index cf377bcbe6..f5c81fc71a 100644 --- a/docs/src/api/Lux/interop.md +++ b/docs/src/api/Lux/interop.md @@ -41,16 +41,3 @@ Adapt.adapt(from::ToSimpleChainsAdaptor, L::AbstractLuxLayer) ToSimpleChainsAdaptor SimpleChainsLayer ``` - -## Symbolic Expressions - -### Embedding DynamicExpressions.jl Node in Lux Layers - -!!! tip - - Accessing these functions require manually loading `DynamicExpressions`, i.e., - `using DynamicExpressions` must be present somewhere in the code for these to be used. - -```@docs -DynamicExpressionsLayer -``` diff --git a/docs/tutorials.jl b/docs/tutorials.jl index 68597e6e26..9d11a55637 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -12,7 +12,6 @@ const INTERMEDIATE_TUTORIALS = [ ] const ADVANCED_TUTORIALS = [ "GravitationalWaveForm/main.jl", - "SymbolicOptimalControl/main.jl" ] const TUTORIALS = [ diff --git a/examples/SymbolicOptimalControl/Project.toml b/examples/SymbolicOptimalControl/Project.toml deleted file mode 100644 index fd41919312..0000000000 --- a/examples/SymbolicOptimalControl/Project.toml +++ /dev/null @@ -1,9 +0,0 @@ -[deps] -InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" - -[compat] -InteractiveUtils = "<0.0.1, 1" -Literate = "2" -Lux = "0.5" diff --git a/examples/SymbolicOptimalControl/main.jl b/examples/SymbolicOptimalControl/main.jl deleted file mode 100644 index b23498f844..0000000000 --- a/examples/SymbolicOptimalControl/main.jl +++ /dev/null @@ -1,5 +0,0 @@ -# # Solving Optimal Control Problems with Symbolic Universal Differential Equations - -# This tutorial has been been moved to Boltz.jl documentation. Refer to the the -# [Symbolic Optimal Control](https://luxdl.github.io/Boltz.jl/stable/tutorials/2_SymbolicOptimalControl) -# tutorial for more details. diff --git a/ext/LuxDynamicExpressionsExt.jl b/ext/LuxDynamicExpressionsExt.jl deleted file mode 100644 index c552b0c83c..0000000000 --- a/ext/LuxDynamicExpressionsExt.jl +++ /dev/null @@ -1,154 +0,0 @@ -module LuxDynamicExpressionsExt - -using ChainRulesCore: NoTangent -using DynamicExpressions: DynamicExpressions, Node, OperatorEnum, eval_grad_tree_array, - eval_tree_array -using FastClosures: @closure -using ForwardDiff: ForwardDiff - -using Lux: Lux, NAME_TYPE, Chain, Parallel, WrappedFunction, DynamicExpressionsLayer -using MLDataDevices: CPUDevice - -@static if pkgversion(DynamicExpressions) ≥ v"0.19" - using DynamicExpressions: EvalOptions - - const EvalOptionsTypes = Union{Missing, EvalOptions, NamedTuple} -else - const EvalOptionsTypes = Union{Missing, NamedTuple} -end - -Lux.is_extension_loaded(::Val{:DynamicExpressions}) = true - -function Lux.DynamicExpressionsLayer(operator_enum::OperatorEnum, expressions::Node...; - name::NAME_TYPE=nothing, eval_options::EvalOptionsTypes=missing, - turbo::Union{Bool, Val, Missing}=missing, - bumper::Union{Bool, Val, Missing}=missing) - eval_options = construct_eval_options( - eval_options, construct_eval_options(turbo, bumper)) - - length(expressions) == 1 && return Lux.DynamicExpressionsLayer( - operator_enum, first(expressions), name, eval_options) - name_fn = name === nothing ? Returns(nothing) : @closure(i->"$(name)_$(i)") - #! format: off - return Chain( - Parallel(nothing, - ntuple(i -> DynamicExpressionsLayer(operator_enum, expressions[i], - name_fn(i), eval_options), length(expressions))...), - WrappedFunction(Lux.Utils.stack1); - name="DynamicExpressionsLayer") - #! format: on -end - -function Lux.DynamicExpressionsLayer( - operator_enum::OperatorEnum, expressions::AbstractVector{<:Node}; kwargs...) - return Lux.DynamicExpressionsLayer(operator_enum, expressions...; kwargs...) -end - -construct_eval_options(::Missing, ::Missing) = (; turbo=Val(false), bumper=Val(false)) -function construct_eval_options(turbo::Union{Bool, Val}, ::Missing) - return construct_eval_options(turbo, Val(false)) -end -function construct_eval_options(::Missing, bumper::Union{Bool, Val}) - return construct_eval_options(Val(false), bumper) -end -function construct_eval_options(turbo::Union{Bool, Val}, bumper::Union{Bool, Val}) - Base.depwarn("`bumper` and `turbo` are deprecated. Use `eval_options` instead.", - :DynamicExpressionsLayer) - return (; turbo, bumper) -end - -construct_eval_options(::Missing, eval_options::EvalOptionsTypes) = eval_options -construct_eval_options(eval_options::EvalOptionsTypes, ::Missing) = eval_options -function construct_eval_options(::EvalOptionsTypes, ::EvalOptionsTypes) - throw(ArgumentError("`eval_options`, `turbo` and `bumper` are mutually exclusive. \ - Don't specify `eval_options` if you are using `turbo` or \ - `bumper`.")) -end - -function Lux.apply_dynamic_expression_internal( - de::Lux.DynamicExpressionsLayer, expr, operator_enum, x, ps) - Lux.update_de_expression_constants!(expr, ps) - @static if pkgversion(DynamicExpressions) ≥ v"0.19" - eval_options = EvalOptions(; de.eval_options.turbo, de.eval_options.bumper) - return first(eval_tree_array(expr, x, operator_enum; eval_options)) - else - return first(eval_tree_array( - expr, x, operator_enum; de.eval_options.turbo, de.eval_options.bumper)) - end -end - -function Lux.∇apply_dynamic_expression( - de::Lux.DynamicExpressionsLayer, expr, operator_enum, x, ps) - Lux.update_de_expression_constants!(expr, ps) - _, Jₓ, _ = eval_grad_tree_array( - expr, x, operator_enum; variable=Val(true), de.eval_options.turbo) - y, Jₚ, _ = eval_grad_tree_array( - expr, x, operator_enum; variable=Val(false), de.eval_options.turbo) - ∇apply_dynamic_expression_internal = @closure Δ -> begin - ∂x = Jₓ .* reshape(Δ, 1, :) - ∂ps = Jₚ * Δ - return NoTangent(), NoTangent(), NoTangent(), NoTangent(), ∂x, ∂ps, NoTangent() - end - return y, ∇apply_dynamic_expression_internal -end - -# Forward Diff rules -function Lux.apply_dynamic_expression(de::DynamicExpressionsLayer, expr, operator_enum, - x::AbstractMatrix{<:ForwardDiff.Dual{Tag, T, N}}, - ps, ::CPUDevice) where {T, N, Tag} - value_fn(x) = ForwardDiff.value(Tag, x) - partials_fn(x, i) = ForwardDiff.partials(Tag, x, i) - - Lux.update_de_expression_constants!(expr, ps) - y, Jₓ, _ = eval_grad_tree_array( - expr, value_fn.(x), operator_enum; variable=Val(true), de.eval_options.turbo) - partials = ntuple( - @closure(i->dropdims(sum(partials_fn.(x, i) .* Jₓ; dims=1); dims=1)), N) - - fT = promote_type(eltype(y), T, eltype(Jₓ)) - partials_y = ForwardDiff.Partials{N, fT}.(tuple.(partials...)) - return ForwardDiff.Dual{Tag, fT, N}.(y, partials_y) -end - -function Lux.apply_dynamic_expression(de::DynamicExpressionsLayer, expr, operator_enum, x, - ps::AbstractVector{<:ForwardDiff.Dual{Tag, T, N}}, ::CPUDevice) where {T, N, Tag} - value_fn(x) = ForwardDiff.value(Tag, x) - partials_fn(x, i) = ForwardDiff.partials(Tag, x, i) - - Lux.update_de_expression_constants!(expr, value_fn.(ps)) - y, Jₚ, _ = eval_grad_tree_array( - expr, x, operator_enum; variable=Val(false), de.eval_options.turbo) - partials = ntuple( - @closure(i->dropdims(sum(partials_fn.(ps, i) .* Jₚ; dims=1); dims=1)), N) - - fT = promote_type(eltype(y), T, eltype(Jₚ)) - partials_y = ForwardDiff.Partials{N, fT}.(tuple.(partials...)) - return ForwardDiff.Dual{Tag, fT, N}.(y, partials_y) -end - -function Lux.apply_dynamic_expression(de::DynamicExpressionsLayer, expr, operator_enum, - x::AbstractMatrix{<:ForwardDiff.Dual{Tag, T1, N}}, - ps::AbstractVector{<:ForwardDiff.Dual{Tag, T2, N}}, - ::CPUDevice) where {T1, T2, N, Tag} - value_fn(x) = ForwardDiff.value(Tag, x) - partials_fn(x, i) = ForwardDiff.partials(Tag, x, i) - - ps_value = value_fn.(ps) - x_value = value_fn.(x) - - Lux.update_de_expression_constants!(expr, ps_value) - _, Jₓ, _ = eval_grad_tree_array( - expr, x_value, operator_enum; variable=Val(true), de.eval_options.turbo) - y, Jₚ, _ = eval_grad_tree_array( - expr, x_value, operator_enum; variable=Val(false), de.eval_options.turbo) - partials = ntuple( - @closure(i->dropdims(sum(partials_fn.(x, i) .* Jₓ; dims=1); dims=1) .+ - dropdims(sum(partials_fn.(ps, i) .* Jₚ; dims=1); dims=1)), - N) - - fT = promote_type(eltype(y), T1, T2, eltype(Jₓ), eltype(Jₚ)) - partials_y = ForwardDiff.Partials{N, fT}.(tuple.(partials...)) - return ForwardDiff.Dual{Tag, fT, N}.(y, partials_y) -end - -end diff --git a/ext/LuxReverseDiffExt/rules.jl b/ext/LuxReverseDiffExt/rules.jl index 08dd5fffb4..247bbd1200 100644 --- a/ext/LuxReverseDiffExt/rules.jl +++ b/ext/LuxReverseDiffExt/rules.jl @@ -4,15 +4,6 @@ @grad_from_chainrules Lux.apply_simple_chain( layer, x::TrackedArray, ps::TrackedArray, ::CPUDevice) -# DynamicExpressions.jl -@grad_from_chainrules Lux.apply_dynamic_expression( - de::Lux.DynamicExpressionsLayer, expr, operator_enum, x::TrackedArray, ps, ::CPUDevice) -@grad_from_chainrules Lux.apply_dynamic_expression( - de::Lux.DynamicExpressionsLayer, expr, operator_enum, x, ps::TrackedArray, ::CPUDevice) -@grad_from_chainrules Lux.apply_dynamic_expression( - de::Lux.DynamicExpressionsLayer, expr, operator_enum, - x::TrackedArray, ps::TrackedArray, ::CPUDevice) - # Nested AD @grad_from_chainrules Lux.AutoDiffInternalImpl.batched_jacobian_internal( f, backend::AbstractADType, x::TrackedArray) diff --git a/ext/LuxTrackerExt/rules.jl b/ext/LuxTrackerExt/rules.jl index 0976070b7f..5a6b5468dd 100644 --- a/ext/LuxTrackerExt/rules.jl +++ b/ext/LuxTrackerExt/rules.jl @@ -1,12 +1,3 @@ -# DynamicExpressions.jl -for T1 in (:TrackedArray, :AbstractArray), T2 in (:TrackedArray, :AbstractArray) - T1 === :AbstractArray && T2 === :AbstractArray && continue - - @eval @grad_from_chainrules Lux.apply_dynamic_expression( - de::Lux.DynamicExpressionsLayer, expr, - operator_enum, x::$(T1), ps::$(T2), dev::CPUDevice) -end - # Nested AD @grad_from_chainrules Lux.AutoDiffInternalImpl.batched_jacobian_internal( f, backend::AbstractADType, x::TrackedArray) diff --git a/src/Lux.jl b/src/Lux.jl index 346d6c7db5..97696532cc 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -115,7 +115,6 @@ export match_eltype export FromFluxAdaptor, FluxLayer export ToSimpleChainsAdaptor, SimpleChainsLayer -export DynamicExpressionsLayer export MPIBackend, NCCLBackend, DistributedUtils diff --git a/src/layers/extension.jl b/src/layers/extension.jl index feca2e9f0c..e4d7298ca7 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -1,124 +1,5 @@ # Layers here serve as a compatibility layer between different frameworks. The core # implementation is present in extensions - -## DynamicExpressions.jl -## We could constrain the type of `operator_enum` to be `OperatorEnum` but defining -## custom types in extensions tends to be a PITA -""" - DynamicExpressionsLayer(operator_enum::OperatorEnum, expressions::Node...; - name::NAME_TYPE=nothing, eval_options::EvalOptions=EvalOptions()) - DynamicExpressionsLayer(operator_enum::OperatorEnum, - expressions::AbstractVector{<:Node}; kwargs...) - -Wraps a `DynamicExpressions.jl` `Node` into a Lux layer and allows the constant nodes to -be updated using any of the AD Backends. - -For details about these expressions, refer to the -[`DynamicExpressions.jl` documentation](https://symbolicml.org/DynamicExpressions.jl/dev/types/). - -## Arguments - - - `operator_enum`: `OperatorEnum` from `DynamicExpressions.jl` - - `expressions`: `Node` from `DynamicExpressions.jl` or `AbstractVector{<:Node}` - -## Keyword Arguments - - - `name`: Name of the layer - - `turbo`: Use LoopVectorization.jl for faster evaluation **(Deprecated)** - - `bumper`: Use Bumper.jl for faster evaluation **(Deprecated)** - - `eval_options`: EvalOptions from `DynamicExpressions.jl` - -These options are simply forwarded to `DynamicExpressions.jl`'s `eval_tree_array` -and `eval_grad_tree_array` function. - -!!! danger "Deprecation Notice" - - These options are deprecated and will be removed in v1. Please use the version in - [`Boltz.jl`](https://github.com/LuxDL/Boltz.jl) instead. -""" -struct DynamicExpressionsLayer{OE, E, N, EO} <: AbstractExplicitLayer - operator_enum::OE - expression::E - name::N - eval_options::EO - - function DynamicExpressionsLayer(operator_enum::OE, expression::E, name::N, - eval_options::EO) where {OE, E, N, EO} - Base.depwarn( - "`DynamicExpressionsLayer` is deprecated and will be removed in v1. Please \ - use the corresponding version in `Boltz.jl` instead.", - :DynamicExpressionsLayer) - return new{OE, E, N, EO}(operator_enum, expression, name, eval_options) - end -end - -function Base.show(io::IO, l::DynamicExpressionsLayer) - print(io, - "DynamicExpressionsLayer($(l.operator_enum), $(l.expression); eval_options=$(l.eval_options))") -end - -function initialparameters(::AbstractRNG, layer::DynamicExpressionsLayer) - params = map(Base.Fix2(getproperty, :val), - filter(node -> node.degree == 0 && node.constant, layer.expression)) - return (; params) -end - -function update_de_expression_constants!(expression, ps) - # Don't use `set_constant_refs!` here, since it requires the types to match. In our - # case we just warn the user - params = filter(node -> node.degree == 0 && node.constant, expression) - foreach(enumerate(params)) do (i, node) - (node.val isa typeof(ps[i])) || - @warn lazy"node.val::$(typeof(node.val)) != ps[$i]::$(typeof(ps[i])). Type of node.val takes precedence. Fix the input expression if this is unintended." maxlog=1 - return node.val = ps[i] - end - return -end - -function (de::DynamicExpressionsLayer)(x::AbstractVector, ps, st) - y, stₙ = de(reshape(x, :, 1), ps, st) - return vec(y), stₙ -end - -# NOTE: Unfortunately we can't use `get_device_type` since it causes problems with -# ReverseDiff -function (de::DynamicExpressionsLayer)(x::AbstractMatrix, ps, st) - y = match_eltype(de, ps, st, x) - return ( - apply_dynamic_expression( - de, de.expression, de.operator_enum, y, ps.params, MLDataDevices.get_device(x)), - st) -end - -function apply_dynamic_expression_internal end - -function apply_dynamic_expression( - de::DynamicExpressionsLayer, expr, operator_enum, x, ps, ::CPUDevice) - if !is_extension_loaded(Val(:DynamicExpressions)) - error("`DynamicExpressions.jl` is not loaded. Please load it before using \ - `DynamicExpressionsLayer`.") - end - return apply_dynamic_expression_internal(de, expr, operator_enum, x, ps) -end - -function ∇apply_dynamic_expression end - -function CRC.rrule(::typeof(apply_dynamic_expression), de::DynamicExpressionsLayer, - expr, operator_enum, x, ps, ::CPUDevice) - if !is_extension_loaded(Val(:DynamicExpressions)) - error("`DynamicExpressions.jl` is not loaded. Please load it before using \ - `DynamicExpressionsLayer`.") - end - return ∇apply_dynamic_expression(de, expr, operator_enum, x, ps) -end - -function apply_dynamic_expression(de, expr, operator_enum, x, ps, dev) - throw(ArgumentError("`DynamicExpressions.jl` only supports CPU operations. Current \ - device detected as $(dev). CUDA.jl will be supported after \ - https://github.com/SymbolicML/DynamicExpressions.jl/pull/65 is \ - merged upstream.")) -end - ## Flux.jl """ FluxLayer(layer) diff --git a/test/layers/dynamic_expressions_tests.jl b/test/layers/dynamic_expressions_tests.jl deleted file mode 100644 index fae956c365..0000000000 --- a/test/layers/dynamic_expressions_tests.jl +++ /dev/null @@ -1,60 +0,0 @@ -@testitem "Dynamic Expressions" setup=[SharedTestSetup] tags=[:others] begin - using DynamicExpressions, ForwardDiff, ComponentArrays, Bumper - - operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos]) - - x1 = Node(; feature=1) - x2 = Node(; feature=2) - - expr_1 = x1 * cos(x2 - 3.2) - expr_2 = x2 - x1 * x2 + 2.5 - 1.0 * x1 - - for exprs in ((expr_1,), (expr_1, expr_2), ([expr_1, expr_2],)), - turbo in (Val(false), Val(true)), - bumper in (Val(false), Val(true)) - - layer = DynamicExpressionsLayer(operators, exprs...; turbo, bumper) - ps, st = Lux.setup(Random.default_rng(), layer) - - x = [1.0f0 2.0f0 3.0f0 - 4.0f0 5.0f0 6.0f0] - - y, st_ = layer(x, ps, st) - @test eltype(y) == Float32 - __f = (x, p) -> sum(abs2, first(layer(x, p, st))) - test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, skip_backends=[AutoEnzyme()]) - - # Particular ForwardDiff dispatches - ps_ca = ComponentArray(ps) - dps_ca = ForwardDiff.gradient(ps_ca) do ps_ - sum(abs2, first(layer(x, ps_, st))) - end - dx = ForwardDiff.gradient(x) do x_ - sum(abs2, first(layer(x_, ps, st))) - end - dxps = ForwardDiff.gradient(ComponentArray(; x=x, ps=ps)) do ca - sum(abs2, first(layer(ca.x, ca.ps, st))) - end - - @test dx≈dxps.x atol=1.0f-3 rtol=1.0f-3 - @test dps_ca≈dxps.ps atol=1.0f-3 rtol=1.0f-3 - - x = Float64.(x) - y, st_ = layer(x, ps, st) - @test eltype(y) == Float64 - __f = (x, p) -> sum(abs2, first(layer(x, p, st))) - test_gradients(__f, x, ps; atol=1.0e-3, rtol=1.0e-3, skip_backends=[AutoEnzyme()]) - end - - @testset "$(mode)" for (mode, aType, dev, ongpu) in MODES - layer = DynamicExpressionsLayer(operators, expr_1) - ps, st = Lux.setup(Random.default_rng(), layer) |> dev - - x = [1.0f0 2.0f0 3.0f0 - 4.0f0 5.0f0 6.0f0] |> aType - - if ongpu - @test_throws ArgumentError layer(x, ps, st) - end - end -end From cc9d663508a530a0217c62d8bc58744a7b153726 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 22:35:13 -0400 Subject: [PATCH 32/95] fix!: (re)move deprecated `PeriodicEmbedding` --- docs/src/api/Lux/layers.md | 6 ---- src/Lux.jl | 2 +- src/layers/basic.jl | 69 -------------------------------------- test/layers/basic_tests.jl | 19 ----------- 4 files changed, 1 insertion(+), 95 deletions(-) diff --git a/docs/src/api/Lux/layers.md b/docs/src/api/Lux/layers.md index c6672b25dc..1eb7da0f50 100644 --- a/docs/src/api/Lux/layers.md +++ b/docs/src/api/Lux/layers.md @@ -92,9 +92,3 @@ WeightNorm PixelShuffle Upsample ``` - -## SciML Layers - -```@docs -PeriodicEmbedding -``` diff --git a/src/Lux.jl b/src/Lux.jl index 97696532cc..87beb07fd3 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -86,7 +86,7 @@ include("distributed/public_api.jl") # Layers export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer -export Bilinear, Dense, Embedding, Scale, PeriodicEmbedding +export Bilinear, Dense, Embedding, Scale export Conv, ConvTranspose, CrossCor, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool, AdaptiveMaxPool, AdaptiveMeanPool, Upsample, PixelShuffle export AlphaDropout, Dropout, VariationalHiddenDropout diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 1a761767d9..bd0eb1b01d 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -621,72 +621,3 @@ end function (e::Embedding)(::Tuple{}, _, ::NamedTuple) throw(ArgumentError("Input tuple must contain at least one element")) end - -""" - PeriodicEmbedding(idxs, periods) - -Create an embedding periodic in some inputs with specified periods. Input indices not in -`idxs` are passed through unchanged, but inputs in `idxs` are moved to the end of the -output and replaced with their sines, followed by their cosines (scaled appropriately to -have the specified periods). This smooth embedding preserves phase information and enforces -periodicity. - -For example, `layer = PeriodicEmbedding([2, 3], [3.0, 1.0])` will create a layer periodic in -the second input with period 3.0 and periodic in the third input with period 1.0. In this -case, `layer([a, b, c, d], st) == ([a, d, sinpi(2 / 3.0 * b), sinpi(2 / 1.0 * c), cospi(2 / 3.0 * b), cospi(2 / 1.0 * c)], st)`. - -## Arguments - - - `idxs`: Indices of the periodic inputs - - `periods`: Periods of the periodic inputs, in the same order as in `idxs` - -!!! danger "Deprecation Notice" - - This layer is deprecated and will be removed in v1. Please use the version in - [`Boltz.jl`](https://github.com/LuxDL/Boltz.jl) instead. - -# Extended Help - -## Inputs - - - `x` must be an `AbstractArray` with `issubset(idxs, axes(x, 1))` - - `st` must be a `NamedTuple` where `st.k = 2 ./ periods`, but on the same device as `x` - -## Returns - - - `AbstractArray` of size `(size(x, 1) + length(idxs), ...)` where `...` are the other - dimensions of `x`. - - `st`, unchanged -""" -struct PeriodicEmbedding{I, P} <: AbstractExplicitLayer - idxs::I - periods::P - - function PeriodicEmbedding(idxs::I, periods::P) where {I, P} - Base.depwarn("`PeriodicEmbedding` is deprecated and will be removed in v1. Please \ - use the corresponding version in `Boltz.jl` instead.", - :PeriodicEmbedding) - return new{I, P}(idxs, periods) - end -end - -initialstates(::AbstractRNG, p::PeriodicEmbedding) = (k=2 ./ p.periods,) - -function (p::PeriodicEmbedding)(x::AbstractVector, ps, st::NamedTuple) - return vec(first(p(reshape(x, :, 1), ps, st))), st -end - -function (p::PeriodicEmbedding)(x::AbstractMatrix, ps, st::NamedTuple) - other_idxs = CRC.@ignore_derivatives setdiff(axes(x, 1), p.idxs) - return ( - vcat(x[other_idxs, :], sinpi.(st.k .* x[p.idxs, :]), cospi.(st.k .* x[p.idxs, :])), - st) -end - -function (p::PeriodicEmbedding)(x::AbstractArray, ps, st::NamedTuple) - return reshape(first(p(reshape(x, size(x, 1), :), ps, st)), :, size(x)[2:end]...), st -end - -function Base.show(io::IO, p::PeriodicEmbedding) - return print(io, "PeriodicEmbedding(", p.idxs, ", ", p.periods, ")") -end diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 8030bd18ef..1c880c3800 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -104,25 +104,6 @@ __f = x -> sum(first(layer(x, ps, st))) test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) end - - @testset "PeriodicEmbedding" begin - layer = PeriodicEmbedding([2, 3], [4.0, π / 5]) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - x = randn(rng, 6, 4, 3, 2) |> aType - Δx = [0.0, 12.0, -2π / 5, 0.0, 0.0, 0.0] |> aType - - val = layer(x, ps, st)[1] |> Array - shifted_val = layer(x .+ Δx, ps, st)[1] |> Array - - @test all(val[1:4, :, :, :] .== shifted_val[1:4, :, :, :]) && all(isapprox.( - val[5:8, :, :, :], shifted_val[5:8, :, :, :]; atol=5 * eps(Float32))) - - @jet layer(x, ps, st) - - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoEnzyme()]) - end end end From fe81ec6d1065f10d570b1e24266fa07d0c283081 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 23:05:47 -0400 Subject: [PATCH 33/95] fix!: remove uses of DynamicExpressions --- Project.toml | 3 --- test/Project.toml | 2 -- test/qa_tests.jl | 1 - test/runtests.jl | 4 ---- 4 files changed, 10 deletions(-) diff --git a/Project.toml b/Project.toml index c982f9d77a..2e6f5a73a2 100644 --- a/Project.toml +++ b/Project.toml @@ -41,7 +41,6 @@ WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" @@ -55,7 +54,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] LuxComponentArraysExt = "ComponentArrays" -LuxDynamicExpressionsExt = "DynamicExpressions" LuxEnzymeExt = "Enzyme" LuxFluxExt = "Flux" LuxMLUtilsExt = "MLUtils" @@ -77,7 +75,6 @@ Compat = "4.15" ComponentArrays = "0.15.16" ConcreteStructs = "0.2.3" DispatchDoctor = "0.4.12" -DynamicExpressions = "0.16, 0.17, 0.18, 0.19" Enzyme = "0.12.26" EnzymeCore = "0.7.7" FastClosures = "0.3.2" diff --git a/test/Project.toml b/test/Project.toml index 6775daa6de..d7de0502d0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,7 +7,6 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -47,7 +46,6 @@ ChainRulesCore = "1.24" ComponentArrays = "0.15.16" DispatchDoctor = "0.4.12" Documenter = "1.4" -DynamicExpressions = "0.16, 0.17, 0.18, 0.19" Enzyme = "0.12.26" ExplicitImports = "1.9.0" ForwardDiff = "0.10.36" diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 6897c33467..42428823fd 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -34,7 +34,6 @@ end doctestexpr = quote using SimpleChains: static - using DynamicExpressions using Adapt, Lux, Random, Optimisers, Zygote, NNlib end diff --git a/test/runtests.jl b/test/runtests.jl index fb0950818d..5ac9215c02 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -71,10 +71,6 @@ Lux.set_dispatch_doctor_preferences!(; luxcore="error", luxlib="error") @test !Lux.is_extension_loaded(Val(:Zygote)) using Zygote @test Lux.is_extension_loaded(Val(:Zygote)) - - @test !Lux.is_extension_loaded(Val(:DynamicExpressions)) - using DynamicExpressions - @test Lux.is_extension_loaded(Val(:DynamicExpressions)) end # These need to be run before MPI or NCCL is ever loaded From 929d60daf073cae2b9183712508e921063ed2a82 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 Aug 2024 18:19:22 -0400 Subject: [PATCH 34/95] chore: apply formatting suggestion Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Lux.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Lux.jl b/src/Lux.jl index 87beb07fd3..c2bf71dc0b 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -22,8 +22,8 @@ using Statistics: mean using UnrolledUtilities: unrolled_map, unrolled_mapreduce import LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer, - initialparameters, initialstates, parameterlength, statelength, - outputsize, update_state, trainmode, testmode, setup, apply, replicate + initialparameters, initialstates, parameterlength, statelength, outputsize, + update_state, trainmode, testmode, setup, apply, replicate @reexport using LuxCore, LuxLib, MLDataDevices, WeightInitializers using NNlib: NNlib, DenseConvDims, PoolDims, logsigmoid, logsoftmax, maxpool, meanpool, From 330545f89fc19dde6607908bc031f7a31312c41b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 Aug 2024 18:45:56 -0400 Subject: [PATCH 35/95] test: reexport NNlib in shared test modules --- test/Project.toml | 4 ++-- test/shared_testsetup.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index d7de0502d0..5b9504c671 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,7 +2,6 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" @@ -20,6 +19,7 @@ LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -41,7 +41,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ADTypes = "1.5" Adapt = "4" Aqua = "0.8.4" -Bumper = "0.6, 0.7" ChainRulesCore = "1.24" ComponentArrays = "0.15.16" DispatchDoctor = "0.4.12" @@ -59,6 +58,7 @@ LuxLib = "1.0" LuxTestUtils = "1.1.4" MLDataDevices = "1.1" MLUtils = "0.4.3" +NNlib = "0.9.21" OneHotArrays = "0.2.5" Optimisers = "0.3.3" Pkg = "1.10" diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index 4ba455da8f..04cb774b2e 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -7,7 +7,7 @@ import Reexport: @reexport using Lux, Functors using DispatchDoctor: allow_unstable @reexport using ComponentArrays, LuxCore, LuxLib, LuxTestUtils, Random, StableRNGs, Test, - Zygote, Statistics, Enzyme, LinearAlgebra, ForwardDiff + Zygote, Statistics, Enzyme, LinearAlgebra, ForwardDiff, NNlib using MLDataDevices: default_device_rng, CPUDevice, CUDADevice, AMDGPUDevice using LuxTestUtils: check_approx From 1fd32e2afac48e341a80c4bd27f8917f7006580e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 Aug 2024 22:32:42 -0400 Subject: [PATCH 36/95] chore: mark version for release --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2e6f5a73a2..6bf1caa8dc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.0.0-DEV" +version = "1.0.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 26e5058ab93d51cf0e690410a2c8086f52dabd12 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 Aug 2024 22:43:29 -0400 Subject: [PATCH 37/95] chore: update compat entries for examples --- Project.toml | 2 +- examples/Basics/Project.toml | 6 +++--- examples/BayesianNN/Project.toml | 4 ++-- examples/ConvMixer/Project.toml | 2 +- examples/DDIM/Project.toml | 2 +- examples/GravitationalWaveForm/Project.toml | 4 +--- examples/GravitationalWaveForm/main.jl | 2 +- examples/HyperNet/Project.toml | 20 +++++++++----------- examples/HyperNet/main.jl | 2 +- examples/ImageNet/Project.toml | 14 +++++++------- examples/ImageNet/main.jl | 2 +- examples/NeuralODE/Project.toml | 6 ++---- examples/NeuralODE/main.jl | 2 +- examples/PolynomialFitting/Project.toml | 6 ++---- examples/PolynomialFitting/main.jl | 2 +- examples/SimpleChains/Project.toml | 4 ++-- examples/SimpleRNN/Project.toml | 12 +++++------- examples/SimpleRNN/main.jl | 2 +- 18 files changed, 42 insertions(+), 52 deletions(-) diff --git a/Project.toml b/Project.toml index 6bf1caa8dc..3ebf6b75a3 100644 --- a/Project.toml +++ b/Project.toml @@ -99,7 +99,7 @@ Preferences = "1.4.3" Random = "1.10" Reexport = "1.2.2" ReverseDiff = "1.15" -SIMDTypes = "0.1.0" +SIMDTypes = "0.1" Setfield = "1.1.1" SimpleChains = "0.4.7" Static = "1.1.1" diff --git a/examples/Basics/Project.toml b/examples/Basics/Project.toml index 9e0c4c2943..01b75c2e2a 100644 --- a/examples/Basics/Project.toml +++ b/examples/Basics/Project.toml @@ -14,7 +14,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ComponentArrays = "0.15" ForwardDiff = "0.10" Literate = "2" -Lux = "0.5.56" -LuxCUDA = "0.2, 0.3" -Optimisers = "0.2, 0.3" +Lux = "1" +LuxCUDA = "0.3" +Optimisers = "0.3" Zygote = "0.6" diff --git a/examples/BayesianNN/Project.toml b/examples/BayesianNN/Project.toml index d4b30f07c3..8d6c24c2ef 100644 --- a/examples/BayesianNN/Project.toml +++ b/examples/BayesianNN/Project.toml @@ -15,8 +15,8 @@ CairoMakie = "0.12" Functors = "0.4" LinearAlgebra = "1" Literate = "2" -Lux = "0.5" +Lux = "1" Random = "1" Tracker = "0.2" -Turing = "0.30, 0.31, 0.32, 0.33, 0.34" +Turing = "0.34" Zygote = "0.6.69" diff --git a/examples/ConvMixer/Project.toml b/examples/ConvMixer/Project.toml index 97da256d59..35d1b5fa95 100644 --- a/examples/ConvMixer/Project.toml +++ b/examples/ConvMixer/Project.toml @@ -26,7 +26,7 @@ DataAugmentation = "0.2.12, 0.3" ImageCore = "0.10.2" ImageShow = "0.3.8" Interpolations = "0.15.1" -Lux = "0.5.53" +Lux = "1" LuxCUDA = "0.3.2" MLDatasets = "0.7.14" MLUtils = "0.4.4" diff --git a/examples/DDIM/Project.toml b/examples/DDIM/Project.toml index 21d9914cdf..60166460e8 100644 --- a/examples/DDIM/Project.toml +++ b/examples/DDIM/Project.toml @@ -37,7 +37,7 @@ FileIO = "1.16" ImageCore = "0.9, 0.10" ImageIO = "0.6" JLD2 = "0.4.48" -Lux = "0.5.52" +Lux = "1" LuxCUDA = "0.3" MLUtils = "0.4" Optimisers = " 0.3" diff --git a/examples/GravitationalWaveForm/Project.toml b/examples/GravitationalWaveForm/Project.toml index b60e84cd24..d052bce7b0 100644 --- a/examples/GravitationalWaveForm/Project.toml +++ b/examples/GravitationalWaveForm/Project.toml @@ -5,7 +5,6 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" @@ -20,9 +19,8 @@ ComponentArrays = "0.15" LineSearches = "7" Literate = "2" Lux = "0.5" -AMDGPU = "0.9.6, 1" LuxCUDA = "0.3" Optimization = "3" -OptimizationOptimJL = "0.1, 0.2, 0.3" +OptimizationOptimJL = "0.3" OrdinaryDiffEq = "6" SciMLSensitivity = "7.57" diff --git a/examples/GravitationalWaveForm/main.jl b/examples/GravitationalWaveForm/main.jl index 718eefd462..345f027258 100644 --- a/examples/GravitationalWaveForm/main.jl +++ b/examples/GravitationalWaveForm/main.jl @@ -7,7 +7,7 @@ # ## Package Imports -using Lux, ComponentArrays, LineSearches, AMDGPU, LuxCUDA, OrdinaryDiffEq, Optimization, +using Lux, ComponentArrays, LineSearches, LuxCUDA, OrdinaryDiffEq, Optimization, OptimizationOptimJL, Printf, Random, SciMLSensitivity using CairoMakie diff --git a/examples/HyperNet/Project.toml b/examples/HyperNet/Project.toml index 161ecc3c89..354d535b55 100644 --- a/examples/HyperNet/Project.toml +++ b/examples/HyperNet/Project.toml @@ -4,7 +4,6 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" @@ -17,16 +16,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "0.2, 1" -ComponentArrays = "0.13, 0.14, 0.15" +ADTypes = "1" +ComponentArrays = "0.15" Literate = "2" -Lux = "0.5" -AMDGPU = "0.9.6, 1" -LuxCUDA = "0.2, 0.3" -MLDatasets = "0.5, 0.7" -MLUtils = "0.2, 0.3, 0.4" -OneHotArrays = "0.1, 0.2" -Optimisers = "0.2, 0.3" -Setfield = "0.8, 1" +Lux = "1" +LuxCUDA = "0.3" +MLDatasets = "0.7" +MLUtils = "0.4" +OneHotArrays = "0.2" +Optimisers = "0.3" +Setfield = "1" Statistics = "1" Zygote = "0.6" diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index f522f37f4a..3c3b119eb7 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -2,7 +2,7 @@ # ## Package Imports -using Lux, ADTypes, ComponentArrays, AMDGPU, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, +using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random, Setfield, Statistics, Zygote CUDA.allowscalar(false) diff --git a/examples/ImageNet/Project.toml b/examples/ImageNet/Project.toml index c4ecbdc4ee..f5dd3601ff 100644 --- a/examples/ImageNet/Project.toml +++ b/examples/ImageNet/Project.toml @@ -27,25 +27,25 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +AMDGPU = "1" Augmentor = "0.6" Boltz = "0.1, 0.2, 0.3" Configurations = "0.17" FLoops = "0.2" FileIO = "1.16" Format = "1.3" -Functors = "0.2, 0.3, 0.4" +Functors = "0.4" Images = "0.26" JLD2 = "0.4.46" JpegTurbo = "0.1" -Lux = "0.5" -AMDGPU = "0.9.6, 1" -LuxCUDA = "0.2, 0.3" -MLUtils = "0.2.10, 0.3, 0.4" +Lux = "1" +LuxCUDA = "0.3" +MLUtils = "0.4" MPI = "0.20.19" Metalhead = "0.9" NCCL = "0.1.1" -OneHotArrays = "0.1, 0.2" -Optimisers = "0.2, 0.3" +OneHotArrays = "0.2" +Optimisers = "0.3" ParameterSchedulers = "0.4" Setfield = "1" SimpleConfig = "0.1" diff --git a/examples/ImageNet/main.jl b/examples/ImageNet/main.jl index 6332352950..5a8f225db6 100644 --- a/examples/ImageNet/main.jl +++ b/examples/ImageNet/main.jl @@ -5,7 +5,7 @@ using Augmentor, Configurations, Dates, FileIO, Functors, Images, MLUtils, OneHo import FLoops: ThreadedEx import Metalhead import MPI, NCCL -using AMDGPU, LuxCUDA +using LuxCUDA using Format # Distributed Training: NCCL for NVIDIA GPUs and MPI for anything else diff --git a/examples/NeuralODE/Project.toml b/examples/NeuralODE/Project.toml index 3893288566..f586f60679 100644 --- a/examples/NeuralODE/Project.toml +++ b/examples/NeuralODE/Project.toml @@ -1,5 +1,4 @@ [deps] -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" @@ -17,10 +16,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -AMDGPU = "0.9.6, 1" -ComponentArrays = "0.13, 0.14, 0.15" +ComponentArrays = "0.15" Literate = "2" -Lux = "0.5" +Lux = "1" LuxCUDA = "0.2, 0.3" MLDatasets = "0.5, 0.7" MLUtils = "0.2, 0.3, 0.4" diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index 97a201fd12..4510fa8dda 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -7,7 +7,7 @@ # ## Package Imports -using Lux, ComponentArrays, SciMLSensitivity, AMDGPU, LuxCUDA, Optimisers, OrdinaryDiffEq, +using Lux, ComponentArrays, SciMLSensitivity, LuxCUDA, Optimisers, OrdinaryDiffEq, Random, Statistics, Zygote, OneHotArrays, InteractiveUtils, Printf import MLDatasets: MNIST import MLUtils: DataLoader, splitobs diff --git a/examples/PolynomialFitting/Project.toml b/examples/PolynomialFitting/Project.toml index 15eb039d58..a5c1183548 100644 --- a/examples/PolynomialFitting/Project.toml +++ b/examples/PolynomialFitting/Project.toml @@ -4,7 +4,6 @@ CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -13,11 +12,10 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "0.2, 1" +ADTypes = "1" CairoMakie = "0.12" Literate = "2" -Lux = "0.5" -AMDGPU = "0.9.6, 1" +Lux = "1" LuxCUDA = "0.3" Optimisers = "0.3" Statistics = "1" diff --git a/examples/PolynomialFitting/main.jl b/examples/PolynomialFitting/main.jl index 217f1f3e1b..50f32b447f 100644 --- a/examples/PolynomialFitting/main.jl +++ b/examples/PolynomialFitting/main.jl @@ -5,7 +5,7 @@ # ## Package Imports -using Lux, ADTypes, AMDGPU, LuxCUDA, Optimisers, Printf, Random, Statistics, Zygote +using Lux, ADTypes, LuxCUDA, Optimisers, Printf, Random, Statistics, Zygote using CairoMakie # ## Dataset diff --git a/examples/SimpleChains/Project.toml b/examples/SimpleChains/Project.toml index 1ff7ce3a2a..009fd8dcad 100644 --- a/examples/SimpleChains/Project.toml +++ b/examples/SimpleChains/Project.toml @@ -13,9 +13,9 @@ SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "0.2, 1" +ADTypes = "1" Literate = "2" -Lux = "0.5.20" +Lux = "1" MLDatasets = "0.7.14" MLUtils = "0.4" OneHotArrays = "0.2.5" diff --git a/examples/SimpleRNN/Project.toml b/examples/SimpleRNN/Project.toml index 0932fbdcc8..9917b042a1 100644 --- a/examples/SimpleRNN/Project.toml +++ b/examples/SimpleRNN/Project.toml @@ -4,7 +4,6 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" @@ -14,13 +13,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -ADTypes = "0.2.6, 1" +ADTypes = "1" JLD2 = "0.4" Literate = "2" -Lux = "0.5" -AMDGPU = "0.9.6, 1" -LuxCUDA = "0.2, 0.3" -MLUtils = "0.2, 0.3, 0.4" -Optimisers = "0.2, 0.3" +Lux = "1" +LuxCUDA = "0.3" +MLUtils = "0.4" +Optimisers = "0.3" Statistics = "1" Zygote = "0.6" diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index f4d22071ad..9fcef7b61d 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -9,7 +9,7 @@ # ## Package Imports -using ADTypes, Lux, AMDGPU, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, +using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics # ## Dataset From 6d320b25387266ff7f2c02a127997b32465c468c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 Aug 2024 22:53:12 -0400 Subject: [PATCH 38/95] fix: remove old code from benchmarks --- benchmarks/setups/models.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/setups/models.jl b/benchmarks/setups/models.jl index c5f8146f0c..b4f3763039 100644 --- a/benchmarks/setups/models.jl +++ b/benchmarks/setups/models.jl @@ -25,10 +25,10 @@ function setup_vgg16_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, conv_bn((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), conv_bn((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), conv_bn((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), - MaxPool((2, 2)); disable_optimizations=true), + MaxPool((2, 2))), FlattenLayer(), Chain(Dense(512, 4096, relu), Dropout(0.5f0), Dense(4096, 4096, relu), - Dropout(0.5f0), Dense(4096, 10); name="Classifier"); disable_optimizations=true) + Dropout(0.5f0), Dense(4096, 10); name="Classifier")) for bsize in (32, 64, 128) setup_forward_pass_benchmark!(suite, "vgg16(32, 32, 3, $bsize)", From 6cb1b52b621989e7dce81c97721ec58697c30ae0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 31 Aug 2024 09:28:28 -0400 Subject: [PATCH 39/95] chore: run formatter --- examples/HyperNet/main.jl | 4 ++-- examples/NeuralODE/main.jl | 4 ++-- examples/SimpleRNN/main.jl | 3 +-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index 3c3b119eb7..f20f94643b 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -2,8 +2,8 @@ # ## Package Imports -using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, - Optimisers, Printf, Random, Setfield, Statistics, Zygote +using Lux, ADTypes, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers, + Printf, Random, Setfield, Statistics, Zygote CUDA.allowscalar(false) diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index 4510fa8dda..ccc6b07b32 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -7,8 +7,8 @@ # ## Package Imports -using Lux, ComponentArrays, SciMLSensitivity, LuxCUDA, Optimisers, OrdinaryDiffEq, - Random, Statistics, Zygote, OneHotArrays, InteractiveUtils, Printf +using Lux, ComponentArrays, SciMLSensitivity, LuxCUDA, Optimisers, OrdinaryDiffEq, Random, + Statistics, Zygote, OneHotArrays, InteractiveUtils, Printf import MLDatasets: MNIST import MLUtils: DataLoader, splitobs diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index 9fcef7b61d..b85692fe4c 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -9,8 +9,7 @@ # ## Package Imports -using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, - Statistics +using ADTypes, Lux, LuxCUDA, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics # ## Dataset From a52f3d91a74b5cd1a328de58fb1e06bc8d2d7d0a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 31 Aug 2024 13:27:12 -0400 Subject: [PATCH 40/95] fix: qa testing --- ext/LuxTrackerExt/LuxTrackerExt.jl | 1 - test/qa_tests.jl | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/LuxTrackerExt/LuxTrackerExt.jl b/ext/LuxTrackerExt/LuxTrackerExt.jl index 8ef071ee51..34dd0d5270 100644 --- a/ext/LuxTrackerExt/LuxTrackerExt.jl +++ b/ext/LuxTrackerExt/LuxTrackerExt.jl @@ -7,7 +7,6 @@ using Tracker: Tracker, TrackedArray, TrackedReal, @grad_from_chainrules using Lux: Lux, Utils using Lux.Training: TrainingBackendCache, TrainState -using MLDataDevices: CPUDevice const CRC = ChainRulesCore diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 42428823fd..074f464b09 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -1,12 +1,13 @@ @testitem "Aqua: Quality Assurance" tags=[:others] begin using Aqua, ChainRulesCore, ForwardDiff - Aqua.test_all(Lux; ambiguities=false) + Aqua.test_all(Lux; ambiguities=false, piracies=false) Aqua.test_ambiguities(Lux; exclude=[ForwardDiff.jacobian, ForwardDiff.gradient, Lux.AutoDiffInternalImpl.batched_jacobian, Lux.AutoDiffInternalImpl.jacobian_vector_product, Lux.AutoDiffInternalImpl.jacobian_vector_product_impl]) + Aqua.test_piracies(Lux; treat_as_own=[Lux.outputsize]) end @testitem "Explicit Imports: Quality Assurance" setup=[SharedTestSetup] tags=[:others] begin From 80c7f0bf59d74065f1c04d42e4a4128a0a6574c5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 31 Aug 2024 17:04:31 -0400 Subject: [PATCH 41/95] feat: controlled reexport of NNlib --- src/Lux.jl | 40 ++++++++++++++++++++++++++++++++++++++++ test/shared_testsetup.jl | 2 +- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/Lux.jl b/src/Lux.jl index c2bf71dc0b..c8d668bec1 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -124,4 +124,44 @@ export LuxOps @compat public Experimental @compat public set_dispatch_doctor_preferences! +# NNlib.jl reexports +## Functional API for common layers. Recommended to use the LuxLib versions +using NNlib: ConvDims, DenseConvDims, PoolDims, batched_adjoint, batched_mul, batched_mul!, + batched_transpose, batched_vec, bias_act!, conv, conv!, conv_bias_act, + conv_bias_act!, dot_product_attention, dot_product_attention_scores, + make_causal_mask, lpnormpool, lpnormpool!, maxpool, maxpool!, meanpool, + meanpool!, pixel_shuffle, imrotate, ∇conv_data, ∇conv_data!, ∇conv_filter, + ∇conv_filter!, ∇lpnormpool, ∇lpnormpool!, ∇maxpool, ∇maxpool!, ∇meanpool, + ∇meanpool!, ∇imrotate +export ConvDims, DenseConvDims, PoolDims, batched_adjoint, batched_mul, batched_mul!, + batched_transpose, batched_vec, bias_act!, conv, conv!, conv_bias_act, + conv_bias_act!, dot_product_attention, dot_product_attention_scores, + make_causal_mask, lpnormpool, lpnormpool!, maxpool, maxpool!, meanpool, meanpool!, + pixel_shuffle, imrotate, ∇conv_data, ∇conv_data!, ∇conv_filter, ∇conv_filter!, + ∇lpnormpool, ∇lpnormpool!, ∇maxpool, ∇maxpool!, ∇meanpool, ∇meanpool!, ∇imrotate + +## Padding +using NNlib: pad_circular, pad_constant, pad_reflect, pad_repeat, pad_symmetric, pad_zeros +export pad_circular, pad_constant, pad_reflect, pad_repeat, pad_symmetric, pad_zeros + +## Upsample +using NNlib: upsample_linear, upsample_bilinear, upsample_trilinear, upsample_nearest, + ∇upsample_linear, ∇upsample_bilinear, ∇upsample_trilinear, ∇upsample_nearest +export upsample_linear, upsample_bilinear, upsample_trilinear, upsample_nearest, + ∇upsample_linear, ∇upsample_bilinear, ∇upsample_trilinear, ∇upsample_nearest + +## Activation Functions +using NNlib: σ, celu, elu, gelu, glu, hardsigmoid, hardswish, hardtanh, hardσ, leakyrelu, + lisht, logcosh, logsigmoid, logσ, mish, relu, relu6, rrelu, selu, sigmoid, + sigmoid_fast, softplus, softshrink, softsign, swish, tanhshrink, tanh_fast, + thresholdrelu, trelu +export σ, celu, elu, gelu, glu, hardsigmoid, hardswish, hardtanh, hardσ, leakyrelu, lisht, + logcosh, logsigmoid, logσ, mish, relu, relu6, rrelu, selu, sigmoid, sigmoid_fast, + softplus, softshrink, softsign, swish, tanhshrink, tanh_fast, thresholdrelu, trelu + +using NNlib: softmax, softmax!, logsoftmax, logsoftmax!, logsumexp, ∇logsoftmax, + ∇logsoftmax!, ∇softmax, ∇softmax! +export softmax, softmax!, logsoftmax, logsoftmax!, logsumexp, ∇logsoftmax, ∇logsoftmax!, + ∇softmax, ∇softmax! + end diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index 04cb774b2e..4ba455da8f 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -7,7 +7,7 @@ import Reexport: @reexport using Lux, Functors using DispatchDoctor: allow_unstable @reexport using ComponentArrays, LuxCore, LuxLib, LuxTestUtils, Random, StableRNGs, Test, - Zygote, Statistics, Enzyme, LinearAlgebra, ForwardDiff, NNlib + Zygote, Statistics, Enzyme, LinearAlgebra, ForwardDiff using MLDataDevices: default_device_rng, CPUDevice, CUDADevice, AMDGPUDevice using LuxTestUtils: check_approx From b9bbf99904374818681f8a6fd3edd4e8add3b4fe Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 1 Sep 2024 11:46:24 -0400 Subject: [PATCH 42/95] docs: add a migration to v1 docs --- README.md | 3 + docs/make.jl | 1 + docs/src/.vitepress/config.mts | 1 + docs/src/api/Building_Blocks/LuxLib.md | 2 +- docs/src/introduction/index.md | 5 ++ docs/src/introduction/updating_to_v1.md | 109 ++++++++++++++++++++++++ 6 files changed, 120 insertions(+), 1 deletion(-) create mode 100644 docs/src/introduction/updating_to_v1.md diff --git a/README.md b/README.md index bd747cb396..f21a9ae7e5 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,9 @@ import Pkg Pkg.add("Lux") ``` +> [!TIP] +> If you are using a pre-v1 version of Lux.jl, please see the [Updating to v1 section](https://lux.csail.mit.edu/dev/introduction/updating_to_v1/) for instructions on how to update. + ## 🤸 Quickstart ```julia diff --git a/docs/make.jl b/docs/make.jl index 564d46bad2..c67eb73d1f 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -14,6 +14,7 @@ pages = [ "Introduction" => "introduction/index.md", "Overview" => "introduction/overview.md", "Resources" => "introduction/resources.md", + "Updating to v1" => "introduction/updating_to_v1.md", "Citation" => "introduction/citation.md" ], "Tutorials" => [ diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index eca4ec34bd..05069ee867 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -111,6 +111,7 @@ export default defineConfig({ { text: 'Introduction', link: '/introduction' }, { text: 'Overview', link: '/introduction/overview' }, { text: 'Resources', link: '/introduction/resources' }, + { text: 'Updating to v1', link: '/introduction/updating_to_v1' }, { text: 'Citation', link: '/introduction/citation' }] }, "/tutorials/": { diff --git a/docs/src/api/Building_Blocks/LuxLib.md b/docs/src/api/Building_Blocks/LuxLib.md index 8075d83ce0..21bbe1510a 100644 --- a/docs/src/api/Building_Blocks/LuxLib.md +++ b/docs/src/api/Building_Blocks/LuxLib.md @@ -1,4 +1,4 @@ -# LuxLib +# [LuxLib](@id LuxLib-API) Backend for Lux.jl diff --git a/docs/src/introduction/index.md b/docs/src/introduction/index.md index 46a35f2839..43a1c8717d 100644 --- a/docs/src/introduction/index.md +++ b/docs/src/introduction/index.md @@ -11,6 +11,11 @@ import Pkg Pkg.add("Lux") ``` +!!! tip "Update to v1" + + If you are using a pre-v1 version of Lux.jl, please see the + [Updating to v1 section](@ref updating-to-v1) for instructions on how to update. + ## Quickstart !!! tip "Pre-Requisites" diff --git a/docs/src/introduction/updating_to_v1.md b/docs/src/introduction/updating_to_v1.md new file mode 100644 index 0000000000..67e96e929e --- /dev/null +++ b/docs/src/introduction/updating_to_v1.md @@ -0,0 +1,109 @@ +# [Updating to Lux v1](@id updating-to-v1) + +Lux v1 is a Major Release, mostly to signify the stability of the API. In this page, we list +out a concrete set of changes that need to be made to your code to update to Lux v1. We also +list out some new exciting features that were added as part of this release. + +## `LuxLib.jl` + +### Breaking Changes + +- Old deprecated API with keyword arguments has been removed. See the new docs in [LuxLib + API](@ref LuxLib-API) for more details. +- Default for [`layernorm`](@ref) dims has been changed to exclude the batch dimension. + +## `LuxCore.jl` + +### Breaking Changes + +- `AbstractExplicitLayer` has been renamed to `AbstractLuxLayer`. +- `AbstractExplicitContainerLayer` behaviour + - This has been renamed to `AbstractLuxContainerLayer`. + - Previously, `AbstractExplicitContainerLayer{(:a,)}` (i.e. singleton containers) would + produce default initial parameters and states without wrapping them in a + `NamedTuple{(:a,)}`. This was inconsistent with non-singleton containers, and was a + source of confusion. With `v` we return `(; a = )` and `(; a = )` + by default. See [`AbstractLuxWrapperLayer`](@ref) for a replacement of this + functionality. +- `inputsize` has been removed since it was ambiguous and not used anywhere. +- Changes to `outputsize`: + - Single argument version has been removed. See [LuxCore.jl Pull Request + 43](https://github.com/LuxDL/LuxCore.jl/pull/43#issuecomment-2254232817) for more + details on the rationale behind this change. + - Fallback implementation has been moved to `Lux.jl`. (i.e. users using Lux shouldn't + see a difference, but if `Lux.jl` isn't loaded, this function has error.) + - Internally this uses a `NilArray` that is able to compute sizes without actually + running the computation. +- `Functors` and `Setfield` have been made into optional dependencies. Certain `LuxCore` + functionality that rely on these functions, will throw an error if these packages are not + loaded. + +### New Major Features + +- Introduction of [`AbstractLuxWrapperLayer`](@ref). This behaves exactly like the old + singleton container. For example, the old `AbstractExplicitContainerLayer{(:a,)}` is + equivalent to `AbstractLuxWrapperLayer{:a}`. + +## `WeightInitializers.jl` + +This was a major release to signify the stability of the API. There were no breaking +changes. We do support a wider range of RNG types, see +[Supported RNG Types](@ref Supported-RNG-Types-WeightInit) for more details. + +## `MLDataDevices.jl` + +This is the most aggressive change that was made. We renamed the `LuxDeviceUtils.jl` package +to `MLDataDevices.jl`, to allow for non-Lux packages to use this shared device management +abstraction. + +!!! warning "Deprecation of `LuxDeviceUtils.jl`" + + This also marks the deprecation of the `LuxDeviceUtils.jl` package. We won't be making + any updates to that package, including fixing any bugs. All users should switch to + `MLDataDevices.jl` instead. + +### Breaking Changes + +- `Lux(___)Device` objects have been renamed to `(___)Device`. For example, `LuxCUDADevice` + has been renamed to `CUDADevice`. +- `Lux(___)Adaptor` objects have been removed. The corresponding `Device` objects should be + used directly instead. + +### New Major Features + +- [`DeviceIterator`](@ref) provides a generalization of `CUDA.CuIterator` and works for all + backends and more data types (using `Functors.jl`). `MLUtils.DataLoader |> gdev` now + returns a `DeviceIterator` instead of being a no-op. + +## `Lux.jl` + +### Breaking Changes (Removed Functionality) + +- Direct reexport of `NNlib` has been removed. We reexport selected functionality from + `NNlib`. Direactly load `NNlib` if you need to use the other functions. +- Flattening of [`Chain`](@ref) layers has been removed, and the corresponding + `disable_optimizations` kwarg has been removed. +- Some layers overloaded `Base.keys`, these have been removed. These were mostly + un-documented and weren't supposed to be used outside of the `Lux.jl` package. +- [`Training.TrainState`](@ref) construction with `rng` has been removed. +- Older versions of Preferences have been removed. +- `disable_stacktrace_truncation!` has been removed. From Julia 1.9 onwards, stacktrace + truncation is enabled by default. +- Certain Experimental features were present outside the `Lux.Experimental` module. These + have been removed, use them via `Lux.Experimental` instead. Run Julia with with `depwarn` + as `error` and Lux `v0.5` to see the deprecations. + +### Breaking Changes (Moved Functionality) + +- `Lux.Experimental.Training` has been moved to `Lux.Training`. We guarantee SemVar + on this new module. +- `Lux.cpu` and `Lux.gpu` have been removed. Use [`cpu_device`](@ref) and + [`gpu_device`](@ref) instead. +- `Experimental.@compact` can be directly used via [`@compact`](@ref) now. +- `Experimental.StatefulLuxLayer` has been moved to [`Lux.StatefulLuxLayer`](@ref). +- `st_fixed_path` kwarg has been removed from [`Lux.StatefulLuxLayer`](@ref), instead use it + as `StatefulLuxLayer{st_fixed_path}(...)`. +- Strings as inputs to [`Experimental.@layer_map`](@ref) and + [`Experimental.@debug_mode`](@ref) are removed, use `Functors.KeyPath` instead. + +### Breaking Changes (Changes in Defaults) From 020afeb427089af72f309fd869c003ab9661080e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Sep 2024 10:28:41 -0400 Subject: [PATCH 43/95] test: try fixing the tests --- examples/GravitationalWaveForm/Project.toml | 2 +- test/helpers/training_tests.jl | 59 ++++++++++----------- 2 files changed, 29 insertions(+), 32 deletions(-) diff --git a/examples/GravitationalWaveForm/Project.toml b/examples/GravitationalWaveForm/Project.toml index d052bce7b0..73f3948089 100644 --- a/examples/GravitationalWaveForm/Project.toml +++ b/examples/GravitationalWaveForm/Project.toml @@ -18,7 +18,7 @@ CairoMakie = "0.12" ComponentArrays = "0.15" LineSearches = "7" Literate = "2" -Lux = "0.5" +Lux = "1" LuxCUDA = "0.3" Optimization = "3" OptimizationOptimJL = "0.3" diff --git a/test/helpers/training_tests.jl b/test/helpers/training_tests.jl index 4d1cd0ccd0..fc6c307d9f 100644 --- a/test/helpers/training_tests.jl +++ b/test/helpers/training_tests.jl @@ -126,51 +126,48 @@ end end end -@testitem "Enzyme: Invalidate Cache on State Update" setup=[SharedTestSetup] tags=[:helpers] begin +@testitem "Enzyme: Invalidate Cache on State Update" setup=[SharedTestSetup] tags=[:helpers] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin using ADTypes, Optimisers using Enzyme - if LuxTestUtils.ENZYME_TESTING_ENABLED - Enzyme.API.runtimeActivity!(true) + Enzyme.API.runtimeActivity!(true) - mse = MSELoss() - function mse2(model, ps, st, (x, y)) - z, st = model(x, ps, st) - return sum(abs2, z .- y), st, () - end + mse = MSELoss() - rng = StableRNG(12345) + function mse2(model, ps, st, (x, y)) + z, st = model(x, ps, st) + return sum(abs2, z .- y), st, () + end - model = Chain(Dense(4 => 3), VariationalHiddenDropout(0.5f0), Dense(3 => 4)) - ps, st = Lux.setup(rng, model) - x = randn(rng, Float32, 4, 32) - opt = Adam(0.001f0) + rng = StableRNG(12345) - tstate = Training.TrainState(model, ps, st, opt) + model = Chain(Dense(4 => 3), VariationalHiddenDropout(0.5f0), Dense(3 => 4)) + ps, st = Lux.setup(rng, model) + x = randn(rng, Float32, 4, 32) + opt = Adam(0.001f0) - _, _, _, tstate_new = @inferred Training.compute_gradients( - AutoEnzyme(), mse, (x, x), tstate) + tstate = Training.TrainState(model, ps, st, opt) - @test tstate_new.states !== tstate.states + _, _, _, tstate_new = @inferred Training.compute_gradients( + AutoEnzyme(), mse, (x, x), tstate) - model = Chain(Dense(4 => 3), Dense(3 => 4)) - ps, st = Lux.setup(rng, model) + @test tstate_new.states !== tstate.states - tstate = Training.TrainState(model, ps, st, opt) + model = Chain(Dense(4 => 3), Dense(3 => 4)) + ps, st = Lux.setup(rng, model) - _, _, _, tstate_new = @inferred Training.compute_gradients( - AutoEnzyme(), mse, (x, x), tstate) + tstate = Training.TrainState(model, ps, st, opt) - @test @inferred(Training.compute_gradients( - AutoEnzyme(), mse, (x, x), tstate_new)) isa Any + _, _, _, tstate_new = @inferred Training.compute_gradients( + AutoEnzyme(), mse, (x, x), tstate) - _, _, _, tstate_new2 = @inferred Training.compute_gradients( - AutoEnzyme(), mse2, (x, x), tstate_new) - @test hasfield(typeof(tstate_new2.cache.extras), :forward) - @test hasfield(typeof(tstate_new2.cache.extras), :reverse) - else - @test_broken false - end + @test @inferred(Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new)) isa + Any + + _, _, _, tstate_new2 = @inferred Training.compute_gradients( + AutoEnzyme(), mse2, (x, x), tstate_new) + @test hasfield(typeof(tstate_new2.cache.extras), :forward) + @test hasfield(typeof(tstate_new2.cache.extras), :reverse) rng = StableRNG(12345) From 9f0cf238a5692081caf82e98ea39c66a8eafdd38 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Sep 2024 13:30:10 -0400 Subject: [PATCH 44/95] fix: missing state type in StatefulLuxLayer --- examples/BayesianNN/main.jl | 2 +- examples/GravitationalWaveForm/main.jl | 2 +- examples/NeuralODE/main.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/BayesianNN/main.jl b/examples/BayesianNN/main.jl index 62525f3f00..7eea940a82 100644 --- a/examples/BayesianNN/main.jl +++ b/examples/BayesianNN/main.jl @@ -110,7 +110,7 @@ end # To interface with external libraries it is often desirable to use the # [`StatefulLuxLayer`](@ref) to automatically handle the neural network states. -const model = StatefulLuxLayer(nn, st) +const model = StatefulLuxLayer{true}(nn, st) ## Specify the probabilistic model. @model function bayes_nn(xs, ts) diff --git a/examples/GravitationalWaveForm/main.jl b/examples/GravitationalWaveForm/main.jl index 345f027258..34e901a6e4 100644 --- a/examples/GravitationalWaveForm/main.jl +++ b/examples/GravitationalWaveForm/main.jl @@ -234,7 +234,7 @@ ps, st = Lux.setup(Xoshiro(), nn) const params = ComponentArray{Float64}(ps) -const nn_model = StatefulLuxLayer(nn, st) +const nn_model = StatefulLuxLayer{true}(nn, st) # Now we define a system of odes which describes motion of point like particle with # Newtonian physics, uses diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index ccc6b07b32..013b64f20d 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -190,7 +190,7 @@ function StatefulNeuralODE( end function (n::StatefulNeuralODE)(x, ps, st) - st_model = StatefulLuxLayer(n.model, ps, st) + st_model = StatefulLuxLayer{true}(n.model, ps, st) dudt(u, p, t) = st_model(u, p) prob = ODEProblem{false}(ODEFunction{false}(dudt), x, n.tspan, ps) return solve(prob, n.solver; n.kwargs...), st_model.st From b3c746aa4bcd771075e6fa2cc6cff969171f15a4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Sep 2024 14:19:30 -0400 Subject: [PATCH 45/95] chore: remove unnecessary `.0` --- docs/Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 865bcada48..2126c3122d 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -40,8 +40,8 @@ LinearAlgebra = "1.10" Literate = "2.18.0" Lux = "1" LuxCUDA = "0.3.2" -LuxCore = "1.0" -LuxLib = "1.0" +LuxCore = "1" +LuxLib = "1" LuxTestUtils = "1.1" MLDataDevices = "1.1" Optimisers = "0.3.3" From e074134aacff6106968bb3d32c5001f9c6e329b8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Sep 2024 15:22:15 -0400 Subject: [PATCH 46/95] fix!: cleanup the implementation of `layer_map` --- docs/src/api/Lux/contrib.md | 1 - docs/src/introduction/updating_to_v1.md | 5 +- docs/src/manual/freezing_model_parameters.md | 4 +- src/contrib/contrib.jl | 2 +- src/contrib/map.jl | 119 +++++++------------ test/contrib/map_tests.jl | 43 ++++--- 6 files changed, 73 insertions(+), 101 deletions(-) diff --git a/docs/src/api/Lux/contrib.md b/docs/src/api/Lux/contrib.md index e79d6872e7..93f412d3e7 100644 --- a/docs/src/api/Lux/contrib.md +++ b/docs/src/api/Lux/contrib.md @@ -27,7 +27,6 @@ For detailed usage example look at the [manual page](@ref freezing-model-paramet ## Map over Layer ```@docs -Lux.Experimental.@layer_map Lux.Experimental.layer_map ``` diff --git a/docs/src/introduction/updating_to_v1.md b/docs/src/introduction/updating_to_v1.md index 67e96e929e..c64f5037e3 100644 --- a/docs/src/introduction/updating_to_v1.md +++ b/docs/src/introduction/updating_to_v1.md @@ -92,6 +92,9 @@ abstraction. - Certain Experimental features were present outside the `Lux.Experimental` module. These have been removed, use them via `Lux.Experimental` instead. Run Julia with with `depwarn` as `error` and Lux `v0.5` to see the deprecations. +- `Lux.Experimental.@layer_map` is not longer needed and has been removed. The name of the + variable prevents writing generic functions and is no longer pre-pended to the `KeyPath`. + See the docstring of [`Lux.Experimental.layer_map`](@ref) for more details. ### Breaking Changes (Moved Functionality) @@ -103,7 +106,7 @@ abstraction. - `Experimental.StatefulLuxLayer` has been moved to [`Lux.StatefulLuxLayer`](@ref). - `st_fixed_path` kwarg has been removed from [`Lux.StatefulLuxLayer`](@ref), instead use it as `StatefulLuxLayer{st_fixed_path}(...)`. -- Strings as inputs to [`Experimental.@layer_map`](@ref) and +- Strings as inputs to [`Experimental.layer_map`](@ref) and [`Experimental.@debug_mode`](@ref) are removed, use `Functors.KeyPath` instead. ### Breaking Changes (Changes in Defaults) diff --git a/docs/src/manual/freezing_model_parameters.md b/docs/src/manual/freezing_model_parameters.md index 027de13c64..14425de4fd 100644 --- a/docs/src/manual/freezing_model_parameters.md +++ b/docs/src/manual/freezing_model_parameters.md @@ -9,7 +9,7 @@ In this manual entry, we will go over how to freeze certain parameters in a mode ## Freezing Layers of a Particular Kind To freeze a particular kind of layer, let's say [`Dense`](@ref) in the following example. -We can use [`Lux.Experimental.@layer_map`](@ref) and freeze layers if they are of type +We can use [`Lux.Experimental.layer_map`](@ref) and freeze layers if they are of type `Dense`. ```@example freezing_model_parameters @@ -30,7 +30,7 @@ function freeze_dense(d::Lux.Dense, ps, st, path) end freeze_dense(l, ps, st, path) = (l, ps, st) -model_frozen, ps_frozen, st_frozen = Lux.Experimental.@layer_map freeze_dense model ps st +model_frozen, ps_frozen, st_frozen = Lux.Experimental.layer_map(freeze_dense, model, ps, st) model_frozen(x, ps_frozen, st_frozen) ``` diff --git a/src/contrib/contrib.jl b/src/contrib/contrib.jl index 87360b2118..3e62563f32 100644 --- a/src/contrib/contrib.jl +++ b/src/contrib/contrib.jl @@ -24,7 +24,7 @@ include("freeze.jl") include("share_parameters.jl") include("debug.jl") -@compat public layer_map, @layer_map +@compat public layer_map @compat public FrozenLayer, freeze, unfreeze @compat public share_parameters @compat public DebugLayer, @debug_mode diff --git a/src/contrib/map.jl b/src/contrib/map.jl index a8cc438f68..4f15c08a37 100644 --- a/src/contrib/map.jl +++ b/src/contrib/map.jl @@ -1,57 +1,19 @@ @doc doc""" - @layer_map func layer ps st - -See the documentation of [`Lux.Experimental.layer_map`](@ref) for more details. This macro -eliminates the need to the set the layer name, and uses the variable name of layer as the -starting point. - -## Example - -```jldoctest -julia> using Lux, Random - -julia> c = Parallel( - +; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)), - dense_3=Dense(5 => 1)); - -julia> rng = Random.default_rng(); - -julia> ps, st = Lux.setup(rng, c); - -julia> # Makes parameters of Dense Layers inside Chain zero - function zero_dense_params(l, ps, st, name) - if l isa Dense - println("zeroing params of $name") - ps = merge(ps, (; weight=zero.(ps.weight), bias=zero.(ps.bias))) - end - return l, ps, st - end; - -julia> _, ps_new, _ = Lux.Experimental.@layer_map zero_dense_params c ps st; -zeroing params of KeyPath(:c, :layers, :chain, :layers, :dense_1) -zeroing params of KeyPath(:c, :layers, :chain, :layers, :dense_2) -zeroing params of KeyPath(:c, :layers, :dense_3) - -julia> all(iszero, (ps_new.chain.dense_1.weight, ps_new.chain.dense_1.bias, - ps_new.chain.dense_2.weight, ps_new.chain.dense_2.bias, - ps_new.dense_3.weight, ps_new.dense_3.bias)) -true -``` -""" -macro layer_map(f, l, ps, st) - return quote - layer_map($(esc(f)), $(esc(l)), $(esc(ps)), $(esc(st)), $(Meta.quot(l))) - end -end - -@doc doc""" - layer_map(f::Function, l::AbstractLuxLayer, ps, st::NamedTuple, - name::Symbol=:model) + layer_map(f, l::AbstractLuxLayer, ps, st::NamedTuple) Map the function `f` over the model `l`, with the parameters `ps` and states `st`. This is different from `Functors.fmap` since it zips the layers, parameters, and states and invokes the function on all of them together. +!!! tip "KeyPath provided to the function" + + The `KeyPath` depths on the structure of the parameters and states. This is of + consequence exclusively for [`AbstractLuxWrapperLayer`](@ref) where the structure of the + layer doesn't match the structure of the parameters and states. In the example, provided + below, the `KeyPath` is `(:chain, :dense_1)` for the first layer (following the + structure in `ps`) while accessing the same layer in the chain is done with `( + :chain, :layers, :dense_1)`. + ## Call Signature for `f` - Must take 4 inputs -- `AbstractLuxLayer`, Corresponding Parameters, Corresponding @@ -59,11 +21,6 @@ the function on all of them together. - Must return a tuple of 3 elements -- `AbstractLuxLayer`, new parameters and the new states. -!!! tip "Use `Lux.Experimental.@layer_map` instead" - - We recommend using the macro `Lux.Experimental.@layer_map` instead of this function. It - automatically sets the `name` of the layer to be the variable name. - # Extended Help ## Example @@ -89,9 +46,9 @@ julia> # Makes parameters of Dense Layers inside Chain zero end; julia> _, ps_new, _ = Lux.Experimental.layer_map(zero_dense_params, c, ps, st); -zeroing params of KeyPath(:model, :layers, :chain, :layers, :dense_1) -zeroing params of KeyPath(:model, :layers, :chain, :layers, :dense_2) -zeroing params of KeyPath(:model, :layers, :dense_3) +zeroing params of KeyPath(:chain, :dense_1) +zeroing params of KeyPath(:chain, :dense_2) +zeroing params of KeyPath(:dense_3) julia> all(iszero, (ps_new.chain.dense_1.weight, ps_new.chain.dense_1.bias, ps_new.chain.dense_2.weight, ps_new.chain.dense_2.bias, @@ -99,36 +56,50 @@ julia> all(iszero, (ps_new.chain.dense_1.weight, ps_new.chain.dense_1.bias, true ``` """ -function layer_map(f::F, l, ps, st, name::Symbol=:model) where {F <: Function} - f_wrapper = @closure (kp, layer, ps_, st_) -> f(layer, ps_, st_, KeyPath(name, kp)) - return fmap_with_path(f_wrapper, l, ps, st; walk=LayerWalkWithPath()) +function layer_map(f, l, ps, st) + return fmap_with_path(l, ps, st; walk=LayerWalkWithPath()) do kp, layer, ps_, st_ + return f(layer, ps_, st_, kp) + end end struct LayerWalkWithPath <: Functors.AbstractWalk end -function (::LayerWalkWithPath)(recurse, kp::KeyPath, layer, ps, st) - _layer_children, layer_re = functor(layer) +function (::LayerWalkWithPath)( + recurse::R, kp::KeyPath, layer::AbstractLuxWrapperLayer{field}, + ps, st) where {R, field} + layer_children, layer_re = functor(getfield(layer, field)) + ps_children, ps_re = functor(ps) + st_children, st_re = functor(st) + + layer_children_new, ps_children_new, st_children_new = perform_layer_map( + recurse, kp, ps_children, st_children, layer_children) + + inner_layer = layer_re(layer_children_new) + return (Setfield.set(layer, Setfield.PropertyLens{field}(), inner_layer), + ps_re(ps_children_new), st_re(st_children_new)) +end + +function (::LayerWalkWithPath)( + recurse::R, kp::KeyPath, layer::AbstractLuxLayer, ps, st) where {R} + layer_children, layer_re = functor(layer) ps_children, ps_re = functor(ps) st_children, st_re = functor(st) - _children = keys(ps_children) - needs_correction = _children != keys(_layer_children) - _key = needs_correction ? only(keys(_layer_children)) : nothing - layer_children = needs_correction ? getfield(layer, _key) : _layer_children - @assert keys(layer_children) == keys(ps_children) == keys(st_children) + layer_children_new, ps_children_new, st_children_new = perform_layer_map( + recurse, kp, ps_children, st_children, layer_children) + + return layer_re(layer_children_new), ps_re(ps_children_new), st_re(st_children_new) +end + +function perform_layer_map(recurse, kp, ps_children, st_children, layer_children) + @argcheck keys(layer_children) == keys(ps_children) == keys(st_children) - kps = NamedTuple{_children}(map( - x -> needs_correction ? KeyPath(kp, _key, x) : KeyPath(kp, x), _children)) + kps = NamedTuple{keys(ps_children)}(map(Base.Fix1(KeyPath, kp), keys(ps_children))) ys = map(recurse, kps, layer_children, ps_children, st_children) layer_children_new = map(Base.Fix2(getindex, 1), ys) ps_children_new = map(Base.Fix2(getindex, 2), ys) st_children_new = map(Base.Fix2(getindex, 3), ys) - layer_new = needs_correction ? layer_re(NamedTuple{(_key,)}((layer_children_new,))) : - layer_re(layer_children_new) - ps_new = ps_re(ps_children_new) - st_new = st_re(st_children_new) - - return layer_new, ps_new, st_new + return layer_children_new, ps_children_new, st_children_new end diff --git a/test/contrib/map_tests.jl b/test/contrib/map_tests.jl index 401afbc8d5..2df37556a4 100644 --- a/test/contrib/map_tests.jl +++ b/test/contrib/map_tests.jl @@ -1,13 +1,13 @@ @testitem "Layer Map" setup=[SharedTestSetup] tags=[:contrib] begin using Setfield, Functors - function __occurs_in(kp::KeyPath, x::KeyPath) + function occurs_in(kp::KeyPath, x::KeyPath) length(kp) ≤ length(x) && return all(==(x[i], kp[i]) for i in 1:length(kp)) return false end function zero_dense_params_1(l, ps, st, name) - if l isa Dense && __occurs_in(KeyPath(:model, :layers, :chain), name) + if l isa Dense && occurs_in(KeyPath(:chain), name) @set! ps.weight = zero.(ps.weight) @set! ps.bias = zero.(ps.bias) end @@ -15,14 +15,6 @@ end function zero_dense_params_2(l, ps, st, name) - if l isa Dense && __occurs_in(KeyPath(:c, :layers, :chain), name) - @set! ps.weight = zero.(ps.weight) - @set! ps.bias = zero.(ps.bias) - end - return l, ps, st - end - - function zero_dense_params_3(l, ps, st, name) if l isa Dense @set! ps.weight = zero.(ps.weight) @set! ps.bias = zero.(ps.bias) @@ -36,7 +28,7 @@ dense_3=Dense(5 => 1)) rng = StableRNG(12345) - ps, st = Lux.setup(rng, c) .|> dev + ps, st = Lux.setup(rng, c) |> dev c_, ps_, st_ = Lux.Experimental.layer_map(zero_dense_params_1, c, ps, st) @@ -47,15 +39,6 @@ @test !all(iszero, ps_.dense_3.weight) @test all(iszero, ps_.dense_3.bias) - c_, ps_, st_ = Lux.Experimental.@layer_map zero_dense_params_2 c ps st - - @test all(iszero, ps_.chain.dense_1.weight) - @test all(iszero, ps_.chain.dense_1.bias) - @test all(iszero, ps_.chain.dense_2.weight) - @test all(iszero, ps_.chain.dense_2.bias) - @test !all(iszero, ps_.dense_3.weight) - @test all(iszero, ps_.dense_3.bias) - # Custom Layers -- See https://github.com/LuxDL/Lux.jl/issues/187 struct SimpleCustom{L1, L2} <: Lux.AbstractLuxContainerLayer{(:dense, :conv)} dense::L1 @@ -64,9 +47,25 @@ l = SimpleCustom(Dense(3 => 2), Conv((3,), 3 => 2)) - ps, st = Lux.setup(rng, l) .|> dev + ps, st = Lux.setup(rng, l) |> dev + + l_, ps_, st_ = Lux.Experimental.layer_map(zero_dense_params_2, l, ps, st) + + @test all(iszero, ps_.dense.weight) + @test all(iszero, ps_.dense.bias) + @test !all(iszero, ps_.conv.weight) + @test all(iszero, ps_.conv.bias) + + # Custom Wrapper + struct SimpleWrapper{L} <: Lux.AbstractLuxWrapperLayer{:model} + model::L + end + + l = SimpleWrapper(SimpleCustom(Dense(3 => 2), Conv((3,), 3 => 2))) + + ps, st = Lux.setup(rng, l) |> dev - l_, ps_, st_ = Lux.Experimental.@layer_map zero_dense_params_3 l ps st + l_, ps_, st_ = Lux.Experimental.layer_map(zero_dense_params_2, l, ps, st) @test all(iszero, ps_.dense.weight) @test all(iszero, ps_.dense.bias) From fac3d4889ab5d301dd1b4cbd2df2fdbaed8beb03 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Sep 2024 15:52:53 -0400 Subject: [PATCH 47/95] fix: tests --- examples/BayesianNN/main.jl | 2 +- examples/GravitationalWaveForm/main.jl | 2 +- src/contrib/map.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/BayesianNN/main.jl b/examples/BayesianNN/main.jl index 7eea940a82..aa850d2ed8 100644 --- a/examples/BayesianNN/main.jl +++ b/examples/BayesianNN/main.jl @@ -110,7 +110,7 @@ end # To interface with external libraries it is often desirable to use the # [`StatefulLuxLayer`](@ref) to automatically handle the neural network states. -const model = StatefulLuxLayer{true}(nn, st) +const model = StatefulLuxLayer{true}(nn, nothing, st) ## Specify the probabilistic model. @model function bayes_nn(xs, ts) diff --git a/examples/GravitationalWaveForm/main.jl b/examples/GravitationalWaveForm/main.jl index 34e901a6e4..4bbed4952b 100644 --- a/examples/GravitationalWaveForm/main.jl +++ b/examples/GravitationalWaveForm/main.jl @@ -234,7 +234,7 @@ ps, st = Lux.setup(Xoshiro(), nn) const params = ComponentArray{Float64}(ps) -const nn_model = StatefulLuxLayer{true}(nn, st) +const nn_model = StatefulLuxLayer{true}(nn, nothing, st) # Now we define a system of odes which describes motion of point like particle with # Newtonian physics, uses diff --git a/src/contrib/map.jl b/src/contrib/map.jl index 4f15c08a37..f5142f0db9 100644 --- a/src/contrib/map.jl +++ b/src/contrib/map.jl @@ -48,7 +48,7 @@ julia> # Makes parameters of Dense Layers inside Chain zero julia> _, ps_new, _ = Lux.Experimental.layer_map(zero_dense_params, c, ps, st); zeroing params of KeyPath(:chain, :dense_1) zeroing params of KeyPath(:chain, :dense_2) -zeroing params of KeyPath(:dense_3) +zeroing params of KeyPath(:dense_3,) julia> all(iszero, (ps_new.chain.dense_1.weight, ps_new.chain.dense_1.bias, ps_new.chain.dense_2.weight, ps_new.chain.dense_2.bias, From 4763413f29c9b64d4dcfe2a49fc5097631d71e29 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Sep 2024 20:22:13 -0400 Subject: [PATCH 48/95] fix: remove symbolic tutorial references --- docs/make.jl | 3 +-- docs/src/.vitepress/config.mts | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index c67eb73d1f..a491ada64d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -31,8 +31,7 @@ pages = [ "tutorials/intermediate/3_HyperNet.md" ], "Advanced" => [ - "tutorials/advanced/1_GravitationalWaveForm.md", - "tutorials/advanced/2_SymbolicOptimalControl.md" + "tutorials/advanced/1_GravitationalWaveForm.md" ] ], "Manual" => [ diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 05069ee867..f6e58e14b3 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -132,8 +132,7 @@ export default defineConfig({ }, { text: 'Advanced', collapsed: false, items: [ - { text: 'Training a Neural ODE to Model Gravitational Waveforms', link: '/tutorials/advanced/1_GravitationalWaveForm' }, - { text: 'Solving Optimal Control Problems with Symbolic UDEs', link: '/tutorials/advanced/2_SymbolicOptimalControl' },] + { text: 'Training a Neural ODE to Model Gravitational Waveforms', link: '/tutorials/advanced/1_GravitationalWaveForm' },] }, { text: 'Large Models', collapsed: true, items: [ From 2d1ad7b5c0636a5246df287122743f007fed92aa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 Sep 2024 20:27:36 -0400 Subject: [PATCH 49/95] fix: incorrect size propagator test rebase --- test/helpers/size_propagator_tests.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/helpers/size_propagator_tests.jl b/test/helpers/size_propagator_tests.jl index bf840de417..7ce8e2f572 100644 --- a/test/helpers/size_propagator_tests.jl +++ b/test/helpers/size_propagator_tests.jl @@ -6,8 +6,7 @@ Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(), Dense(256 => 120, relu), Dense(120 => 84, relu), Dense(84 => 10)) - @testset "size(x) = $(size(x))" for x in ( - randn(rng, Float32, 28, 28, 1, 3), randn(rng, Float32, 28, 28, 1, 12)) + for x in (randn(rng, Float32, 28, 28, 1, 3), randn(rng, Float32, 28, 28, 1, 12)) @test Lux.outputsize(lenet, x, rng) == (10,) end end @@ -19,13 +18,12 @@ BatchNorm(120, relu), Dense(120 => 84, relu), Dropout(0.5f0), BatchNorm(84, relu), Dense(84 => 10), BatchNorm(10, relu)) - @testset "size(x) = $(size(x))" for x in ( - randn(rng, Float32, 28, 28, 1, 3), randn(rng, Float32, 28, 28, 1, 12)) + for x in (randn(rng, Float32, 28, 28, 1, 3), randn(rng, Float32, 28, 28, 1, 12)) @test Lux.outputsize(lenet, x, rng) == (10,) end end - @testset "Normalization: $(nameof(typeof(layer)))" for (layer, xs) in [ + norm_layer = [ (BatchNorm(3, relu), [randn(rng, Float32, 4, 4, 3, 2), randn(rng, Float32, 3, 3)]), (GroupNorm(6, 3, relu), [randn(rng, Float32, 4, 4, 6, 2), randn(rng, Float32, 6, 3)]), @@ -33,6 +31,8 @@ [randn(rng, Float32, 4, 4, 3, 2), randn(rng, Float32, 4, 3, 2)]), (LayerNorm((2, 1, 3), relu), [randn(rng, Float32, 2, 4, 3, 2), randn(rng, Float32, 2, 1, 3, 3)])] + + @testset "Normalization: $(nameof(typeof(layer)))" for (layer, xs) in norm_layer for x in xs @test Lux.outputsize(layer, x, rng) == size(x)[1:(end - 1)] end From d41e3e36223ed473e4b54f0e28c30787b32b2807 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 4 Sep 2024 15:52:42 -0400 Subject: [PATCH 50/95] fix!: remove allow_fast_activation --- docs/src/introduction/updating_to_v1.md | 2 + src/layers/basic.jl | 46 +++++++------------ src/layers/conv.jl | 43 ++++++----------- src/layers/normalize.jl | 61 +++++++++---------------- test/layers/basic_tests.jl | 14 ------ test/layers/conv_tests.jl | 14 ------ test/layers/normalize_tests.jl | 28 ------------ 7 files changed, 56 insertions(+), 152 deletions(-) diff --git a/docs/src/introduction/updating_to_v1.md b/docs/src/introduction/updating_to_v1.md index c64f5037e3..bbdffd98e3 100644 --- a/docs/src/introduction/updating_to_v1.md +++ b/docs/src/introduction/updating_to_v1.md @@ -95,6 +95,8 @@ abstraction. - `Lux.Experimental.@layer_map` is not longer needed and has been removed. The name of the variable prevents writing generic functions and is no longer pre-pended to the `KeyPath`. See the docstring of [`Lux.Experimental.layer_map`](@ref) for more details. +- `allow_fast_activation` kwarg has been removed completely. Pass an annonymous function + as the activation to prevent internal modivations to the activation function. ### Breaking Changes (Moved Functionality) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index bd0eb1b01d..e10974057f 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -248,7 +248,7 @@ Base.show(io::IO, w::WrappedFunction) = print(io, "WrappedFunction(", w.func, ") """ Dense(in_dims => out_dims, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, use_bias=True(), allow_fast_activation=True()) + init_bias=zeros32, use_bias=True()) Create a traditional fully connected layer, whose forward pass is given by: `y = activation.(weight * x .+ bias)` @@ -265,9 +265,6 @@ Create a traditional fully connected layer, whose forward pass is given by: (`weight = init_weight(rng, out_dims, in_dims)`) - `init_bias`: initializer for the bias vector (ignored if `use_bias=false`) - `use_bias`: Trainable bias can be disabled entirely by setting this to `false` - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` ## Input @@ -304,9 +301,7 @@ function Dense(mapping::Pair{<:IntegerType, <:IntegerType}, activation=identity; end function Dense(in_dims::IntegerType, out_dims::IntegerType, activation=identity; - init_weight=glorot_uniform, init_bias=zeros32, - use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation + init_weight=glorot_uniform, init_bias=zeros32, use_bias::BoolType=True()) return Dense(activation, in_dims, out_dims, init_weight, init_bias, static(use_bias)) end @@ -327,15 +322,14 @@ outputsize(d::Dense, _, ::AbstractRNG) = (d.out_dims,) function (d::Dense)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(d, ps, st, x) bias = safe_getproperty(ps, Val(:bias)) + σ = NNlib.fast_act(d.activation, x) z = matrix_to_array( - fused_dense_bias_activation(d.activation, ps.weight, make_abstract_matrix(y), bias), - y) + fused_dense_bias_activation(σ, ps.weight, make_abstract_matrix(y), bias), y) return z, st end """ - Scale(dims, activation=identity; init_weight=ones32, init_bias=zeros32, use_bias=True(), - allow_fast_activation=True()) + Scale(dims, activation=identity; init_weight=ones32, init_bias=zeros32, use_bias=True()) Create a Sparsely Connected Layer with a very specific structure (only Diagonal Elements are non-zero). The forward pass is given by: `y = activation.(weight .* x .+ bias)` @@ -351,9 +345,6 @@ Elements are non-zero). The forward pass is given by: `y = activation.(weight .* (`weight = init_weight(rng, out_dims, in_dims)`) - `init_bias`: initializer for the bias vector (ignored if `use_bias=false`) - `use_bias`: Trainable bias can be disabled entirely by setting this to `false` - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` ## Input @@ -386,9 +377,7 @@ function Base.show(io::IO, d::Scale) end function Scale(dims::Tuple{Vararg{IntegerType}}, activation=identity; - init_weight=glorot_uniform, init_bias=zeros32, - use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation + init_weight=glorot_uniform, init_bias=zeros32, use_bias::BoolType=True()) return Scale(activation, dims, init_weight, init_bias, static(use_bias)) end @@ -413,18 +402,20 @@ outputsize(d::Scale, _, ::AbstractRNG) = d.dims function (d::Scale{False})(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(d, ps, st, x) - return @.(d.activation(y .* ps.weight)), st + σ = NNlib.fast_act(d.activation, y) + return @.(σ(y .* ps.weight)), st end function (d::Scale{True})(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(d, ps, st, x) - return @.(d.activation(y * ps.weight + ps.bias)), st + σ = NNlib.fast_act(d.activation, y) + return @.(σ(y * ps.weight + ps.bias)), st end """ Bilinear((in1_dims, in2_dims) => out, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, use_bias=True(), allow_fast_activation=True()) + init_bias=zeros32, use_bias=True()) Bilinear(in12_dims => out, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, use_bias=True(), allow_fast_activation=True()) + init_bias=zeros32, use_bias=True()) Create a fully connected layer between two inputs and an output, and otherwise similar to [`Dense`](@ref). Its output, given vectors `x` & `y`, is another vector `z` with, for all @@ -449,9 +440,6 @@ with `B` the Bilinear layer. (`weight = init_weight(rng, out_dims, in1_dims, in2_dims)`) - `init_bias`: initializer for the bias vector (ignored if `use_bias=false`) - `use_bias`: Trainable bias can be disabled entirely by setting this to `false` - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` ## Input @@ -494,10 +482,9 @@ function Bilinear((in12_dims, out)::Pair{<:IntegerType, <:IntegerType}, return Bilinear((in12_dims, in12_dims) => out, activation; kwargs...) end -function Bilinear(((in1_dims, in2_dims), out)::Pair{<:Tuple, <:IntegerType}, - activation=identity; init_weight=glorot_uniform, init_bias=zeros32, - use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation +function Bilinear( + ((in1_dims, in2_dims), out)::Pair{<:Tuple, <:IntegerType}, activation=identity; + init_weight=glorot_uniform, init_bias=zeros32, use_bias::BoolType=True()) return Bilinear( activation, in1_dims, in2_dims, out, init_weight, init_bias, static(use_bias)) end @@ -527,7 +514,8 @@ function (b::Bilinear)( Wy = reshape(reshape(ps.weight, (:, s₃)) * y, (s₁, s₂, :)) Wyx = reshape(batched_matmul(Wy, reshape(x, (s₂, 1, :))), (s₁, :)) - return bias_activation!!(b.activation, Wyx, safe_getproperty(ps, Val(:bias))), st + σ = NNlib.fast_act(b.activation, Wyx) + return bias_activation!!(σ, Wyx, safe_getproperty(ps, Val(:bias))), st end function (b::Bilinear)((x, y)::Tuple{<:AbstractArray, <:AbstractArray}, ps, st::NamedTuple) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 5e4290621b..4e0aebeb3b 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -62,7 +62,7 @@ end @doc doc""" Conv(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, activation=identity; init_weight=glorot_uniform, init_bias=zeros32, stride=1, - pad=0, dilation=1, groups=1, use_bias=True(), allow_fast_activation=True()) + pad=0, dilation=1, groups=1, use_bias=True()) Standard convolutional layer. @@ -115,9 +115,6 @@ Standard convolutional layer. convolution into (set `groups = in_chs` for Depthwise Convolutions). `in_chs` and `out_chs` must be divisible by `groups`. - `use_bias`: Trainable bias can be disabled entirely by setting this to `false`. - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` ## Inputs @@ -154,15 +151,13 @@ O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s end function Conv(k::Tuple{Vararg{IntegerType}}, ch::Pair{<:IntegerType, <:IntegerType}, - activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, stride=1, pad=0, dilation=1, groups=1, - use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) + activation=identity; init_weight=glorot_uniform, init_bias=zeros32, + stride=1, pad=0, dilation=1, groups=1, use_bias::BoolType=True()) stride = Utils.expand(Val(length(k)), stride) dilation = Utils.expand(Val(length(k)), dilation) pad = calc_padding(pad, k, dilation, stride) @argcheck allequal(length, (stride, dilation, k)) - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation return Conv(activation, first(ch), last(ch), k, stride, pad, dilation, groups, init_weight, init_bias, static(use_bias)) end @@ -182,7 +177,8 @@ function (c::Conv)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(c, ps, st, x) cdims = DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups) bias = safe_getproperty(ps, Val(:bias)) - return fused_conv_bias_activation(c.activation, ps.weight, y, bias, cdims), st + σ = NNlib.fast_act(c.activation, y) + return fused_conv_bias_activation(σ, ps.weight, y, bias, cdims), st end function Base.show(io::IO, l::Conv) @@ -575,7 +571,7 @@ PixelShuffle(r::IntegerType) = WrappedFunction(Base.Fix2(pixel_shuffle, r)) @doc doc""" CrossCor(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, activation=identity; init_weight=glorot_uniform, init_bias=zeros32, stride=1, - pad=0, dilation=1, groups=1, use_bias=True(), allow_fast_activation=True()) + pad=0, dilation=1, groups=1, use_bias=True()) Cross Correlation layer. @@ -617,9 +613,6 @@ number of observations in a batch. dimension. - `use_bias`: Trainable bias can be disabled entirely by setting this to `false`. - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` # Extended Help @@ -658,15 +651,13 @@ O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s end function CrossCor(k::Tuple{Vararg{IntegerType}}, ch::Pair{<:IntegerType, <:IntegerType}, - activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, stride=1, pad=0, dilation=1, groups=1, - use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) + activation=identity; init_weight=glorot_uniform, init_bias=zeros32, + stride=1, pad=0, dilation=1, groups=1, use_bias::BoolType=True()) stride = Utils.expand(Val(length(k)), stride) dilation = Utils.expand(Val(length(k)), dilation) pad = calc_padding(pad, k, dilation, stride) @argcheck allequal(length, (stride, dilation, k)) - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation return CrossCor(activation, first(ch), last(ch), k, stride, pad, dilation, groups, init_weight, init_bias, static(use_bias)) end @@ -687,7 +678,8 @@ function (c::CrossCor)(x::AbstractArray, ps, st::NamedTuple) cdims = DenseConvDims( DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups); F=true) bias = safe_getproperty(ps, Val(:bias)) - return fused_conv_bias_activation(c.activation, ps.weight, y, bias, cdims), st + σ = NNlib.fast_act(c.activation, y) + return fused_conv_bias_activation(σ, ps.weight, y, bias, cdims), st end function Base.show(io::IO, l::CrossCor) @@ -706,8 +698,7 @@ end @doc doc""" ConvTranspose(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, activation=identity; init_weight=glorot_uniform, init_bias=zeros32, - stride=1, pad=0, dilation=1, groups=1, use_bias=True(), - allow_fast_activation=True()) + stride=1, pad=0, dilation=1, groups=1, use_bias=True()) Standard convolutional transpose layer. @@ -741,9 +732,6 @@ Standard convolutional transpose layer. convolution into (set `groups = in_chs` for Depthwise Convolutions). `in_chs` and `out_chs` must be divisible by `groups`. - `use_bias`: Trainable bias can be disabled entirely by setting this to `false`. - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` # Extended Help @@ -778,9 +766,8 @@ end function ConvTranspose( k::Tuple{Vararg{IntegerType}}, ch::Pair{<:IntegerType, <:IntegerType}, - activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, stride=1, pad=0, dilation=1, groups=1, - use_bias::BoolType=True(), allow_fast_activation::BoolType=True()) + activation=identity; init_weight=glorot_uniform, init_bias=zeros32, + stride=1, pad=0, dilation=1, groups=1, use_bias::BoolType=True()) stride = Utils.expand(Val(length(k)), stride) dilation = Utils.expand(Val(length(k)), dilation) pad = if pad isa SamePad @@ -790,7 +777,6 @@ function ConvTranspose( end @argcheck allequal(length, (stride, dilation, k)) - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation return ConvTranspose(activation, first(ch), last(ch), k, stride, pad, dilation, groups, init_weight, init_bias, static(use_bias)) end @@ -810,7 +796,8 @@ function (c::ConvTranspose)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(c, ps, st, x) cdims = conv_transpose_dims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups) bias = safe_getproperty(ps, Val(:bias)) - return bias_activation!!(c.activation, conv_transpose(y, ps.weight, cdims), bias), st + σ = NNlib.fast_act(c.activation, y) + return bias_activation!!(σ, conv_transpose(y, ps.weight, cdims), bias), st end function Base.show(io::IO, l::ConvTranspose) diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 6cfc28c970..146d58f417 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -1,7 +1,6 @@ @doc doc""" BatchNorm(chs::Integer, activation=identity; init_bias=zeros32, init_scale=ones32, - affine=True(), track_stats=True(), epsilon=1f-5, momentum=0.1f0, - allow_fast_activation::Bool=true) + affine=True(), track_stats=True(), epsilon=1f-5, momentum=0.1f0) [Batch Normalization](https://arxiv.org/abs/1502.03167) layer. @@ -22,9 +21,6 @@ slice and normalises the input accordingly. - `epsilon`: a value added to the denominator for numerical stability - `momentum`: the value used for the `running_mean` and `running_var` computation - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. @@ -98,9 +94,8 @@ See also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`LayerNorm`](@ref), end function BatchNorm(chs::IntegerType, activation=identity; init_bias=zeros32, - init_scale=ones32, affine::BoolType=True(), track_stats::BoolType=True(), - epsilon=1.0f-5, momentum=0.1f0, allow_fast_activation::BoolType=True()) - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation + init_scale=ones32, affine::BoolType=True(), + track_stats::BoolType=True(), epsilon=1.0f-5, momentum=0.1f0) return BatchNorm(activation, epsilon, momentum, chs, init_bias, init_scale, static(affine), static(track_stats)) end @@ -129,10 +124,11 @@ function (BN::BatchNorm)(x::AbstractArray, ps, st::NamedTuple) end x′ = match_eltype(BN, ps, st, x) + σ = NNlib.fast_act(BN.activation, x′) y, stats = batchnorm( x′, safe_getproperty(ps, Val(:scale)), safe_getproperty(ps, Val(:bias)), - safe_getproperty(st, Val(:running_mean)), safe_getproperty(st, Val(:running_var)), - st.training, BN.activation, BN.momentum, BN.epsilon) + safe_getproperty(st, Val(:running_mean)), + safe_getproperty(st, Val(:running_var)), st.training, σ, BN.momentum, BN.epsilon) return y, update_batchnorm_state(BN, st, stats) end @@ -153,8 +149,7 @@ end """ GroupNorm(chs::Integer, groups::Integer, activation=identity; init_bias=zeros32, - init_scale=ones32, affine=true, epsilon=1f-5, - allow_fast_activation::Bool=true) + init_scale=ones32, affine=true, epsilon=1f-5) [Group Normalization](https://arxiv.org/abs/1803.08494) layer. @@ -171,9 +166,6 @@ end - `epsilon`: a value added to the denominator for numerical stability - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. @@ -232,11 +224,10 @@ See also [`GroupNorm`](@ref), [`InstanceNorm`](@ref), [`LayerNorm`](@ref), affine <: StaticBool end -function GroupNorm(chs::IntegerType, groups::IntegerType, activation=identity; - init_bias=zeros32, init_scale=ones32, affine::BoolType=True(), - epsilon=1.0f-5, allow_fast_activation::BoolType=True()) +function GroupNorm( + chs::IntegerType, groups::IntegerType, activation=identity; init_bias=zeros32, + init_scale=ones32, affine::BoolType=True(), epsilon=1.0f-5) @argcheck chs % groups==0 "The number of groups ($(groups)) must divide the number of channels ($chs)" - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation return GroupNorm( activation, epsilon, chs, init_bias, init_scale, groups, static(affine)) end @@ -250,8 +241,9 @@ parameterlength(l::GroupNorm) = has_affine(l) ? (l.chs * 2) : 0 function (GN::GroupNorm)(x::AbstractArray, ps, st::NamedTuple) x′ = match_eltype(GN, ps, st, x) + σ = NNlib.fast_act(GN.activation, x′) y = groupnorm(x′, safe_getproperty(ps, Val(:scale)), - safe_getproperty(ps, Val(:bias)), GN.groups, GN.activation, GN.epsilon) + safe_getproperty(ps, Val(:bias)), GN.groups, σ, GN.epsilon) return y, st end @@ -264,7 +256,7 @@ end @doc doc""" InstanceNorm(chs::Integer, activation=identity; init_bias=zeros32, init_scale=ones32, - affine=true, epsilon=1f-5, allow_fast_activation::Bool=true) + affine=True(), epsilon=1f-5) Instance Normalization. For details see [1]. @@ -282,9 +274,6 @@ accordingly. ## Keyword Arguments - `epsilon`: a value added to the denominator for numerical stability - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. @@ -345,10 +334,8 @@ See also [`BatchNorm`](@ref), [`GroupNorm`](@ref), [`LayerNorm`](@ref), [`Weight affine <: StaticBool end -function InstanceNorm( - chs::IntegerType, activation=identity; init_bias=zeros32, init_scale=ones32, - affine::BoolType=True(), epsilon=1.0f-5, allow_fast_activation::BoolType=True()) - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation +function InstanceNorm(chs::IntegerType, activation=identity; init_bias=zeros32, + init_scale=ones32, affine::BoolType=True(), epsilon=1.0f-5) return InstanceNorm(activation, epsilon, chs, init_bias, init_scale, static(affine)) end @@ -362,9 +349,9 @@ parameterlength(l::InstanceNorm) = ifelse(has_affine(l), l.chs * 2, 0) function (IN::InstanceNorm)(x::AbstractArray, ps, st::NamedTuple) x′ = match_eltype(IN, ps, st, x) - y, _ = instancenorm( - x′, safe_getproperty(ps, Val(:scale)), safe_getproperty(ps, Val(:bias)), - st.training, IN.activation, IN.epsilon) + σ = NNlib.fast_act(IN.activation, x′) + y, _ = instancenorm(x′, safe_getproperty(ps, Val(:scale)), + safe_getproperty(ps, Val(:bias)), st.training, σ, IN.epsilon) return y, st end @@ -527,9 +514,6 @@ where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`. ## Keyword Arguments - - `allow_fast_activation`: If `true`, then certain activations can be approximated with - a faster version. The new activation function will be given by - `NNlib.fast_act(activation)` - `epsilon`: a value added to the denominator for numerical stability. - `dims`: Dimensions to normalize the array over. - If `affine=true`, it also applies a shift and a rescale to the input through to @@ -567,10 +551,8 @@ where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`. affine <: StaticBool end -function LayerNorm( - shape, activation=identity; epsilon=1.0f-5, dims=Colon(), affine::BoolType=True(), - init_bias=zeros32, init_scale=ones32, allow_fast_activation::BoolType=True()) - activation = dynamic(allow_fast_activation) ? NNlib.fast_act(activation) : activation +function LayerNorm(shape, activation=identity; epsilon=1.0f-5, dims=Colon(), + affine::BoolType=True(), init_bias=zeros32, init_scale=ones32) return LayerNorm( shape, activation, epsilon, init_bias, init_scale, dims, static(affine)) end @@ -585,8 +567,9 @@ end function (l::LayerNorm)(x::AbstractArray, ps, st::NamedTuple) x′ = match_eltype(l, ps, st, x) + σ = NNlib.fast_act(l.activation, x′) y = layernorm(x′, safe_getproperty(ps, Val(:scale)), - safe_getproperty(ps, Val(:bias)), l.activation, l.dims, l.epsilon) + safe_getproperty(ps, Val(:bias)), σ, l.dims, l.epsilon) return y, st end diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 1c880c3800..1535ab3b04 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -126,13 +126,6 @@ end @test layer.activation == relu end - @testset "allow fast activation" begin - layer = Dense(10, 10, tanh) - @test layer.activation == tanh_fast - layer = Dense(10, 10, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end - @testset "dimensions" begin layer = Dense(10, 5) ps, st = Lux.setup(rng, layer) @@ -208,13 +201,6 @@ end @test layer.activation == relu end - @testset "allow fast activation" begin - layer = Scale(10, 5, tanh) - @test layer.activation == tanh_fast - layer = Scale(10, 5, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end - @testset "dimensions" begin layer = Scale(10, 5) ps, st = Lux.setup(rng, layer) |> dev diff --git a/test/layers/conv_tests.jl b/test/layers/conv_tests.jl index 7a0e813a0f..2cb7b609ac 100644 --- a/test/layers/conv_tests.jl +++ b/test/layers/conv_tests.jl @@ -301,13 +301,6 @@ end @jet layer(x, ps, st) end - - @testset "allow fast activation" begin - layer = Conv((3, 3), 1 => 1, tanh) - @test layer.activation == tanh_fast - layer = Conv((3, 3), 1 => 1, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end end end @@ -476,13 +469,6 @@ end test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) end end - - @testset "allow fast activation" begin - layer = CrossCor((3, 3), 1 => 1, tanh) - @test layer.activation == tanh_fast - layer = CrossCor((3, 3), 1 => 1, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end end end diff --git a/test/layers/normalize_tests.jl b/test/layers/normalize_tests.jl index 5f67efb1c2..567a6a72ac 100644 --- a/test/layers/normalize_tests.jl +++ b/test/layers/normalize_tests.jl @@ -101,13 +101,6 @@ @jet m(x, ps, st) end - - @testset "allow fast activation" begin - layer = BatchNorm(10, tanh) - @test layer.activation == tanh_fast - layer = BatchNorm(10, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end end end @@ -192,13 +185,6 @@ end end @test_throws ArgumentError GroupNorm(5, 2) - - @testset "allow fast activation" begin - layer = GroupNorm(10, 2, tanh) - @test layer.activation == tanh_fast - layer = GroupNorm(10, 2, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end end end @@ -371,13 +357,6 @@ end end end end - - @testset "allow fast activation" begin - layer = LayerNorm((3, 1), tanh) - @test layer.activation == tanh_fast - layer = LayerNorm((3, 1), tanh; allow_fast_activation=false) - @test layer.activation == tanh - end end end @@ -428,12 +407,5 @@ end end end end - - @testset "allow fast activation" begin - layer = InstanceNorm(3, tanh) - @test layer.activation == tanh_fast - layer = InstanceNorm(3, tanh; allow_fast_activation=false) - @test layer.activation == tanh - end end end From 43cee132348dd8e407db4d19772003875cba78f1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 4 Sep 2024 16:00:41 -0400 Subject: [PATCH 51/95] fix: bad rebase --- src/contrib/debug.jl | 2 +- src/extended_ops.jl | 2 +- src/helpers/stateful.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/contrib/debug.jl b/src/contrib/debug.jl index cc50dbf32c..7d7b388618 100644 --- a/src/contrib/debug.jl +++ b/src/contrib/debug.jl @@ -1,5 +1,5 @@ """ - DebugLayer(layer::AbstractExplicitLayer; + DebugLayer(layer::AbstractLuxLayer; nan_check::Union{Symbol, StaticSymbol, Val}=static(:both), error_check::Union{StaticBool, Bool, Val{true}, Val{false}}=True(), location::KeyPath=KeyPath()) diff --git a/src/extended_ops.jl b/src/extended_ops.jl index d2e65b4f28..ce9f662153 100644 --- a/src/extended_ops.jl +++ b/src/extended_ops.jl @@ -236,7 +236,7 @@ const private_foldl_init = LuxOps.foldl_init # These are defined here to avoid a circular dependency among modules for (op, field) in (:bias => :use_bias, :affine => :affine, :track_stats => :track_stats, :train_state => :train_state) - @eval function $(Symbol(:has_, op))(l::AbstractExplicitLayer) + @eval function $(Symbol(:has_, op))(l::AbstractLuxLayer) res = known(safe_getproperty(l, Val($(Meta.quot(field))))) return ifelse(res === nothing, false, res) end diff --git a/src/helpers/stateful.jl b/src/helpers/stateful.jl index fffea8cf71..02c57eeaf4 100644 --- a/src/helpers/stateful.jl +++ b/src/helpers/stateful.jl @@ -48,7 +48,7 @@ mutable struct StatefulLuxLayer{ST, M <: AbstractLuxLayer, psType, stType} fixed_state_type::ST function StatefulLuxLayer( - model::AbstractExplicitLayer, ps, st, st_any, fixed_state_type::StaticBool) + model::AbstractLuxLayer, ps, st, st_any, fixed_state_type::StaticBool) return new{typeof(fixed_state_type), typeof(model), typeof(ps), typeof(st)}( model, ps, st, st_any, fixed_state_type) end From af560b65abd61d7b53d1960a7c9d5a635530dca5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 4 Sep 2024 16:51:52 -0400 Subject: [PATCH 52/95] fix: update freezing docs --- docs/src/manual/freezing_model_parameters.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/src/manual/freezing_model_parameters.md b/docs/src/manual/freezing_model_parameters.md index 14425de4fd..0d5258cfa7 100644 --- a/docs/src/manual/freezing_model_parameters.md +++ b/docs/src/manual/freezing_model_parameters.md @@ -39,14 +39,14 @@ model_frozen(x, ps_frozen, st_frozen) When the function in `layer_map` is called, the 4th argument is the name of the layer. For example, if you want to freeze the 1st layer inside the inner Chain. The name for this -would be `.layer_2.layer_1`. +would be `layer_2.layer_1`. :::code-group ```julia [Freezing by Layer Name] function freeze_by_name(d, ps, st, name::KeyPath) - name == KeyPath(:model, :layer_2, :layer_1) && + name == KeyPath(:layer_2, :layer_1) && return Lux.Experimental.freeze(d, ps, st, (:weight, :bias)) return d, ps, st end @@ -74,7 +74,7 @@ the `weight` parameter while training the `bias` parameter. ```julia [Freezing Some Parameters of a Layer] function freeze_by_name(d, ps, st, name::KeyPath) - name == KeyPath(:model, :layer_2, :layer_1) && + name == KeyPath(:layer_2, :layer_1) && return Lux.Experimental.freeze(d, ps, st, (:weight,)) return d, ps, st end @@ -84,7 +84,7 @@ end ```julia [Freezing All Parameters of a Layer] function freeze_by_name(d, ps, st, name::KeyPath) - name == KeyPath(:model, :layer_2, :layer_1) && + name == KeyPath(:layer_2, :layer_1) && return Lux.Experimental.freeze(d, ps, st, (:weight, :bias)) return d, ps, st end From 9b35c99aeef1c977508ebc0392ec88bb34639045 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 4 Sep 2024 17:08:05 -0400 Subject: [PATCH 53/95] feat: correctly type-cast momentum and epsilon --- src/contrib/share_parameters.jl | 2 +- src/layers/normalize.jl | 24 +++++++++++++----------- src/utils.jl | 1 + 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/contrib/share_parameters.jl b/src/contrib/share_parameters.jl index 59c7e95bdf..9855089800 100644 --- a/src/contrib/share_parameters.jl +++ b/src/contrib/share_parameters.jl @@ -26,7 +26,7 @@ Updated Parameters having the same structure as `ps`. julia> model = Chain(; d1=Dense(2 => 4, tanh), d3=Chain(; l1=Dense(4 => 2), l2=Dense(2 => 4)), d2=Dense(4 => 2)) Chain( - d1 = Dense(2 => 4, tanh_fast), # 12 parameters + d1 = Dense(2 => 4, tanh), # 12 parameters d3 = Chain( l1 = Dense(4 => 2), # 10 parameters l2 = Dense(2 => 4), # 12 parameters diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 146d58f417..4d483e8a8b 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -82,10 +82,10 @@ Chain( See also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`LayerNorm`](@ref), [`WeightNorm`](@ref) """ -@concrete struct BatchNorm{N} <: AbstractLuxLayer +@concrete struct BatchNorm <: AbstractLuxLayer activation - epsilon::N - momentum::N + epsilon <: Real + momentum <: Real chs <: IntegerType init_bias init_scale @@ -127,8 +127,9 @@ function (BN::BatchNorm)(x::AbstractArray, ps, st::NamedTuple) σ = NNlib.fast_act(BN.activation, x′) y, stats = batchnorm( x′, safe_getproperty(ps, Val(:scale)), safe_getproperty(ps, Val(:bias)), - safe_getproperty(st, Val(:running_mean)), - safe_getproperty(st, Val(:running_var)), st.training, σ, BN.momentum, BN.epsilon) + safe_getproperty(st, Val(:running_mean)), safe_getproperty(st, Val(:running_var)), + st.training, σ, convert(unwrapped_eltype(x′), BN.momentum), + convert(unwrapped_eltype(x′), BN.epsilon)) return y, update_batchnorm_state(BN, st, stats) end @@ -242,8 +243,8 @@ parameterlength(l::GroupNorm) = has_affine(l) ? (l.chs * 2) : 0 function (GN::GroupNorm)(x::AbstractArray, ps, st::NamedTuple) x′ = match_eltype(GN, ps, st, x) σ = NNlib.fast_act(GN.activation, x′) - y = groupnorm(x′, safe_getproperty(ps, Val(:scale)), - safe_getproperty(ps, Val(:bias)), GN.groups, σ, GN.epsilon) + y = groupnorm(x′, safe_getproperty(ps, Val(:scale)), safe_getproperty(ps, Val(:bias)), + GN.groups, σ, convert(unwrapped_eltype(x′), GN.epsilon)) return y, st end @@ -350,8 +351,9 @@ parameterlength(l::InstanceNorm) = ifelse(has_affine(l), l.chs * 2, 0) function (IN::InstanceNorm)(x::AbstractArray, ps, st::NamedTuple) x′ = match_eltype(IN, ps, st, x) σ = NNlib.fast_act(IN.activation, x′) - y, _ = instancenorm(x′, safe_getproperty(ps, Val(:scale)), - safe_getproperty(ps, Val(:bias)), st.training, σ, IN.epsilon) + y, _ = instancenorm( + x′, safe_getproperty(ps, Val(:scale)), safe_getproperty(ps, Val(:bias)), + st.training, σ, convert(unwrapped_eltype(x′), IN.epsilon)) return y, st end @@ -568,8 +570,8 @@ end function (l::LayerNorm)(x::AbstractArray, ps, st::NamedTuple) x′ = match_eltype(l, ps, st, x) σ = NNlib.fast_act(l.activation, x′) - y = layernorm(x′, safe_getproperty(ps, Val(:scale)), - safe_getproperty(ps, Val(:bias)), σ, l.dims, l.epsilon) + y = layernorm(x′, safe_getproperty(ps, Val(:scale)), safe_getproperty(ps, Val(:bias)), + σ, l.dims, convert(unwrapped_eltype(x′), l.epsilon)) return y, st end diff --git a/src/utils.jl b/src/utils.jl index 4cea362907..e27642eda6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -208,3 +208,4 @@ using .Utils: Utils, BoolType, IntegerType, SymbolType, make_abstract_matrix, const safe_reverse = Utils.reverse const safe_vec = Utils.vec +const unwrapped_eltype = Utils.eltype From dce920dd379e601118735998a7c69f99a4eedf9c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 4 Sep 2024 18:56:11 -0400 Subject: [PATCH 54/95] feat: use the device iterators in the examples --- examples/ConvMixer/main.jl | 10 ++++------ examples/DDIM/main.jl | 3 +-- examples/HyperNet/main.jl | 35 ++++++++++++++++++----------------- examples/NeuralODE/main.jl | 18 ++++++++---------- examples/SimpleChains/main.jl | 17 +++++++++++------ examples/SimpleRNN/main.jl | 11 +++-------- 6 files changed, 45 insertions(+), 49 deletions(-) diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index 21ece57fc1..2c9dc2824c 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -55,13 +55,13 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2) #! format: on end -function accuracy(model, ps, st, dataloader; dev=gpu_device()) +function accuracy(model, ps, st, dataloader) total_correct, total = 0, 0 st = Lux.testmode(st) cpu_dev = cpu_device() for (x, y) in dataloader - target_class = onecold(y) - predicted_class = onecold(cpu_dev(first(model(dev(x), ps, st)))) + target_class = onecold(cpu_dev(y)) + predicted_class = onecold(cpu_dev(first(model(x, ps, st)))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end @@ -74,7 +74,7 @@ end rng = StableRNG(seed) gdev = gpu_device() - trainloader, testloader = get_dataloaders(batchsize) + trainloader, testloader = get_dataloaders(batchsize) .|> gdev model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size) ps, st = Lux.setup(rng, model) |> gdev @@ -96,8 +96,6 @@ end for (i, (x, y)) in enumerate(trainloader) lr = lr_schedule((epoch - 1) + (i + 1) / length(trainloader)) train_state = Optimisers.adjust!(train_state, lr) - x = x |> gdev - y = y |> gdev (_, _, _, train_state) = Training.single_train_step!( AutoZygote(), loss, (x, y), train_state) end diff --git a/examples/DDIM/main.jl b/examples/DDIM/main.jl index e97d72939f..6e81b88f8d 100644 --- a/examples/DDIM/main.jl +++ b/examples/DDIM/main.jl @@ -359,7 +359,7 @@ end @info "Preparing dataset" ds = FlowersDataset(x -> preprocess_image(x, image_size), true) - data_loader = DataLoader(ds; batchsize, collate=true, parallel=true) + data_loader = DataLoader(ds; batchsize, collate=true, parallel=true) |> gdev scheduler = CosAnneal(learning_rate_start, learning_rate_end, epochs) @@ -376,7 +376,6 @@ end for (i, data) in enumerate(data_loader) step += 1 - data = data |> gdev (_, _, stats, tstate) = Training.single_train_step!( AutoZygote(), loss_function, data, tstate) image_losses[i] = stats.image_loss diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index f20f94643b..e2d12e7c91 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -59,13 +59,11 @@ end # ## Define Utility Functions const loss = CrossEntropyLoss(; logits=Val(true)) -function accuracy(model, ps, st, dataloader, data_idx, gdev=gpu_device()) +function accuracy(model, ps, st, dataloader, data_idx) total_correct, total = 0, 0 st = Lux.testmode(st) cpu_dev = cpu_device() for (x, y) in dataloader - x = x |> gdev - y = y |> gdev target_class = onecold(cpu_dev(y)) predicted_class = onecold(cpu_dev(model((data_idx, x), ps, st)[1])) total_correct += sum(target_class .== predicted_class) @@ -86,26 +84,24 @@ function train() train_state = Training.TrainState(model, ps, st, Adam(3.0f-4)) ### Lets train the model - nepochs = 10 + nepochs = 25 for epoch in 1:nepochs, data_idx in 1:2 - train_dataloader, test_dataloader = dataloaders[data_idx] + train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev stime = time() for (x, y) in train_dataloader - x = x |> dev - y = y |> dev (_, _, _, train_state) = Training.single_train_step!( AutoZygote(), loss, ((data_idx, x), y), train_state) end ttime = time() - stime train_acc = round( - accuracy(model, train_state.parameters, train_state.states, - train_dataloader, data_idx, dev) * 100; + accuracy(model, train_state.parameters, + train_state.states, train_dataloader, data_idx) * 100; digits=2) test_acc = round( - accuracy(model, train_state.parameters, train_state.states, - test_dataloader, data_idx, dev) * 100; + accuracy(model, train_state.parameters, + train_state.states, test_dataloader, data_idx) * 100; digits=2) data_name = data_idx == 1 ? "MNIST" : "FashionMNIST" @@ -116,22 +112,27 @@ function train() println() + test_acc_list = [0.0, 0.0] for data_idx in 1:2 - train_dataloader, test_dataloader = dataloaders[data_idx] + train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev train_acc = round( - accuracy(model, train_state.parameters, train_state.states, - train_dataloader, data_idx, dev) * 100; + accuracy(model, train_state.parameters, + train_state.states, train_dataloader, data_idx) * 100; digits=2) test_acc = round( - accuracy(model, train_state.parameters, train_state.states, - test_dataloader, data_idx, dev) * 100; + accuracy(model, train_state.parameters, + train_state.states, test_dataloader, data_idx) * 100; digits=2) data_name = data_idx == 1 ? "MNIST" : "FashionMNIST" @printf "[FINAL] \t %12s \t Training Accuracy: %.2f%% \t Test Accuracy: \ %.2f%%\n" data_name train_acc test_acc + test_acc_list[data_idx] = test_acc end + return test_acc_list end -train() +test_acc_list = train() +@assert test_acc_list[1] > 0.90 && test_acc_list[2] > 0.70 #hide +nothing #hide diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index 013b64f20d..2b83e13bab 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -9,8 +9,8 @@ using Lux, ComponentArrays, SciMLSensitivity, LuxCUDA, Optimisers, OrdinaryDiffEq, Random, Statistics, Zygote, OneHotArrays, InteractiveUtils, Printf -import MLDatasets: MNIST -import MLUtils: DataLoader, splitobs +using MLDatasets: MNIST +using MLUtils: DataLoader, splitobs CUDA.allowscalar(false) @@ -106,13 +106,13 @@ end # ## Define Utility Functions const logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) -function accuracy(model, ps, st, dataloader; dev=gpu_device()) +function accuracy(model, ps, st, dataloader) total_correct, total = 0, 0 st = Lux.testmode(st) cpu_dev = cpu_device() for (x, y) in dataloader - target_class = onecold(y) - predicted_class = onecold(cpu_dev(first(model(dev(x), ps, st)))) + target_class = onecold(cpu_dev(y)) + predicted_class = onecold(cpu_dev(first(model(x, ps, st)))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end @@ -125,7 +125,7 @@ function train(model_function; cpu::Bool=false, kwargs...) model, ps, st = create_model(model_function; dev, kwargs...) ## Training - train_dataloader, test_dataloader = loadmnist(128, 0.9) + train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev tstate = Training.TrainState(model, ps, st, Adam(0.001f0)) @@ -134,15 +134,13 @@ function train(model_function; cpu::Bool=false, kwargs...) for epoch in 1:nepochs stime = time() for (x, y) in train_dataloader - x = dev(x) - y = dev(y) _, _, _, tstate = Training.single_train_step!( AutoZygote(), logitcrossentropy, (x, y), tstate) end ttime = time() - stime - tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader; dev) - te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader; dev) + tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) + te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) @printf "[%d/%d] \t Time %.2fs \t Training Accuracy: %.5f%% \t Test \ Accuracy: %.5f%%\n" epoch nepochs ttime tr_acc te_acc end diff --git a/examples/SimpleChains/main.jl b/examples/SimpleChains/main.jl index 0bbd62944b..6726800594 100644 --- a/examples/SimpleChains/main.jl +++ b/examples/SimpleChains/main.jl @@ -8,8 +8,8 @@ # ## Package Imports using Lux, ADTypes, MLUtils, Optimisers, Zygote, OneHotArrays, Random, Statistics, Printf -import MLDatasets: MNIST -import SimpleChains: static +using MLDatasets: MNIST +using SimpleChains: SimpleChains # ## Loading MNIST function loadmnist(batchsize, train_split) @@ -19,7 +19,7 @@ function loadmnist(batchsize, train_split) imgs = dataset.features[:, :, 1:N] labels_raw = dataset.targets[1:N] - ## Process images into (H,W,C,BS) batches + ## Process images into (H, W, C, BS) batches x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) y_data = onehotbatch(labels_raw, 0:9) (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split) @@ -40,7 +40,7 @@ lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)), # We now need to convert the lux_model to SimpleChains.jl. We need to do this by defining # the [`ToSimpleChainsAdaptor`](@ref) and providing the input dimensions. -adaptor = ToSimpleChainsAdaptor((static(28), static(28), static(1))) +adaptor = ToSimpleChainsAdaptor((28, 28, 1)) simple_chains_model = adaptor(lux_model) # ## Helper Functions @@ -72,10 +72,11 @@ function train(model; rng=Xoshiro(0), kwargs...) ### Lets train the model nepochs = 10 + tr_acc, te_acc = 0.0, 0.0 for epoch in 1:nepochs stime = time() for (x, y) in train_dataloader - (gs, _, _, train_state) = Training.single_train_step!( + gs, _, _, train_state = Training.single_train_step!( AutoZygote(), loss, (x, y), train_state) end ttime = time() - stime @@ -88,16 +89,20 @@ function train(model; rng=Xoshiro(0), kwargs...) @printf "[%2d/%2d] \t Time %.2fs \t Training Accuracy: %.2f%% \t Test Accuracy: \ %.2f%%\n" epoch nepochs ttime tr_acc te_acc end + + return tr_acc, te_acc end # ## Finally Training the Model # First we will train the Lux model -train(lux_model) +tr_acc, te_acc = train(lux_model) +@assert tr_acc > 0.75 && te_acc > 0.75 #hide nothing #hide # Now we will train the SimpleChains model train(simple_chains_model) +@assert tr_acc > 0.75 && te_acc > 0.75 #hide nothing #hide # On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index b85692fe4c..e0ba547138 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -128,11 +128,11 @@ accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred) # ## Training the Model function main(model_type) - ## Get the dataloaders - (train_loader, val_loader) = get_dataloaders() - dev = gpu_device() + ## Get the dataloaders + train_loader, val_loader = get_dataloaders() .|> dev + ## Create the model model = model_type(2, 8, 1) rng = Xoshiro(0) @@ -143,9 +143,6 @@ function main(model_type) for epoch in 1:25 ## Train the model for (x, y) in train_loader - x = x |> dev - y = y |> dev - (_, loss, _, train_state) = Training.single_train_step!( AutoZygote(), lossfn, (x, y), train_state) @@ -155,8 +152,6 @@ function main(model_type) ## Validate the model st_ = Lux.testmode(train_state.states) for (x, y) in val_loader - x = x |> dev - y = y |> dev ŷ, st_ = model(x, train_state.parameters, st_) loss = lossfn(ŷ, y) acc = accuracy(ŷ, y) From 7b96d5d08410c3018d58033921204f208624e532 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 4 Sep 2024 19:08:44 -0400 Subject: [PATCH 55/95] fix: mark `Utils.eltype` as non-differentiable --- docs/src/introduction/updating_to_v1.md | 2 +- src/utils.jl | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/src/introduction/updating_to_v1.md b/docs/src/introduction/updating_to_v1.md index bbdffd98e3..72feb2a76e 100644 --- a/docs/src/introduction/updating_to_v1.md +++ b/docs/src/introduction/updating_to_v1.md @@ -95,7 +95,7 @@ abstraction. - `Lux.Experimental.@layer_map` is not longer needed and has been removed. The name of the variable prevents writing generic functions and is no longer pre-pended to the `KeyPath`. See the docstring of [`Lux.Experimental.layer_map`](@ref) for more details. -- `allow_fast_activation` kwarg has been removed completely. Pass an annonymous function +- `allow_fast_activation` kwarg has been removed completely. Pass an anonymous function as the activation to prevent internal modivations to the activation function. ### Breaking Changes (Moved Functionality) diff --git a/src/utils.jl b/src/utils.jl index e27642eda6..a4177f7275 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -126,6 +126,8 @@ eltype(x) = eltype(Base.eltype(x)) eltype(::Type{T}) where {T} = T eltype(::Type{<:Dual{T, V}}) where {T, V} = V +@non_differentiable eltype(::Any) + ofeltype_array(::Type{T}, x::AbstractArray) where {T} = broadcast(T, x) function ofeltype_array(::Type{T}, x::AbstractArray{<:Dual{Tag, V, N}}) where {Tag, T, V, N} return Dual{Tag, T, N}.(x) From 8c964f4494847a434ff1004a298bc2cd01543493 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 4 Sep 2024 20:31:52 -0400 Subject: [PATCH 56/95] fix: misc docs issues --- docs/src/api/Lux/utilities.md | 6 ++++++ docs/src/introduction/updating_to_v1.md | 4 ++-- docs/src/manual/distributed_utils.md | 8 ++++---- docs/src/manual/freezing_model_parameters.md | 2 +- test/runtests.jl | 2 +- 5 files changed, 14 insertions(+), 8 deletions(-) diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index 19489e766d..744624fa1e 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -119,3 +119,9 @@ StatefulLuxLayer @init_fn @non_trainable ``` + +## Miscellaneous + +```@docs +Lux.set_dispatch_doctor_preferences! +``` diff --git a/docs/src/introduction/updating_to_v1.md b/docs/src/introduction/updating_to_v1.md index 72feb2a76e..0d0629a8e2 100644 --- a/docs/src/introduction/updating_to_v1.md +++ b/docs/src/introduction/updating_to_v1.md @@ -108,7 +108,7 @@ abstraction. - `Experimental.StatefulLuxLayer` has been moved to [`Lux.StatefulLuxLayer`](@ref). - `st_fixed_path` kwarg has been removed from [`Lux.StatefulLuxLayer`](@ref), instead use it as `StatefulLuxLayer{st_fixed_path}(...)`. -- Strings as inputs to [`Experimental.layer_map`](@ref) and - [`Experimental.@debug_mode`](@ref) are removed, use `Functors.KeyPath` instead. +- Strings as inputs to [`Lux.Experimental.layer_map`](@ref) and + [`Lux.Experimental.@debug_mode`](@ref) are removed, use `Functors.KeyPath` instead. ### Breaking Changes (Changes in Defaults) diff --git a/docs/src/manual/distributed_utils.md b/docs/src/manual/distributed_utils.md index dbee8ab110..677e473777 100644 --- a/docs/src/manual/distributed_utils.md +++ b/docs/src/manual/distributed_utils.md @@ -87,10 +87,10 @@ And that's pretty much it! as input. 3. We don't automatically determine if the MPI Implementation is CUDA or ROCM aware. See [GPU-aware MPI](@ref gpu-aware-mpi-preferences) for more information. -4. Older [`Lux.gpu`](@ref) implementations used to "just work" with `FluxMPI.jl`. We expect - [`gpu_device`](@ref) to continue working as expected, however, we recommend using - [`gpu_device`](@ref) after calling [`DistributedUtils.initialize`](@ref) to avoid any - mismatch between the device set via `DistributedUtils` and the device stores in +4. Older (now non-existent) `Lux.gpu` implementations used to "just work" with `FluxMPI.jl`. + We expect [`gpu_device`](@ref) to continue working as expected, however, we recommend + using [`gpu_device`](@ref) after calling [`DistributedUtils.initialize`](@ref) to avoid + any mismatch between the device set via `DistributedUtils` and the device stores in `CUDADevice` or `AMDGPUDevice`. ## Known Shortcomings diff --git a/docs/src/manual/freezing_model_parameters.md b/docs/src/manual/freezing_model_parameters.md index 0d5258cfa7..5f2f4055e6 100644 --- a/docs/src/manual/freezing_model_parameters.md +++ b/docs/src/manual/freezing_model_parameters.md @@ -13,7 +13,7 @@ We can use [`Lux.Experimental.layer_map`](@ref) and freeze layers if they are of `Dense`. ```@example freezing_model_parameters -using Lux, Functors, Random +using Lux, Random rng = Xoshiro(0) diff --git a/test/runtests.jl b/test/runtests.jl index 5ac9215c02..a6160b119d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -93,7 +93,7 @@ const RETESTITEMS_NWORKERS = parse( @info "Running tests for group: [$(i)/$(length(LUX_TEST_GROUP))] $tag" ReTestItems.runtests(Lux; tags=(tag == "all" ? nothing : [Symbol(tag)]), - nworkers=RETESTITEMS_NWORKERS, testitem_timeout=1800, retries=1) + nworkers=RETESTITEMS_NWORKERS, testitem_timeout=2400, retries=1) end end From 5e8cda13d5660578230aa151e16e9bcbfa35828d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 4 Sep 2024 22:27:17 -0400 Subject: [PATCH 57/95] chore: remove old compat --- examples/ConvMixer/Project.toml | 2 +- examples/DDIM/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ConvMixer/Project.toml b/examples/ConvMixer/Project.toml index 35d1b5fa95..ca93123123 100644 --- a/examples/ConvMixer/Project.toml +++ b/examples/ConvMixer/Project.toml @@ -22,7 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Comonicon = "1.0.8" ConcreteStructs = "0.2.3" -DataAugmentation = "0.2.12, 0.3" +DataAugmentation = "0.3" ImageCore = "0.10.2" ImageShow = "0.3.8" Interpolations = "0.15.1" diff --git a/examples/DDIM/Project.toml b/examples/DDIM/Project.toml index 60166460e8..42a76263b4 100644 --- a/examples/DDIM/Project.toml +++ b/examples/DDIM/Project.toml @@ -31,7 +31,7 @@ CairoMakie = "0.12" ChainRulesCore = "1.23" Comonicon = "1" ConcreteStructs = "0.2.3" -DataAugmentation = "0.2.12, 0.3" +DataAugmentation = "0.3" DataDeps = "0.7.13" FileIO = "1.16" ImageCore = "0.9, 0.10" From 4d3f99c4cdb38eb8eac7f54cbb4ecc9d81e77a29 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 09:57:31 -0400 Subject: [PATCH 58/95] feat: track running statistics in InstanceNorm --- Project.toml | 4 ++-- src/layers/normalize.jl | 44 ++++++++++++++++++++++++++++------ test/layers/normalize_tests.jl | 10 ++++---- 3 files changed, 44 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index 3ebf6b75a3..537c8574ea 100644 --- a/Project.toml +++ b/Project.toml @@ -85,8 +85,8 @@ Functors = "0.4.12" GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" LossFunctions = "0.11.1" -LuxCore = "1.0" -LuxLib = "1.0" +LuxCore = "1" +LuxLib = "1.2" MLDataDevices = "1.1" MLUtils = "0.4.4" MPI = "0.20.19" diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 4d483e8a8b..50e5e24295 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -31,7 +31,7 @@ slice and normalises the input accordingly. ## Inputs - - `x`: Array where `size(x, N - 1) = chs` and `ndims(x) > 2` + - `x`: Array where `size(x, N - 1) = chs` ## Returns @@ -257,7 +257,7 @@ end @doc doc""" InstanceNorm(chs::Integer, activation=identity; init_bias=zeros32, init_scale=ones32, - affine=True(), epsilon=1f-5) + affine=True(), track_stats=False(), epsilon=1f-5, momentum=0.1f0) Instance Normalization. For details see [1]. @@ -274,13 +274,19 @@ accordingly. ## Keyword Arguments + - If `track_stats=true`, accumulates mean and variance statistics in training phase that + will be used to renormalize the input in test phase. + - `epsilon`: a value added to the denominator for numerical stability + - `momentum`: the value used for the `running_mean` and `running_var` computation - If `affine=true`, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters. + `init_bias`: Controls how the `bias` is initialized + `init_scale`: Controls how the `scale` is initialized +# Extended Help + ## Inputs - `x`: Array where `size(x, N - 1) = chs` and `ndims(x) > 2` @@ -301,6 +307,15 @@ accordingly. ## States + - Statistics if `track_stats=true` + + + `running_mean`: Running mean of shape `(chs,)` + + `running_var`: Running variance of shape `(chs,)` + + - Statistics if `track_stats=false` + + + `running_mean`: nothing + + `running_var`: nothing - `training`: Used to check if training/inference mode Use `Lux.testmode` during inference. @@ -328,16 +343,20 @@ See also [`BatchNorm`](@ref), [`GroupNorm`](@ref), [`LayerNorm`](@ref), [`Weight """ @concrete struct InstanceNorm <: AbstractLuxLayer activation - epsilon + epsilon <: Real + momentum <: Real chs <: IntegerType init_bias init_scale affine <: StaticBool + track_stats <: StaticBool end function InstanceNorm(chs::IntegerType, activation=identity; init_bias=zeros32, - init_scale=ones32, affine::BoolType=True(), epsilon=1.0f-5) - return InstanceNorm(activation, epsilon, chs, init_bias, init_scale, static(affine)) + init_scale=ones32, affine::BoolType=True(), + track_stats::BoolType=False(), epsilon=1.0f-5, momentum=0.1f0) + return InstanceNorm(activation, epsilon, momentum, chs, init_bias, + init_scale, static(affine), static(track_stats)) end function initialparameters(rng::AbstractRNG, l::InstanceNorm) @@ -345,15 +364,25 @@ function initialparameters(rng::AbstractRNG, l::InstanceNorm) return (;) end -initialstates(::AbstractRNG, ::InstanceNorm) = (; training=Val(true)) +function initialstates(rng::AbstractRNG, l::InstanceNorm) + if has_track_stats(l) + return (running_mean=zeros32(rng, l.chs), + running_var=ones32(rng, l.chs), training=Val(true)) + end + return (; training=Val(true)) +end + parameterlength(l::InstanceNorm) = ifelse(has_affine(l), l.chs * 2, 0) +statelength(l::InstanceNorm) = ifelse(has_track_stats(l), l.chs * 2, 0) + 1 function (IN::InstanceNorm)(x::AbstractArray, ps, st::NamedTuple) x′ = match_eltype(IN, ps, st, x) σ = NNlib.fast_act(IN.activation, x′) y, _ = instancenorm( x′, safe_getproperty(ps, Val(:scale)), safe_getproperty(ps, Val(:bias)), - st.training, σ, convert(unwrapped_eltype(x′), IN.epsilon)) + safe_getproperty(st, Val(:running_mean)), safe_getproperty(st, Val(:running_var)), + st.training, σ, convert(unwrapped_eltype(x′), IN.momentum), + convert(unwrapped_eltype(x′), IN.epsilon)) return y, st end @@ -361,6 +390,7 @@ function Base.show(io::IO, l::InstanceNorm) print(io, "InstanceNorm($(l.chs)") (l.activation == identity) || print(io, ", $(l.activation)") print(io, ", affine=$(has_affine(l))") + print(io, ", track_stats=$(has_track_stats(l))") return print(io, ")") end diff --git a/test/layers/normalize_tests.jl b/test/layers/normalize_tests.jl index 567a6a72ac..b7b27e7a97 100644 --- a/test/layers/normalize_tests.jl +++ b/test/layers/normalize_tests.jl @@ -367,10 +367,10 @@ end for x in (randn(rng, Float32, 3, 3, 3, 2), randn(rng, Float32, 3, 3, 2), randn(rng, Float32, 3, 3, 3, 3, 2)) x = x |> aType - for affine in (true, false) - layer = InstanceNorm(3; affine) + for affine in (true, false), track_stats in (true, false) + layer = InstanceNorm(3; affine, track_stats) display(layer) - ps, st = Lux.setup(rng, layer) .|> device + ps, st = Lux.setup(rng, layer) |> device y, st_ = layer(x, ps, st) @@ -387,9 +387,9 @@ end end for act in (sigmoid, tanh) - layer = InstanceNorm(3, act; affine) + layer = InstanceNorm(3, act; affine, track_stats) display(layer) - ps, st = Lux.setup(rng, layer) .|> device + ps, st = Lux.setup(rng, layer) |> device y, st_ = layer(x, ps, st) From 95953aac0f5709e72abfac4385f1973fc3d6f72e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 13:54:21 -0400 Subject: [PATCH 59/95] feat!: match initialization of convolution layers with Pytorch --- benchmarks/setup.jl | 2 +- docs/src/api/Lux/layers.md | 1 - docs/src/introduction/updating_to_v1.md | 16 + ext/LuxFluxExt.jl | 8 +- src/Lux.jl | 2 +- src/layers/conv.jl | 444 ++++++++++-------------- src/utils.jl | 15 + test/layers/conv_tests.jl | 13 +- test/shared_testsetup.jl | 7 +- test/zygote_type_stability.jl | 2 +- 10 files changed, 231 insertions(+), 279 deletions(-) diff --git a/benchmarks/setup.jl b/benchmarks/setup.jl index 2ed92f2e0f..e2d05bc889 100644 --- a/benchmarks/setup.jl +++ b/benchmarks/setup.jl @@ -1,6 +1,6 @@ using ADTypes: ADTypes, AutoEnzyme, AutoZygote using Adapt: adapt -using Lux: Lux, BatchNorm, Chain, Conv, CrossCor, Dense, Dropout, FlattenLayer, MaxPool +using Lux: Lux, BatchNorm, Chain, Conv, Dense, Dropout, FlattenLayer, MaxPool using MLDataDevices: AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice using NNlib: relu, gelu using Random: Random diff --git a/docs/src/api/Lux/layers.md b/docs/src/api/Lux/layers.md index 1eb7da0f50..6041984940 100644 --- a/docs/src/api/Lux/layers.md +++ b/docs/src/api/Lux/layers.md @@ -22,7 +22,6 @@ RepeatedLayer ```@docs Conv ConvTranspose -CrossCor ``` ## Dropout Layers diff --git a/docs/src/introduction/updating_to_v1.md b/docs/src/introduction/updating_to_v1.md index 0d0629a8e2..70cd4b2d6f 100644 --- a/docs/src/introduction/updating_to_v1.md +++ b/docs/src/introduction/updating_to_v1.md @@ -12,6 +12,11 @@ list out some new exciting features that were added as part of this release. API](@ref LuxLib-API) for more details. - Default for [`layernorm`](@ref) dims has been changed to exclude the batch dimension. +### New Major Features + +- Dense layers now support CUDA backend for Enzyme (starting `v1.1`). Wider support for + other operations with Enzyme + CUDA is being actively worked on. + ## `LuxCore.jl` ### Breaking Changes @@ -110,5 +115,16 @@ abstraction. as `StatefulLuxLayer{st_fixed_path}(...)`. - Strings as inputs to [`Lux.Experimental.layer_map`](@ref) and [`Lux.Experimental.@debug_mode`](@ref) are removed, use `Functors.KeyPath` instead. +- `CrossCor` has been removed. Use `Conv(args...; kwargs..., cross_correlation=true)` + instead. ### Breaking Changes (Changes in Defaults) + +- [`Conv`](@ref) and [`ConvTranspose`](@ref) use an initialization based on the activation + function, taken from Pytorch. Pytorch assumes the activation function is `leakyrelu` to + compute the gain, however, we compute the gain based on the activation function passed in + to the layer. + +### New Features + +- [`InstanceNorm`](@ref) now supports tracking statistics. diff --git a/ext/LuxFluxExt.jl b/ext/LuxFluxExt.jl index 1f73de3d78..0116c200af 100644 --- a/ext/LuxFluxExt.jl +++ b/ext/LuxFluxExt.jl @@ -131,12 +131,12 @@ function Lux.convert_flux_model(l::Flux.CrossCor; preserve_ps_st::Bool=false, kw pad = l.pad isa Flux.SamePad ? SamePad() : l.pad if preserve_ps_st _bias = l.bias isa Bool ? nothing : vec(copy(l.bias)) - return Lux.CrossCor(k, in_chs => out_chs, l.σ; l.stride, pad, + return Lux.Conv(k, in_chs => out_chs, l.σ; l.stride, pad, l.dilation, init_weight=Returns(copy(l.weight)), - init_bias=Returns(_bias), use_bias=!(l.bias isa Bool)) + init_bias=Returns(_bias), use_bias=!(l.bias isa Bool), cross_correlation=true) else - return Lux.CrossCor(k, in_chs => out_chs, l.σ; l.stride, pad, - l.dilation, use_bias=!(l.bias isa Bool)) + return Lux.Conv(k, in_chs => out_chs, l.σ; l.stride, pad, + l.dilation, use_bias=!(l.bias isa Bool), cross_correlation=true) end end diff --git a/src/Lux.jl b/src/Lux.jl index c8d668bec1..a0650c673f 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -87,7 +87,7 @@ include("distributed/public_api.jl") # Layers export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer export Bilinear, Dense, Embedding, Scale -export Conv, ConvTranspose, CrossCor, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool, +export Conv, ConvTranspose, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool, AdaptiveMaxPool, AdaptiveMeanPool, Upsample, PixelShuffle export AlphaDropout, Dropout, VariationalHiddenDropout export BatchNorm, GroupNorm, InstanceNorm, LayerNorm diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 4e0aebeb3b..0fdf3cc01c 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -51,18 +51,35 @@ end CRC.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any) -function init_conv_filter(rng::AbstractRNG, filter::NTuple{N, Integer}, - ch::Pair{<:Integer, <:Integer}; init=glorot_uniform, groups=1) where {N} - cin, cout = ch - @argcheck cin % groups==0 DimensionMismatch("Input channel dimension must be divisible by groups.") - @argcheck cout % groups==0 DimensionMismatch("Output channel dimension must be divisible by groups.") - return init(rng, filter..., cin ÷ groups, cout) +function init_conv_weight( + rng::AbstractRNG, init_weight::F, filter::NTuple{N, <:IntegerType}, + in_chs::IntegerType, out_chs::IntegerType, groups, σ::A) where {F, N, A} + if init_weight === nothing # Default from PyTorch + gain = Utils.calculate_gain(σ, √5.0f0) + return kaiming_uniform(rng, Float32, filter..., in_chs ÷ groups, out_chs; gain) + end + return init_weight(rng, filter..., in_chs ÷ groups, out_chs) +end + +function init_conv_bias(rng::AbstractRNG, init_bias::F, filter::NTuple{N, <:IntegerType}, + in_chs::IntegerType, out_chs::IntegerType, groups) where {F, N} + if init_bias === nothing # Default from PyTorch + fan_in = prod(filter) * (in_chs ÷ groups) + bound = inv(sqrt(fan_in)) + y = rand32(rng, out_chs) + @. y = y * 2bound - bound + return y + end + return init_bias(rng, out_chs) end +construct_crosscor_convdims(::False, cdims::DenseConvDims) = cdims +construct_crosscor_convdims(::True, cdims::DenseConvDims) = DenseConvDims(cdims; F=true) + @doc doc""" Conv(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, - activation=identity; init_weight=glorot_uniform, init_bias=zeros32, stride=1, - pad=0, dilation=1, groups=1, use_bias=True()) + activation=identity; init_weight=nothing, init_bias=nothing, stride=1, + pad=0, dilation=1, groups=1, use_bias=True(), cross_correlation=False()) Standard convolutional layer. @@ -79,7 +96,8 @@ Standard convolutional layer. !!! warning Frameworks like [`Pytorch`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) - perform cross-correlation in their convolution layers + perform cross-correlation in their convolution layers. Pass `cross_correlation=true` to + use cross-correlation instead. ## Arguments @@ -93,8 +111,13 @@ Standard convolutional layer. ## Keyword Arguments - - `init_weight`: Controls the initialization of the weight parameter - - `init_bias`: Controls the initialization of the bias parameter + - `init_weight`: Controls the initialization of the weight parameter. If `nothing`, then + we use [`kaiming_uniform`](@ref) with gain computed on the basis of the activation + function (taken from Pytorch + [`nn.init.calculate_gain`](https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.calculate_gain)). + - `init_bias`: Controls the initialization of the bias parameter. If `nothing`, then we + use uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(fan_in))`. - `stride`: Should each be either single integer, or a tuple with `N` integers - `dilation`: Should each be either single integer, or a tuple with `N` integers @@ -115,6 +138,9 @@ Standard convolutional layer. convolution into (set `groups = in_chs` for Depthwise Convolutions). `in_chs` and `out_chs` must be divisible by `groups`. - `use_bias`: Trainable bias can be disabled entirely by setting this to `false`. + - `cross_correlation`: If `true`, perform cross-correlation instead of convolution. Prior + to `v1`, Lux used to have a `CrossCor` layer which performed cross-correlation. This + was removed in `v1` in favor of `Conv` with `cross_correlation=true`. ## Inputs @@ -148,25 +174,30 @@ O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s init_weight init_bias use_bias <: StaticBool + cross_correlation <: StaticBool end function Conv(k::Tuple{Vararg{IntegerType}}, ch::Pair{<:IntegerType, <:IntegerType}, - activation=identity; init_weight=glorot_uniform, init_bias=zeros32, - stride=1, pad=0, dilation=1, groups=1, use_bias::BoolType=True()) + activation=identity; init_weight=nothing, + init_bias=nothing, stride=1, pad=0, dilation=1, groups=1, + use_bias::BoolType=True(), cross_correlation::BoolType=False()) stride = Utils.expand(Val(length(k)), stride) dilation = Utils.expand(Val(length(k)), dilation) pad = calc_padding(pad, k, dilation, stride) + + @argcheck ch[1] % groups==0 DimensionMismatch("Input channel dimension must be divisible by groups.") + @argcheck ch[2] % groups==0 DimensionMismatch("Output channel dimension must be divisible by groups.") @argcheck allequal(length, (stride, dilation, k)) - return Conv(activation, first(ch), last(ch), k, stride, pad, dilation, - groups, init_weight, init_bias, static(use_bias)) + return Conv(activation, first(ch), last(ch), k, stride, pad, dilation, groups, + init_weight, init_bias, static(use_bias), static(cross_correlation)) end function initialparameters(rng::AbstractRNG, c::Conv) - weight = init_conv_filter( - rng, c.kernel_size, c.in_chs => c.out_chs; init=c.init_weight, c.groups) + args = (c.kernel_size, c.in_chs, c.out_chs, c.groups) + weight = init_conv_weight(rng, c.init_weight, args..., c.activation) has_bias(c) || return (; weight) - return (; weight, bias=c.init_bias(rng, c.out_chs)) + return (; weight, bias=init_conv_bias(rng, c.init_bias, args...)) end function parameterlength(c::Conv) @@ -175,7 +206,8 @@ end function (c::Conv)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(c, ps, st, x) - cdims = DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups) + cdims = construct_crosscor_convdims(c.cross_correlation, + DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups)) bias = safe_getproperty(ps, Val(:bias)) σ = NNlib.fast_act(c.activation, y) return fused_conv_bias_activation(σ, ps.weight, y, bias, cdims), st @@ -191,9 +223,136 @@ function Base.show(io::IO, l::Conv) print(io, ", dilation=", PrettyPrinting.tuple_string(l.dilation)) (l.groups == 1) || print(io, ", groups=", l.groups) has_bias(l) || print(io, ", use_bias=false") + known(l.cross_correlation) && print(io, ", cross_correlation=true") print(io, ")") end +@doc doc""" + ConvTranspose(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, + activation=identity; init_weight=glorot_uniform, init_bias=zeros32, + stride=1, pad=0, dilation=1, groups=1, use_bias=True()) + +Standard convolutional transpose layer. + +## Arguments + + - `k`: Tuple of integers specifying the size of the convolutional kernel. Eg, for 2D + convolutions `length(k) == 2` + - `in_chs`: Number of input channels + - `out_chs`: Number of input and output channels + - `activation`: Activation Function + +## Keyword Arguments + + - `init_weight`: Controls the initialization of the weight parameter. If `nothing`, then + we use [`kaiming_uniform`](@ref) with gain computed on the basis of the activation + function (taken from Pytorch + [`nn.init.calculate_gain`](https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.calculate_gain)). + - `init_bias`: Controls the initialization of the bias parameter. If `nothing`, then we + use uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(fan_in))`. + + - `stride`: Should each be either single integer, or a tuple with `N` integers + - `dilation`: Should each be either single integer, or a tuple with `N` integers + - `pad`: Specifies the number of elements added to the borders of the data array. It can + be + + + a single integer for equal padding all around, + + a tuple of `N` integers, to apply the same padding at begin/end of each spatial + dimension, + + a tuple of `2*N` integers, for asymmetric padding, or + + the singleton `SamePad()`, to calculate padding such that + `size(output,d) == size(x,d) * stride` (possibly rounded) for each spatial + dimension. + + - `groups`: Expected to be an `Int`. It specifies the number of groups to divide a + convolution into (set `groups = in_chs` for Depthwise Convolutions). `in_chs` + and `out_chs` must be divisible by `groups`. + - `use_bias`: Trainable bias can be disabled entirely by setting this to `false`. + +# Extended Help + +## Inputs + + - `x`: Data satisfying `ndims(x) == N + 2 && size(x, N - 1) == in_chs`, i.e. + `size(x) = (I_N, ..., I_1, C_in, N)` + +## Returns + + - Output of the convolution transpose `y` of size `(O_N, ..., O_1, C_out, N)` where + - Empty `NamedTuple()` + +## Parameters + + - `weight`: Convolution Transpose kernel + - `bias`: Bias (present if `use_bias=true`) +""" +@concrete struct ConvTranspose <: AbstractLuxLayer + activation + in_chs <: IntegerType + out_chs <: IntegerType + kernel_size <: Tuple{Vararg{IntegerType}} + stride <: Tuple{Vararg{IntegerType}} + pad <: Tuple{Vararg{IntegerType}} + dilation <: Tuple{Vararg{IntegerType}} + groups <: IntegerType + init_weight + init_bias + use_bias <: StaticBool +end + +function ConvTranspose( + k::Tuple{Vararg{IntegerType}}, ch::Pair{<:IntegerType, <:IntegerType}, + activation=identity; init_weight=glorot_uniform, init_bias=zeros32, + stride=1, pad=0, dilation=1, groups=1, use_bias::BoolType=True()) + stride = Utils.expand(Val(length(k)), stride) + dilation = Utils.expand(Val(length(k)), dilation) + pad = if pad isa SamePad + calc_padding(pad, k .- stride .+ 1, dilation, stride) + else + calc_padding(pad, k, dilation, stride) + end + + @argcheck ch[2] % groups==0 DimensionMismatch("Input channel dimension must be divisible by groups.") + @argcheck ch[1] % groups==0 DimensionMismatch("Output channel dimension must be divisible by groups.") + @argcheck allequal(length, (stride, dilation, k)) + + return ConvTranspose(activation, first(ch), last(ch), k, stride, pad, dilation, + groups, init_weight, init_bias, static(use_bias)) +end + +function initialparameters(rng::AbstractRNG, c::ConvTranspose) + args = (c.kernel_size, c.out_chs, c.in_chs, c.groups) + weight = init_conv_weight(rng, c.init_weight, args..., c.activation) + has_bias(c) || return (; weight) + return (; weight, bias=init_conv_bias(rng, c.init_bias, args...)) +end + +function parameterlength(c::ConvTranspose) + return prod(c.kernel_size) * c.in_chs * c.out_chs ÷ c.groups + has_bias(c) * c.out_chs +end + +function (c::ConvTranspose)(x::AbstractArray, ps, st::NamedTuple) + y = match_eltype(c, ps, st, x) + cdims = conv_transpose_dims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups) + bias = safe_getproperty(ps, Val(:bias)) + σ = NNlib.fast_act(c.activation, y) + return bias_activation!!(σ, conv_transpose(y, ps.weight, cdims), bias), st +end + +function Base.show(io::IO, l::ConvTranspose) + print(io, "ConvTranspose(", l.kernel_size) + print(io, ", ", l.in_chs, " => ", l.out_chs) + l.activation == identity || print(io, ", ", l.activation) + all(==(0), l.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(l.pad)) + all(==(1), l.stride) || print(io, ", stride=", PrettyPrinting.tuple_string(l.stride)) + all(==(1), l.dilation) || + print(io, ", dilation=", PrettyPrinting.tuple_string(l.dilation)) + (l.groups == 1) || print(io, ", groups=", l.groups) + has_bias(l) || print(io, ", use_bias=false") + return print(io, ")") +end + @doc doc""" MaxPool(window::NTuple; pad=0, stride=window) @@ -566,249 +725,10 @@ function set to `Base.Fix2(pixel_shuffle, r)` - Output of size `(r x W, r x H, C, N)` for 4D-arrays, and `(r x W, r x H, ..., C, N)` for D-dimensional data, where `D = ndims(x) - 2` """ -PixelShuffle(r::IntegerType) = WrappedFunction(Base.Fix2(pixel_shuffle, r)) - -@doc doc""" - CrossCor(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, - activation=identity; init_weight=glorot_uniform, init_bias=zeros32, stride=1, - pad=0, dilation=1, groups=1, use_bias=True()) - -Cross Correlation layer. - -Image data should be stored in WHCN order (width, height, channels, batch). In other words, -a `100 x 100` RGB image would be a `100 x 100 x 3 x 1` array, and a batch of 50 would be a -`100 x 100 x 3 x 50` array. This has `N = 2` spatial dimensions, and needs a kernel size -like `(5, 5)`, a 2-tuple of integers. To take convolutions along `N` feature dimensions, -this layer expects as input an array with `ndims(x) == N + 2`, where -`size(x, N + 1) == in_chs` is the number of input channels, and `size(x, ndims(x))` is the -number of observations in a batch. - -## Arguments - - - `k`: Tuple of integers specifying the size of the convolutional kernel. Eg, for 2D - convolutions `length(k) == 2` - - `in_chs`: Number of input channels - - `out_chs`: Number of input and output channels - - `activation`: Activation Function - -## Keyword Arguments - - - `init_weight`: Controls the initialization of the weight parameter - - `init_bias`: Controls the initialization of the bias parameter - - - `stride`: Should each be either single integer, or a tuple with `N` integers - - `dilation`: Should each be either single integer, or a tuple with `N` integers - - `groups`: Expected to be an `Int`. It specifies the number of groups to divide a - convolution into (set `groups = in_chs` for Depthwise Convolutions). `in_chs` - and `out_chs` must be divisible by `groups`. - - `pad`: Specifies the number of elements added to the borders of the data array. It can - be - - + a single integer for equal padding all around, - + a tuple of `N` integers, to apply the same padding at begin/end of each spatial - dimension, - + a tuple of `2*N` integers, for asymmetric padding, or - + the singleton `SamePad()`, to calculate padding such that - `size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial - dimension. - - - `use_bias`: Trainable bias can be disabled entirely by setting this to `false`. - -# Extended Help - -## Inputs - - - `x`: Data satisfying `ndims(x) == N + 2 && size(x, N - 1) == in_chs`, i.e. - `size(x) = (I_N, ..., I_1, C_in, N)` - -## Returns - - - Output of the convolution `y` of size `(O_N, ..., O_1, C_out, N)` where - -```math -O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s_i} + 1\right\rfloor -``` - - - Empty `NamedTuple()` - -## Parameters - - - `weight`: Convolution kernel - - `bias`: Bias (present if `use_bias=true`) -""" -@concrete struct CrossCor <: AbstractLuxLayer - activation - in_chs <: IntegerType - out_chs <: IntegerType - kernel_size <: Tuple{Vararg{IntegerType}} - stride <: Tuple{Vararg{IntegerType}} - pad <: Tuple{Vararg{IntegerType}} - dilation <: Tuple{Vararg{IntegerType}} - groups <: IntegerType - init_weight - init_bias - use_bias <: StaticBool -end - -function CrossCor(k::Tuple{Vararg{IntegerType}}, ch::Pair{<:IntegerType, <:IntegerType}, - activation=identity; init_weight=glorot_uniform, init_bias=zeros32, - stride=1, pad=0, dilation=1, groups=1, use_bias::BoolType=True()) - stride = Utils.expand(Val(length(k)), stride) - dilation = Utils.expand(Val(length(k)), dilation) - pad = calc_padding(pad, k, dilation, stride) - @argcheck allequal(length, (stride, dilation, k)) - - return CrossCor(activation, first(ch), last(ch), k, stride, pad, dilation, - groups, init_weight, init_bias, static(use_bias)) -end - -function initialparameters(rng::AbstractRNG, c::CrossCor) - weight = init_conv_filter( - rng, c.kernel_size, c.in_chs => c.out_chs; init=c.init_weight, c.groups) - has_bias(c) || return (; weight) - return (; weight, bias=c.init_bias(rng, c.out_chs)) -end - -function parameterlength(c::CrossCor) - return prod(c.kernel_size) * c.in_chs * c.out_chs ÷ c.groups + has_bias(c) * c.out_chs -end - -function (c::CrossCor)(x::AbstractArray, ps, st::NamedTuple) - y = match_eltype(c, ps, st, x) - cdims = DenseConvDims( - DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups); F=true) - bias = safe_getproperty(ps, Val(:bias)) - σ = NNlib.fast_act(c.activation, y) - return fused_conv_bias_activation(σ, ps.weight, y, bias, cdims), st +@concrete struct PixelShuffle <: AbstractLuxWrapperLayer{:layer} + layer <: AbstractLuxLayer end -function Base.show(io::IO, l::CrossCor) - print(io, "CrossCor(", l.kernel_size) - print(io, ", ", l.in_chs, " => ", l.out_chs) - l.activation == identity || print(io, ", ", l.activation) - all(==(0), l.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(l.pad)) - all(==(1), l.stride) || print(io, ", stride=", PrettyPrinting.tuple_string(l.stride)) - all(==(1), l.dilation) || - print(io, ", dilation=", PrettyPrinting.tuple_string(l.dilation)) - (l.groups == 1) || print(io, ", groups=", l.groups) - has_bias(l) || print(io, ", use_bias=false") - return print(io, ")") -end - -@doc doc""" - ConvTranspose(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, - activation=identity; init_weight=glorot_uniform, init_bias=zeros32, - stride=1, pad=0, dilation=1, groups=1, use_bias=True()) - -Standard convolutional transpose layer. - -## Arguments - - - `k`: Tuple of integers specifying the size of the convolutional kernel. Eg, for 2D - convolutions `length(k) == 2` - - `in_chs`: Number of input channels - - `out_chs`: Number of input and output channels - - `activation`: Activation Function - -## Keyword Arguments - - - `init_weight`: Controls the initialization of the weight parameter - - `init_bias`: Controls the initialization of the bias parameter - - - `stride`: Should each be either single integer, or a tuple with `N` integers - - `dilation`: Should each be either single integer, or a tuple with `N` integers - - `pad`: Specifies the number of elements added to the borders of the data array. It can - be - - + a single integer for equal padding all around, - + a tuple of `N` integers, to apply the same padding at begin/end of each spatial - dimension, - + a tuple of `2*N` integers, for asymmetric padding, or - + the singleton `SamePad()`, to calculate padding such that - `size(output,d) == size(x,d) * stride` (possibly rounded) for each spatial - dimension. - - - `groups`: Expected to be an `Int`. It specifies the number of groups to divide a - convolution into (set `groups = in_chs` for Depthwise Convolutions). `in_chs` - and `out_chs` must be divisible by `groups`. - - `use_bias`: Trainable bias can be disabled entirely by setting this to `false`. - -# Extended Help - -## Inputs - - - `x`: Data satisfying `ndims(x) == N + 2 && size(x, N - 1) == in_chs`, i.e. - `size(x) = (I_N, ..., I_1, C_in, N)` - -## Returns - - - Output of the convolution transpose `y` of size `(O_N, ..., O_1, C_out, N)` where - - Empty `NamedTuple()` - -## Parameters - - - `weight`: Convolution Transpose kernel - - `bias`: Bias (present if `use_bias=true`) -""" -@concrete struct ConvTranspose <: AbstractLuxLayer - activation - in_chs <: IntegerType - out_chs <: IntegerType - kernel_size <: Tuple{Vararg{IntegerType}} - stride <: Tuple{Vararg{IntegerType}} - pad <: Tuple{Vararg{IntegerType}} - dilation <: Tuple{Vararg{IntegerType}} - groups <: IntegerType - init_weight - init_bias - use_bias <: StaticBool -end - -function ConvTranspose( - k::Tuple{Vararg{IntegerType}}, ch::Pair{<:IntegerType, <:IntegerType}, - activation=identity; init_weight=glorot_uniform, init_bias=zeros32, - stride=1, pad=0, dilation=1, groups=1, use_bias::BoolType=True()) - stride = Utils.expand(Val(length(k)), stride) - dilation = Utils.expand(Val(length(k)), dilation) - pad = if pad isa SamePad - calc_padding(pad, k .- stride .+ 1, dilation, stride) - else - calc_padding(pad, k, dilation, stride) - end - @argcheck allequal(length, (stride, dilation, k)) - - return ConvTranspose(activation, first(ch), last(ch), k, stride, pad, dilation, - groups, init_weight, init_bias, static(use_bias)) -end - -function initialparameters(rng::AbstractRNG, c::ConvTranspose) - weight = init_conv_filter( - rng, c.kernel_size, c.out_chs => c.in_chs; init=c.init_weight, c.groups) - has_bias(c) || return (; weight) - return (; weight, bias=c.init_bias(rng, c.out_chs)) -end - -function parameterlength(c::ConvTranspose) - return prod(c.kernel_size) * c.in_chs * c.out_chs ÷ c.groups + has_bias(c) * c.out_chs -end - -function (c::ConvTranspose)(x::AbstractArray, ps, st::NamedTuple) - y = match_eltype(c, ps, st, x) - cdims = conv_transpose_dims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups) - bias = safe_getproperty(ps, Val(:bias)) - σ = NNlib.fast_act(c.activation, y) - return bias_activation!!(σ, conv_transpose(y, ps.weight, cdims), bias), st -end - -function Base.show(io::IO, l::ConvTranspose) - print(io, "ConvTranspose(", l.kernel_size) - print(io, ", ", l.in_chs, " => ", l.out_chs) - l.activation == identity || print(io, ", ", l.activation) - all(==(0), l.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(l.pad)) - all(==(1), l.stride) || print(io, ", stride=", PrettyPrinting.tuple_string(l.stride)) - all(==(1), l.dilation) || - print(io, ", dilation=", PrettyPrinting.tuple_string(l.dilation)) - (l.groups == 1) || print(io, ", groups=", l.groups) - has_bias(l) || print(io, ", use_bias=false") - return print(io, ")") +function PixelShuffle(r::IntegerType) + return PixelShuffle(WrappedFunction(Base.Fix2(pixel_shuffle, r))) end diff --git a/src/utils.jl b/src/utils.jl index a4177f7275..74973abe6c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -12,6 +12,7 @@ using Static: Static, StaticBool, StaticInteger, StaticSymbol using LuxCore: LuxCore, AbstractLuxLayer using MLDataDevices: get_device +using NNlib: NNlib const CRC = ChainRulesCore @@ -203,6 +204,20 @@ matrix_to_array(x::AbstractMatrix, ::AbstractVector) = vec(x) matrix_to_array(x::AbstractMatrix, ::AbstractMatrix) = x matrix_to_array(x::AbstractMatrix, y::AbstractArray) = reshape(x, :, size(y)[2:end]...) +# This should probably be in WeightInitializers.jl +calculate_gain(_, __) = 1.0f0 +calculate_gain(::typeof(identity), _) = 1.0f0 +calculate_gain(::typeof(NNlib.sigmoid), _) = 1.0f0 +calculate_gain(::typeof(NNlib.sigmoid_fast), _) = 1.0f0 +calculate_gain(::typeof(NNlib.relu), _) = 2.0f0 +calculate_gain(::typeof(tanh), _) = 5.0f0 / 3.0f0 +calculate_gain(::typeof(NNlib.tanh_fast), _) = 5.0f0 / 3.0f0 +function calculate_gain(::typeof(NNlib.leakyrelu), ::Nothing) + return calculate_gain(NNlib.leakyrelu, 0.1f0) +end +calculate_gain(::typeof(NNlib.leakyrelu), x::Real) = typeof(x)(√(2 / (1 + x^2))) +calculate_gain(::typeof(NNlib.selu), _) = 3.0f0 / 4 + end using .Utils: Utils, BoolType, IntegerType, SymbolType, make_abstract_matrix, diff --git a/test/layers/conv_tests.jl b/test/layers/conv_tests.jl index 2cb7b609ac..861bd85331 100644 --- a/test/layers/conv_tests.jl +++ b/test/layers/conv_tests.jl @@ -410,12 +410,12 @@ end end end -@testitem "CrossCor" setup=[SharedTestSetup] tags=[:core_layers] begin +@testitem "Conv(cross_correlation=true)" setup=[SharedTestSetup] tags=[:core_layers] begin rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES @testset "Asymmetric Padding" begin - layer = CrossCor((3, 3), 1 => 1, relu; pad=(0, 1, 1, 2)) + layer = Conv((3, 3), 1 => 1, relu; pad=(0, 1, 1, 2), cross_correlation=true) display(layer) x = ones(Float32, 28, 28, 1, 1) |> aType ps, st = Lux.setup(rng, layer) |> dev @@ -436,23 +436,24 @@ end end @testset "Variable BitWidth Parameters FluxML/Flux.jl#1421" begin - layer = CrossCor((5, 5), 10 => 20, identity; + layer = Conv((5, 5), 10 => 20, identity; init_weight=(rng, dims...) -> aType(randn(rng, Float64, dims...)), - init_bias=(rng, dims...) -> aType(randn(rng, Float16, dims...))) + init_bias=(rng, dims...) -> aType(randn(rng, Float16, dims...)), + cross_correlation=true) display(layer) ps, st = Lux.setup(rng, layer) @test ps.weight isa aType{Float64, 4} @test ps.bias isa aType{Float16, 1} end - @testset "CrossCor SamePad kernelsize $k" for k in ( + @testset "SamePad kernelsize $k" for k in ( (1,), (2,), (3,), (2, 3), (1, 2, 3)) x = ones(Float32, (k .+ 3)..., 1, 1) |> aType @testset "Kwargs: $kwarg" for kwarg in ( (; stride=1), (; dilation=max.(k .÷ 2, 1), stride=1), (; stride=3), (; stride=1, use_bias=false)) - layer = CrossCor(k, 1 => 1; pad=Lux.SamePad(), kwarg...) + layer = Conv(k, 1 => 1; pad=Lux.SamePad(), kwarg..., cross_correlation=true) display(layer) ps, st = Lux.setup(rng, layer) |> dev diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index 4ba455da8f..aba3646de4 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -5,11 +5,13 @@ include("setup_modes.jl") import Reexport: @reexport using Lux, Functors +using Setfield: @set using DispatchDoctor: allow_unstable @reexport using ComponentArrays, LuxCore, LuxLib, LuxTestUtils, Random, StableRNGs, Test, Zygote, Statistics, Enzyme, LinearAlgebra, ForwardDiff using MLDataDevices: default_device_rng, CPUDevice, CUDADevice, AMDGPUDevice using LuxTestUtils: check_approx +using Static: True LuxTestUtils.jet_target_modules!(["Lux", "LuxCore", "LuxLib"]) LinearAlgebra.BLAS.set_num_threads(Threads.nthreads()) @@ -24,9 +26,8 @@ end maybe_rewrite_to_crosscor(layer) = layer function maybe_rewrite_to_crosscor(layer::Conv) - return CrossCor(layer.activation, layer.in_chs, layer.out_chs, layer.kernel_size, - layer.stride, layer.pad, layer.dilation, layer.groups, - layer.init_weight, layer.init_bias, layer.use_bias) + @set layer.cross_correlation = True() + return layer end function maybe_rewrite_to_crosscor(mode, model) diff --git a/test/zygote_type_stability.jl b/test/zygote_type_stability.jl index 517ba590f1..b8d0d22c3d 100644 --- a/test/zygote_type_stability.jl +++ b/test/zygote_type_stability.jl @@ -77,7 +77,7 @@ include("setup_modes.jl") @test @inferred(model(x, ps, st)) isa Any @test @inferred(loss_function(model, x, ps, st)) isa Any - if mode == "amdgpu" && (model isa Conv || model isa CrossCor) + if mode == "amdgpu" && model isa Conv @test_broken @inferred(Zygote.gradient(loss_function, model, x, ps, st)) isa Any else From c6e6b042a075e72de43ab604417742d7dddd33ae Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 13:55:16 -0400 Subject: [PATCH 60/95] fix: docstrings in InstanceNorm --- src/layers/normalize.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 50e5e24295..047f581417 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -327,9 +327,9 @@ julia> Chain(Dense(784 => 64), InstanceNorm(64, relu), Dense(64 => 10), InstanceNorm(10, relu)) Chain( layer_1 = Dense(784 => 64), # 50_240 parameters - layer_2 = InstanceNorm(64, relu, affine=true), # 128 parameters, plus 1 + layer_2 = InstanceNorm(64, relu, affine=true, track_stats=false), # 128 parameters, plus 1 layer_3 = Dense(64 => 10), # 650 parameters - layer_4 = InstanceNorm(10, relu, affine=true), # 20 parameters, plus 1 + layer_4 = InstanceNorm(10, relu, affine=true, track_stats=false), # 20 parameters, plus 1 ) # Total: 51_038 parameters, # plus 2 states. ``` From 3219ae52f4bee3aa95581a20db9526d50b8f36dc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 13:56:07 -0400 Subject: [PATCH 61/95] chore: run formatter --- ext/LuxFluxExt.jl | 10 +++++----- test/layers/conv_tests.jl | 3 +-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/ext/LuxFluxExt.jl b/ext/LuxFluxExt.jl index 0116c200af..bd8ed31418 100644 --- a/ext/LuxFluxExt.jl +++ b/ext/LuxFluxExt.jl @@ -131,12 +131,12 @@ function Lux.convert_flux_model(l::Flux.CrossCor; preserve_ps_st::Bool=false, kw pad = l.pad isa Flux.SamePad ? SamePad() : l.pad if preserve_ps_st _bias = l.bias isa Bool ? nothing : vec(copy(l.bias)) - return Lux.Conv(k, in_chs => out_chs, l.σ; l.stride, pad, - l.dilation, init_weight=Returns(copy(l.weight)), - init_bias=Returns(_bias), use_bias=!(l.bias isa Bool), cross_correlation=true) + return Lux.Conv(k, in_chs => out_chs, l.σ; l.stride, pad, l.dilation, + init_weight=Returns(copy(l.weight)), init_bias=Returns(_bias), + use_bias=!(l.bias isa Bool), cross_correlation=true) else - return Lux.Conv(k, in_chs => out_chs, l.σ; l.stride, pad, - l.dilation, use_bias=!(l.bias isa Bool), cross_correlation=true) + return Lux.Conv(k, in_chs => out_chs, l.σ; l.stride, pad, l.dilation, + use_bias=!(l.bias isa Bool), cross_correlation=true) end end diff --git a/test/layers/conv_tests.jl b/test/layers/conv_tests.jl index 861bd85331..c6ddecbabd 100644 --- a/test/layers/conv_tests.jl +++ b/test/layers/conv_tests.jl @@ -446,8 +446,7 @@ end @test ps.bias isa aType{Float16, 1} end - @testset "SamePad kernelsize $k" for k in ( - (1,), (2,), (3,), (2, 3), (1, 2, 3)) + @testset "SamePad kernelsize $k" for k in ((1,), (2,), (3,), (2, 3), (1, 2, 3)) x = ones(Float32, (k .+ 3)..., 1, 1) |> aType @testset "Kwargs: $kwarg" for kwarg in ( From 4c8a7eebbc96387787d7bc081133d183db6955d6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 14:17:11 -0400 Subject: [PATCH 62/95] feat!: upsampling now defaults to no align corners --- docs/src/introduction/updating_to_v1.md | 2 ++ ext/LuxFluxExt.jl | 2 +- src/layers/conv.jl | 36 ++++++++++++++++++------- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/docs/src/introduction/updating_to_v1.md b/docs/src/introduction/updating_to_v1.md index 70cd4b2d6f..cbde4ffc8e 100644 --- a/docs/src/introduction/updating_to_v1.md +++ b/docs/src/introduction/updating_to_v1.md @@ -124,6 +124,8 @@ abstraction. function, taken from Pytorch. Pytorch assumes the activation function is `leakyrelu` to compute the gain, however, we compute the gain based on the activation function passed in to the layer. +- [`Upsample`](@ref) now has an `align_corners` keyword argument, which defaults to `false`. + Previously this was always `true`. ### New Features diff --git a/ext/LuxFluxExt.jl b/ext/LuxFluxExt.jl index bd8ed31418..617980d02f 100644 --- a/ext/LuxFluxExt.jl +++ b/ext/LuxFluxExt.jl @@ -173,7 +173,7 @@ Lux.convert_flux_model(::typeof(Flux.flatten); kwargs...) = Lux.FlattenLayer() Lux.convert_flux_model(l::Flux.PixelShuffle; kwargs...) = Lux.PixelShuffle(l.r) function Lux.convert_flux_model(l::Flux.Upsample{mode}; kwargs...) where {mode} - return Lux.Upsample(mode; l.scale, l.size) + return Lux.Upsample(mode; l.scale, l.size, align_corners=false) end function Lux.convert_flux_model( diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 0fdf3cc01c..01d1dab66b 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -494,7 +494,7 @@ function Base.show(io::IO, m::MeanPool) end """ - Upsample(mode = :nearest; [scale, size]) + Upsample(mode = :nearest; [scale, size, align_corners=false]) Upsample(scale, mode = :nearest) Upsampling Layer. @@ -526,6 +526,12 @@ Currently supported upsampling `mode`s and corresponding NNlib's methods are: # Extended Help +## Other Keyword Arguments + + - `align_corners`: If `true`, the corner pixels of the input and output tensors are + aligned, and thus preserving the values at those pixels. This only has effect when mode + is one of `:bilinear` or `:trilinear`. + ## Inputs - `x`: For the input dimensions look into the documentation for the corresponding `NNlib` @@ -544,38 +550,49 @@ Currently supported upsampling `mode`s and corresponding NNlib's methods are: scale size upsample_mode <: StaticSymbol + align_corners <: Bool end -function Upsample(mode::SymbolType=static(:nearest); scale=nothing, size=nothing) +function Upsample(mode::SymbolType=static(:nearest); scale=nothing, + size=nothing, align_corners::Bool=false) @argcheck dynamic(mode) in (:nearest, :bilinear, :trilinear) + if !xor(isnothing(scale), isnothing(size)) throw(ArgumentError("Either scale or size should be specified (but not both).")) end - return Upsample(scale, size, static(mode)) + return Upsample(scale, size, static(mode), align_corners) end Upsample(scale, mode::SymbolType=static(:nearest)) = Upsample(mode; scale) function (m::Upsample)(x::AbstractArray, _, st::NamedTuple) - return lux_upsample_scale_dispatch(m.upsample_mode, x, m.scale), st + return lux_upsample_scale_dispatch(m.upsample_mode, x, m.scale, m.align_corners), st end function (m::Upsample{Nothing})(x::AbstractArray, _, st::NamedTuple) - return lux_upsample_size_dispatch(m.upsample_mode, x, m.size), st + return lux_upsample_size_dispatch(m.upsample_mode, x, m.size, m.align_corners), st end -for interp in (:nearest, :bilinear, :trilinear) +for interp in (:bilinear, :trilinear) nnlib_interp_func = Symbol(:upsample_, interp) @eval begin - function lux_upsample_scale_dispatch(::StaticSymbol{$(Meta.quot(interp))}, x, scale) + function lux_upsample_scale_dispatch( + ::StaticSymbol{$(Meta.quot(interp))}, x, scale, align_corners) return $(nnlib_interp_func)(x, scale) end - function lux_upsample_size_dispatch(::StaticSymbol{$(Meta.quot(interp))}, x, size) + function lux_upsample_size_dispatch( + ::StaticSymbol{$(Meta.quot(interp))}, x, size, align_corners) return $(nnlib_interp_func)(x; size) end end end -function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale::Integer) +function lux_upsample_size_dispatch(::StaticSymbol{:nearest}, x, size, _) + return NNlib.upsample_nearest(x; size) +end +function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale, _) + return NNlib.upsample_nearest(x, scale) +end +function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale::Integer, _) return NNlib.upsample_nearest(x, ntuple(i -> scale, ndims(x) - 2)) end @@ -583,6 +600,7 @@ function Base.show(io::IO, u::Upsample) print(io, "Upsample(", u.upsample_mode) u.scale !== nothing && print(io, ", scale = $(u.scale)") u.size !== nothing && print(io, ", size = $(u.size)") + u.align_corners && print(io, ", align_corners = $(u.align_corners)") print(io, ")") end From 9ef520ef3278c74bdfb27f906a81ac59f270a1f6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 14:25:40 -0400 Subject: [PATCH 63/95] fix: tests and update init assumptions in tests --- src/layers/conv.jl | 10 +++++++--- test/contrib/map_tests.jl | 4 ++-- test/layers/normalize_tests.jl | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 01d1dab66b..4439233d2e 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -322,10 +322,14 @@ function ConvTranspose( end function initialparameters(rng::AbstractRNG, c::ConvTranspose) - args = (c.kernel_size, c.out_chs, c.in_chs, c.groups) - weight = init_conv_weight(rng, c.init_weight, args..., c.activation) + weight = init_conv_weight( + rng, c.init_weight, c.kernel_size, c.out_chs, c.in_chs, c.groups, c.activation) has_bias(c) || return (; weight) - return (; weight, bias=init_conv_bias(rng, c.init_bias, args...)) + # NOTE: The c.out_chs, c.out_chs is intentional, since it only affects the size of the + # bias vector + return (; weight, + bias=init_conv_bias( + rng, c.init_bias, c.kernel_size, c.out_chs, c.out_chs, c.groups)) end function parameterlength(c::ConvTranspose) diff --git a/test/contrib/map_tests.jl b/test/contrib/map_tests.jl index 2df37556a4..00b2f994f6 100644 --- a/test/contrib/map_tests.jl +++ b/test/contrib/map_tests.jl @@ -54,7 +54,7 @@ @test all(iszero, ps_.dense.weight) @test all(iszero, ps_.dense.bias) @test !all(iszero, ps_.conv.weight) - @test all(iszero, ps_.conv.bias) + @test !all(iszero, ps_.conv.bias) # Custom Wrapper struct SimpleWrapper{L} <: Lux.AbstractLuxWrapperLayer{:model} @@ -70,6 +70,6 @@ @test all(iszero, ps_.dense.weight) @test all(iszero, ps_.dense.bias) @test !all(iszero, ps_.conv.weight) - @test all(iszero, ps_.conv.bias) + @test !all(iszero, ps_.conv.bias) end end diff --git a/test/layers/normalize_tests.jl b/test/layers/normalize_tests.jl index b7b27e7a97..7d8c9db732 100644 --- a/test/layers/normalize_tests.jl +++ b/test/layers/normalize_tests.jl @@ -288,7 +288,7 @@ end # See https://github.com/LuxDL/Lux.jl/issues/95 @testset "Normalizing Zero Parameters" begin - c = Conv((3, 3), 3 => 3) + c = Conv((3, 3), 3 => 3; init_bias=zeros32) wn = WeightNorm(c, (:weight, :bias)) @test_throws ArgumentError Lux.setup(rng, wn) From e839347ebfc01d1f52038b3537ca43cd1cf11a4c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 15:00:14 -0400 Subject: [PATCH 64/95] fix: update initialization of linear layers --- src/layers/basic.jl | 62 ++++++++++++++++++++++++++++++++------------- src/layers/conv.jl | 6 ++--- 2 files changed, 47 insertions(+), 21 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e10974057f..a9f97173fc 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -1,3 +1,14 @@ +function init_linear_bias(rng::AbstractRNG, init_bias::F, fan_in::IntegerType, + bias_len::IntegerType) where {F} + if init_bias === nothing # Default from PyTorch + bound = inv(sqrt(fan_in)) + y = rand32(rng, bias_len) + @. y = (y - 0.5f0) * 2 * bound + return y + end + return init_bias(rng, bias_len) +end + """ ReshapeLayer(dims) @@ -247,8 +258,8 @@ end Base.show(io::IO, w::WrappedFunction) = print(io, "WrappedFunction(", w.func, ")") """ - Dense(in_dims => out_dims, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, use_bias=True()) + Dense(in_dims => out_dims, activation=identity; init_weight=nothing, + init_bias=nothing, use_bias=True()) Create a traditional fully connected layer, whose forward pass is given by: `y = activation.(weight * x .+ bias)` @@ -262,8 +273,13 @@ Create a traditional fully connected layer, whose forward pass is given by: ## Keyword Arguments - `init_weight`: initializer for the weight matrix - (`weight = init_weight(rng, out_dims, in_dims)`) - - `init_bias`: initializer for the bias vector (ignored if `use_bias=false`) + (`weight = init_weight(rng, out_dims, in_dims)`). If `nothing`, then we use + [`kaiming_uniform`](@ref) with gain computed on the basis of the activation + function (taken from Pytorch + [`nn.init.calculate_gain`](https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.calculate_gain)). + - `init_bias`: initializer for the bias vector (ignored if `use_bias=false`). If + `nothing`, then we use uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(in_dims))`. - `use_bias`: Trainable bias can be disabled entirely by setting this to `false` ## Input @@ -306,12 +322,14 @@ function Dense(in_dims::IntegerType, out_dims::IntegerType, activation=identity; end function initialparameters(rng::AbstractRNG, d::Dense) - if has_bias(d) - return (weight=d.init_weight(rng, d.out_dims, d.in_dims), - bias=d.init_bias(rng, d.out_dims)) + weight = if d.init_weight === nothing + kaiming_uniform(rng, Float32, d.out_dims, d.in_dims; gain=Utils.calculate_gain( + d.activation, √5.0f0)) else - return (weight=d.init_weight(rng, d.out_dims, d.in_dims),) + d.init_weight(rng, d.out_dims, d.in_dims) end + has_bias(d) || return (; weight) + return (; weight, bias=init_linear_bias(rng, d.init_bias, d.in_dims, d.out_dims)) end parameterlength(d::Dense) = d.out_dims * d.in_dims + has_bias(d) * d.out_dims @@ -412,10 +430,10 @@ function (d::Scale{True})(x::AbstractArray, ps, st::NamedTuple) end """ - Bilinear((in1_dims, in2_dims) => out, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, use_bias=True()) - Bilinear(in12_dims => out, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, use_bias=True()) + Bilinear((in1_dims, in2_dims) => out, activation=identity; init_weight=nothing, + init_bias=nothing, use_bias=True()) + Bilinear(in12_dims => out, activation=identity; init_weight=nothing, + init_bias=nothing, use_bias=True()) Create a fully connected layer between two inputs and an output, and otherwise similar to [`Dense`](@ref). Its output, given vectors `x` & `y`, is another vector `z` with, for all @@ -437,8 +455,12 @@ with `B` the Bilinear layer. ## Keyword Arguments - `init_weight`: initializer for the weight matrix - (`weight = init_weight(rng, out_dims, in1_dims, in2_dims)`) - - `init_bias`: initializer for the bias vector (ignored if `use_bias=false`) + (`weight = init_weight(rng, out_dims, in1_dims, in2_dims)`). If `nothing`, then we + use uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(in1_dims))`. + - `init_bias`: initializer for the bias vector (ignored if `use_bias=false`). If + `nothing`, then we use uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(in1_dims))`. - `use_bias`: Trainable bias can be disabled entirely by setting this to `false` ## Input @@ -490,12 +512,16 @@ function Bilinear( end function initialparameters(rng::AbstractRNG, b::Bilinear) - if has_bias(b) - return (weight=b.init_weight(rng, b.out_dims, b.in1_dims, b.in2_dims), - bias=b.init_bias(rng, b.out_dims)) + weight = if b.init_weight === nothing + bound = inv(sqrt(b.in1_dims)) + y = randn32(rng, b.out_dims, b.in1_dims, b.in2_dims) + @. y = (y - 0.5f0) * 2 * bound + y else - return (weight=b.init_weight(rng, b.out_dims, b.in1_dims, b.in2_dims),) + b.init_weight(rng, b.out_dims, b.in1_dims, b.in2_dims) end + has_bias(b) || return (; weight) + return (; weight, bias=init_linear_bias(rng, b.init_bias, b.in1_dims, b.out_dims)) end function parameterlength(b::Bilinear) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 4439233d2e..08b087e848 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -55,8 +55,8 @@ function init_conv_weight( rng::AbstractRNG, init_weight::F, filter::NTuple{N, <:IntegerType}, in_chs::IntegerType, out_chs::IntegerType, groups, σ::A) where {F, N, A} if init_weight === nothing # Default from PyTorch - gain = Utils.calculate_gain(σ, √5.0f0) - return kaiming_uniform(rng, Float32, filter..., in_chs ÷ groups, out_chs; gain) + return kaiming_uniform(rng, Float32, filter..., in_chs ÷ groups, + out_chs; gain=Utils.calculate_gain(σ, √5.0f0)) end return init_weight(rng, filter..., in_chs ÷ groups, out_chs) end @@ -67,7 +67,7 @@ function init_conv_bias(rng::AbstractRNG, init_bias::F, filter::NTuple{N, <:Inte fan_in = prod(filter) * (in_chs ÷ groups) bound = inv(sqrt(fan_in)) y = rand32(rng, out_chs) - @. y = y * 2bound - bound + @. y = (y - 0.5f0) * 2 * bound return y end return init_bias(rng, out_chs) From 8e68a1e52e6e973bce4e5a763d5f801e33a040b9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 15:05:08 -0400 Subject: [PATCH 65/95] fix: update normalization defaults to match Pytorch --- docs/src/introduction/updating_to_v1.md | 3 + src/layers/normalize.jl | 196 ++++++++++++------------ 2 files changed, 101 insertions(+), 98 deletions(-) diff --git a/docs/src/introduction/updating_to_v1.md b/docs/src/introduction/updating_to_v1.md index cbde4ffc8e..a196fa8b7a 100644 --- a/docs/src/introduction/updating_to_v1.md +++ b/docs/src/introduction/updating_to_v1.md @@ -126,6 +126,9 @@ abstraction. to the layer. - [`Upsample`](@ref) now has an `align_corners` keyword argument, which defaults to `false`. Previously this was always `true`. +- [`Dense`](@ref) and [`Bilinear`](@ref) have updated default initializations to align with + the defaults from Pytorch. See the documentation for more details. +- [`InstanceNorm`](@ref) now defaults to `affine=false` instead of `affine=true`. ### New Features diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 047f581417..1912922863 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -77,7 +77,7 @@ Chain( !!! warning - Passing a batch size of 1, during training will result in NaNs. + Passing a batch size of 1, during training will result in an error. See also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`LayerNorm`](@ref), [`WeightNorm`](@ref) @@ -257,7 +257,7 @@ end @doc doc""" InstanceNorm(chs::Integer, activation=identity; init_bias=zeros32, init_scale=ones32, - affine=True(), track_stats=False(), epsilon=1f-5, momentum=0.1f0) + affine=False(), track_stats=False(), epsilon=1f-5, momentum=0.1f0) Instance Normalization. For details see [1]. @@ -353,7 +353,7 @@ See also [`BatchNorm`](@ref), [`GroupNorm`](@ref), [`LayerNorm`](@ref), [`Weight end function InstanceNorm(chs::IntegerType, activation=identity; init_bias=zeros32, - init_scale=ones32, affine::BoolType=True(), + init_scale=ones32, affine::BoolType=False(), track_stats::BoolType=False(), epsilon=1.0f-5, momentum=0.1f0) return InstanceNorm(activation, epsilon, momentum, chs, init_bias, init_scale, static(affine), static(track_stats)) @@ -394,6 +394,101 @@ function Base.show(io::IO, l::InstanceNorm) return print(io, ")") end +@doc doc""" + LayerNorm(shape::NTuple{N, Int}, activation=identity; epsilon=1f-5, dims=Colon(), + affine=true, init_bias=zeros32, init_scale=ones32) + +Computes mean and standard deviation over the whole input array, and uses these to +normalize the whole array. Optionally applies an elementwise affine transformation +afterwards. + +Given an input array ``x``, this layer computes + +```math +y = \frac{x - \mathbb{E}[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta +``` + +where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`. + +!!! warning "Inconsistent Defaults till v0.5.0" + + As of v0.5.0, the doc used to say `affine::Bool=false`, but the code actually had + `affine::Bool=true` as the default. Now the doc reflects the code, so please check + whether your assumptions about the default (if made) were invalid. + +## Arguments + + - `shape`: Broadcastable shape of input array excluding the batch dimension. + - `activation`: After normalization, elementwise activation `activation` is applied. + +## Keyword Arguments + + - `epsilon`: a value added to the denominator for numerical stability. + - `dims`: Dimensions to normalize the array over. + - If `affine=true`, it also applies a shift and a rescale to the input through to + learnable per-element bias and scale parameters. + + + `init_bias`: Controls how the `bias` is initialized + + `init_scale`: Controls how the `scale` is initialized + +# Extended Help + +## Inputs + + - `x`: AbstractArray + +## Returns + + - `y`: Normalized Array + - Empty NamedTuple() + +## Parameters + + - `affine=false`: Empty `NamedTuple()` + - `affine=true` + + + `bias`: Bias of shape `(shape..., 1)` + + `scale`: Scale of shape `(shape..., 1)` +""" +@concrete struct LayerNorm <: AbstractLuxLayer + shape + activation + epsilon + init_bias + init_scale + dims + affine <: StaticBool +end + +function LayerNorm(shape, activation=identity; epsilon=1.0f-5, dims=Colon(), + affine::BoolType=True(), init_bias=zeros32, init_scale=ones32) + return LayerNorm( + shape, activation, epsilon, init_bias, init_scale, dims, static(affine)) +end + +function initialparameters(rng::AbstractRNG, ln::LayerNorm) + if has_affine(ln) + dims = (ln.shape..., 1) + return (; bias=ln.init_bias(rng, dims...), scale=ln.init_scale(rng, dims...)) + end + return (;) +end + +function (l::LayerNorm)(x::AbstractArray, ps, st::NamedTuple) + x′ = match_eltype(l, ps, st, x) + σ = NNlib.fast_act(l.activation, x′) + y = layernorm(x′, safe_getproperty(ps, Val(:scale)), safe_getproperty(ps, Val(:bias)), + σ, l.dims, convert(unwrapped_eltype(x′), l.epsilon)) + return y, st +end + +function Base.show(io::IO, l::LayerNorm) + print(io, "LayerNorm($(l.shape)") + (l.activation == identity) || print(io, ", $(l.activation)") + print(io, ", affine=$(has_affine(l)), dims=$(l.dims)") + return print(io, ")") +end + @doc doc""" WeightNorm(layer::AbstractLuxLayer, which_params::NTuple{N, Symbol}, dims::Union{Tuple, Nothing}=nothing) @@ -516,98 +611,3 @@ function Base.show(io::IO, ::MIME"text/plain", w::WeightNorm) return print(io, "WeightNorm(", w.layer, ", dims = ", known(w.dims), ", normalized_parameters = ", known(w.which_params), ")") end - -@doc doc""" - LayerNorm(shape::NTuple{N, Int}, activation=identity; epsilon=1f-5, dims=Colon(), - affine=true, init_bias=zeros32, init_scale=ones32) - -Computes mean and standard deviation over the whole input array, and uses these to -normalize the whole array. Optionally applies an elementwise affine transformation -afterwards. - -Given an input array ``x``, this layer computes - -```math -y = \frac{x - \mathbb{E}[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta -``` - -where ``\gamma`` & ``\beta`` are trainable parameters if `affine=true`. - -!!! warning "Inconsistent Defaults till v0.5.0" - - As of v0.5.0, the doc used to say `affine::Bool=false`, but the code actually had - `affine::Bool=true` as the default. Now the doc reflects the code, so please check - whether your assumptions about the default (if made) were invalid. - -## Arguments - - - `shape`: Broadcastable shape of input array excluding the batch dimension. - - `activation`: After normalization, elementwise activation `activation` is applied. - -## Keyword Arguments - - - `epsilon`: a value added to the denominator for numerical stability. - - `dims`: Dimensions to normalize the array over. - - If `affine=true`, it also applies a shift and a rescale to the input through to - learnable per-element bias and scale parameters. - - + `init_bias`: Controls how the `bias` is initialized - + `init_scale`: Controls how the `scale` is initialized - -# Extended Help - -## Inputs - - - `x`: AbstractArray - -## Returns - - - `y`: Normalized Array - - Empty NamedTuple() - -## Parameters - - - `affine=false`: Empty `NamedTuple()` - - `affine=true` - - + `bias`: Bias of shape `(shape..., 1)` - + `scale`: Scale of shape `(shape..., 1)` -""" -@concrete struct LayerNorm <: AbstractLuxLayer - shape - activation - epsilon - init_bias - init_scale - dims - affine <: StaticBool -end - -function LayerNorm(shape, activation=identity; epsilon=1.0f-5, dims=Colon(), - affine::BoolType=True(), init_bias=zeros32, init_scale=ones32) - return LayerNorm( - shape, activation, epsilon, init_bias, init_scale, dims, static(affine)) -end - -function initialparameters(rng::AbstractRNG, ln::LayerNorm) - if has_affine(ln) - dims = (ln.shape..., 1) - return (; bias=ln.init_bias(rng, dims...), scale=ln.init_scale(rng, dims...)) - end - return (;) -end - -function (l::LayerNorm)(x::AbstractArray, ps, st::NamedTuple) - x′ = match_eltype(l, ps, st, x) - σ = NNlib.fast_act(l.activation, x′) - y = layernorm(x′, safe_getproperty(ps, Val(:scale)), safe_getproperty(ps, Val(:bias)), - σ, l.dims, convert(unwrapped_eltype(x′), l.epsilon)) - return y, st -end - -function Base.show(io::IO, l::LayerNorm) - print(io, "LayerNorm($(l.shape)") - (l.activation == identity) || print(io, ", $(l.activation)") - print(io, ", affine=$(has_affine(l)), dims=$(l.dims)") - return print(io, ")") -end From 183b18cbfcaaca08dc6d5a6259edc30100cdd120 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 15:06:52 -0400 Subject: [PATCH 66/95] fix: update Embedding defaults to match Pytorch --- docs/src/introduction/updating_to_v1.md | 1 + src/layers/basic.jl | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/src/introduction/updating_to_v1.md b/docs/src/introduction/updating_to_v1.md index a196fa8b7a..ed7a71227b 100644 --- a/docs/src/introduction/updating_to_v1.md +++ b/docs/src/introduction/updating_to_v1.md @@ -129,6 +129,7 @@ abstraction. - [`Dense`](@ref) and [`Bilinear`](@ref) have updated default initializations to align with the defaults from Pytorch. See the documentation for more details. - [`InstanceNorm`](@ref) now defaults to `affine=false` instead of `affine=true`. +- [`Embedding`](@ref) now defaults to `init_weight=rand32` instead of `init_weight=randn32`. ### New Features diff --git a/src/layers/basic.jl b/src/layers/basic.jl index a9f97173fc..7b0168f96e 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -323,8 +323,8 @@ end function initialparameters(rng::AbstractRNG, d::Dense) weight = if d.init_weight === nothing - kaiming_uniform(rng, Float32, d.out_dims, d.in_dims; gain=Utils.calculate_gain( - d.activation, √5.0f0)) + kaiming_uniform(rng, Float32, d.out_dims, d.in_dims; + gain=Utils.calculate_gain(d.activation, √5.0f0)) else d.init_weight(rng, d.out_dims, d.in_dims) end @@ -559,7 +559,7 @@ end (b::Bilinear)(x::AbstractArray, ps, st::NamedTuple) = b((x, x), ps, st) """ - Embedding(in_dims => out_dims; init_weight=randn32) + Embedding(in_dims => out_dims; init_weight=rand32) A lookup table that stores embeddings of dimension `out_dims` for a vocabulary of size `in_dims`. When the vocabulary is multi-dimensional, the input is expected to be a tuple @@ -598,7 +598,7 @@ This layer is often used to store word embeddings and retrieve them using indice init_weight end -function Embedding((in_dims, out_dims)::Pair; init_weight=randn32) +function Embedding((in_dims, out_dims)::Pair; init_weight=rand32) return Embedding(in_dims, out_dims, init_weight) end From be1e74dee9ae04d364d416b2473810737496c7e0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 15:38:34 -0400 Subject: [PATCH 67/95] fix!: RNNCell defaults updated --- docs/src/introduction/updating_to_v1.md | 4 ++ src/layers/recurrent.jl | 79 +++++++++++++++++-------- src/utils.jl | 8 ++- 3 files changed, 64 insertions(+), 27 deletions(-) diff --git a/docs/src/introduction/updating_to_v1.md b/docs/src/introduction/updating_to_v1.md index ed7a71227b..a554093cf3 100644 --- a/docs/src/introduction/updating_to_v1.md +++ b/docs/src/introduction/updating_to_v1.md @@ -130,7 +130,11 @@ abstraction. the defaults from Pytorch. See the documentation for more details. - [`InstanceNorm`](@ref) now defaults to `affine=false` instead of `affine=true`. - [`Embedding`](@ref) now defaults to `init_weight=rand32` instead of `init_weight=randn32`. +- Recurrent Cells - [`RNNCell`](@ref), [`LSTMCell`](@ref), and [`GRUCell`](@ref) now have + different default initializations. See the documentation for more details. ### New Features - [`InstanceNorm`](@ref) now supports tracking statistics. +- [`RNNCell`](@ref) add `bias_ih` and `bias_hh` to the parameters to align with Pytorch. + Both are controlled using `init_bias` and `use_bias`. diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index fbde8afe49..d83e07cbd8 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -27,6 +27,20 @@ function LuxOps.eachslice(x::AbstractArray, ::BatchLastIndex) end LuxOps.eachslice(x::AbstractMatrix, ::BatchLastIndex) = LuxOps.eachslice(x, Val(ndims(x))) +function init_rnn_weight(rng::AbstractRNG, init_weight, hidden_dims, dims) + if init_weight === nothing + bound = inv(sqrt(hidden_dims)) + y = randn32(rng, Float32, dims...) + @. y = (y - 0.5f0) * 2 * bound + return y + end + return init_weight(rng, dims...) +end + +function init_rnn_bias(rng::AbstractRNG, init_bias, hidden_dims, bias_len) + return init_rnn_weight(rng, init_bias, hidden_dims, (bias_len,)) +end + """ Recurrence(cell; ordering::AbstractTimeSeriesDataBatchOrdering=BatchLastIndex(), @@ -171,11 +185,11 @@ applyrecurrentcell(l::AbstractRecurrentCell, x, ps, st, ::Nothing) = apply(l, x, @doc doc""" RNNCell(in_dims => out_dims, activation=tanh; use_bias=True(), train_state=False(), - init_bias=zeros32, init_weight=glorot_uniform, init_state=ones32) + init_bias=nothing, init_weight=nothing, init_state=zeros32) An Elman RNNCell cell with `activation` (typically set to `tanh` or `relu`). -``h_{new} = activation(weight_{ih} \times x + weight_{hh} \times h_{prev} + bias)`` +``h_{new} = activation(weight_{ih} \times x + bias_{ih} + weight_{hh} \times h_{prev} + bias_{hh})`` ## Arguments @@ -184,8 +198,10 @@ An Elman RNNCell cell with `activation` (typically set to `tanh` or `relu`). - `activation`: Activation function - `use_bias`: Set to false to deactivate bias - `train_state`: Trainable initial hidden state can be activated by setting this to `true` - - `init_bias`: Initializer for bias - - `init_weight`: Initializer for weight + - `init_bias`: Initializer for bias. If `nothing`, then we use uniform distribution with + bounds `-bound` and `bound` where `bound = inv(sqrt(out_dims))`. + - `init_weight`: Initializer for weight. If `nothing`, then we use uniform distribution + with bounds `-bound` and `bound` where `bound = inv(sqrt(out_dims))`. - `init_state`: Initializer for hidden state ## Inputs @@ -199,6 +215,7 @@ An Elman RNNCell cell with `activation` (typically set to `tanh` or `relu`). updated hidden state is returned. ## Returns + - Tuple containing + Output ``h_{new}`` of shape `(out_dims, batch_size)` @@ -210,7 +227,8 @@ An Elman RNNCell cell with `activation` (typically set to `tanh` or `relu`). - `weight_ih`: Maps the input to the hidden state. - `weight_hh`: Maps the hidden state to the hidden state. - - `bias`: Bias vector (not present if `use_bias=false`) + - `bias_ih`: Bias vector for the input-hidden connection (not present if `use_bias=false`) + - `bias_hh`: Bias vector for the hidden-hidden connection (not present if `use_bias=false`) - `hidden_state`: Initial hidden state vector (not present if `train_state=false`) ## States @@ -230,15 +248,22 @@ end function RNNCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType}, activation=tanh; use_bias::BoolType=True(), train_state::BoolType=False(), - init_bias=zeros32, init_weight=glorot_uniform, init_state=ones32) + init_bias=nothing, init_weight=nothing, init_state=zeros32) return RNNCell(static(train_state), activation, in_dims, out_dims, init_bias, init_weight, init_state, static(use_bias)) end function initialparameters(rng::AbstractRNG, rnn::RNNCell) - ps = (weight_ih=rnn.init_weight(rng, rnn.out_dims, rnn.in_dims), - weight_hh=rnn.init_weight(rng, rnn.out_dims, rnn.out_dims)) - has_bias(rnn) && (ps = merge(ps, (bias=rnn.init_bias(rng, rnn.out_dims),))) + weight_ih = init_rnn_weight( + rng, rnn.init_weight, rnn.out_dims, (rnn.out_dims, rnn.in_dims)) + weight_hh = init_rnn_weight( + rng, rnn.init_weight, rnn.out_dims, (rnn.out_dims, rnn.out_dims)) + ps = (; weight_ih, weight_hh) + if has_bias(rnn) + bias_ih = init_rnn_bias(rng, rnn.init_bias, rnn.out_dims, rnn.out_dims) + bias_hh = init_rnn_bias(rng, rnn.init_bias, rnn.out_dims, rnn.out_dims) + ps = merge(ps, (; bias_ih, bias_hh)) + end has_train_state(rnn) && (ps = merge(ps, (hidden_state=rnn.init_state(rng, rnn.out_dims),))) return ps @@ -248,12 +273,12 @@ initialstates(rng::AbstractRNG, ::RNNCell) = (rng=Utils.sample_replicate(rng),) function (rnn::RNNCell{False})(x::AbstractMatrix, ps, st::NamedTuple) rng = replicate(st.rng) - hidden_state = Utils.init_hidden_state(rng, rnn, x) + hidden_state = init_rnn_hidden_state(rng, rnn, x) return rnn((x, (hidden_state,)), ps, merge(st, (; rng))) end function (rnn::RNNCell{True})(x::AbstractMatrix, ps, st::NamedTuple) - hidden_state = Utils.init_trainable_hidden_state(ps.hidden_state, x) + hidden_state = init_trainable_rnn_hidden_state(ps.hidden_state, x) return rnn((x, (hidden_state,)), ps, st) end @@ -261,9 +286,15 @@ function (rnn::RNNCell)( (x, (hidden_state,))::Tuple{<:AbstractMatrix, Tuple{<:AbstractMatrix}}, ps, st::NamedTuple) y, hidden_stateₙ = match_eltype(rnn, ps, st, x, hidden_state) - bias = safe_getproperty(ps, Val(:bias)) - z = fused_dense_bias_activation(identity, ps.weight_hh, hidden_stateₙ, bias) - hₙ = fast_activation!!(rnn.activation, LuxLib.Impl.matmul(ps.weight_ih, y) .+ z) + + bias_hh = safe_getproperty(ps, Val(:bias_hh)) + z₁ = fused_dense_bias_activation(identity, ps.weight_hh, hidden_stateₙ, bias_hh) + + bias_ih = safe_getproperty(ps, Val(:bias_ih)) + z₂ = fused_dense_bias_activation(identity, ps.weight_ih, y, bias_ih) + + # TODO: This operation can be fused instead of doing add then activation + hₙ = fast_activation!!(rnn.activation, z₁ .+ z₂) return (hₙ, (hₙ,)), st end @@ -393,28 +424,28 @@ initialstates(rng::AbstractRNG, ::LSTMCell) = (rng=Utils.sample_replicate(rng),) function (lstm::LSTMCell{False, False})(x::AbstractMatrix, ps, st::NamedTuple) rng = replicate(st.rng) - hidden_state = Utils.init_hidden_state(rng, lstm, x) - memory = Utils.init_hidden_state(rng, lstm, x) + hidden_state = init_rnn_hidden_state(rng, lstm, x) + memory = init_rnn_hidden_state(rng, lstm, x) return lstm((x, (hidden_state, memory)), ps, merge(st, (; rng))) end function (lstm::LSTMCell{True, False})(x::AbstractMatrix, ps, st::NamedTuple) rng = replicate(st.rng) - hidden_state = Utils.init_trainable_hidden_state(ps.hidden_state, x) - memory = Utils.init_hidden_state(rng, lstm, x) + hidden_state = init_trainable_rnn_hidden_state(ps.hidden_state, x) + memory = init_rnn_hidden_state(rng, lstm, x) return lstm((x, (hidden_state, memory)), ps, merge(st, (; rng))) end function (lstm::LSTMCell{False, True})(x::AbstractMatrix, ps, st::NamedTuple) rng = replicate(st.rng) - hidden_state = Utils.init_hidden_state(rng, lstm, x) - memory = Utils.init_trainable_hidden_state(ps.memory, x) + hidden_state = init_rnn_hidden_state(rng, lstm, x) + memory = init_trainable_rnn_hidden_state(ps.memory, x) return lstm((x, (hidden_state, memory)), ps, merge(st, (; rng))) end function (lstm::LSTMCell{True, True})(x::AbstractMatrix, ps, st::NamedTuple) - hidden_state = Utils.init_trainable_hidden_state(ps.hidden_state, x) - memory = Utils.init_trainable_hidden_state(ps.memory, x) + hidden_state = init_trainable_rnn_hidden_state(ps.hidden_state, x) + memory = init_trainable_rnn_hidden_state(ps.memory, x) return lstm((x, (hidden_state, memory)), ps, st) end @@ -543,14 +574,14 @@ end initialstates(rng::AbstractRNG, ::GRUCell) = (rng=Utils.sample_replicate(rng),) function (gru::GRUCell{True})(x::AbstractMatrix, ps, st::NamedTuple) - hidden_state = Utils.init_trainable_hidden_state(ps.hidden_state, x) + hidden_state = init_trainable_rnn_hidden_state(ps.hidden_state, x) return gru((x, (hidden_state,)), ps, st) end function (gru::GRUCell{False})(x::AbstractMatrix, ps, st::NamedTuple) rng = replicate(st.rng) st = merge(st, (; rng)) - hidden_state = Utils.init_hidden_state(rng, gru, x) + hidden_state = init_rnn_hidden_state(rng, gru, x) return gru((x, (hidden_state,)), ps, st) end diff --git a/src/utils.jl b/src/utils.jl index 74973abe6c..13e442d087 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -160,11 +160,13 @@ end add!!(x::Number, y::Number) = x + y add!!(::Nothing, ::Nothing) = nothing -function init_hidden_state(rng::AbstractRNG, rnn, x::AbstractMatrix) +function init_rnn_hidden_state(rng::AbstractRNG, rnn, x::AbstractMatrix) + # TODO: Once we support moving `rng` to the device, we can directly initialize on the + # device return rnn.init_state(rng, rnn.out_dims, Base.size(x, 2)) |> get_device(x) end -function init_trainable_hidden_state(hidden_state::AbstractVector, x::AbstractMatrix) +function init_trainable_rnn_hidden_state(hidden_state::AbstractVector, x::AbstractMatrix) return repeat(hidden_state, 1, Base.size(x, 2)) end @@ -221,7 +223,7 @@ calculate_gain(::typeof(NNlib.selu), _) = 3.0f0 / 4 end using .Utils: Utils, BoolType, IntegerType, SymbolType, make_abstract_matrix, - matrix_to_array + matrix_to_array, init_trainable_rnn_hidden_state, init_rnn_hidden_state const safe_reverse = Utils.reverse const safe_vec = Utils.vec From 780de12f51627afcc9e4ce34c9045f7298691dc6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 15:46:24 -0400 Subject: [PATCH 68/95] fix: testing failures due to non-zero bias --- src/layers/conv.jl | 2 ++ test/layers/conv_tests.jl | 20 +++++++------------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 08b087e848..577e3e643d 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -208,8 +208,10 @@ function (c::Conv)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(c, ps, st, x) cdims = construct_crosscor_convdims(c.cross_correlation, DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups)) + @show cdims bias = safe_getproperty(ps, Val(:bias)) σ = NNlib.fast_act(c.activation, y) + @show σ return fused_conv_bias_activation(σ, ps.weight, y, bias, cdims), st end diff --git a/test/layers/conv_tests.jl b/test/layers/conv_tests.jl index c6ddecbabd..7248b18254 100644 --- a/test/layers/conv_tests.jl +++ b/test/layers/conv_tests.jl @@ -141,12 +141,8 @@ end test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) # Test that we cannot ask for non-integer multiplication factors - layer = Conv((2, 2), 3 => 10; groups=2) - display(layer) - @test_throws DimensionMismatch Lux.setup(rng, layer) - layer = Conv((2, 2), 2 => 9; groups=2) - display(layer) - @test_throws DimensionMismatch Lux.setup(rng, layer) + @test_throws DimensionMismatch Conv((2, 2), 3 => 10; groups=2) + @test_throws DimensionMismatch Conv((2, 2), 2 => 9; groups=2) @testset "Segfault Test LuxDL/Lux.jl#386" begin layer = Conv((5,), 32 => 32, tanh; groups=32) @@ -228,9 +224,7 @@ end test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) # Test that we cannot ask for non-integer multiplication factors - layer = Conv((2, 2), 3 => 10; groups=3) - display(layer) - @test_throws DimensionMismatch Lux.setup(rng, layer) + @test_throws DimensionMismatch Conv((2, 2), 3 => 10; groups=3) end @testset "Conv SamePad kernelsize $k" for k in ((1,), (2,), (3,), (2, 3), (1, 2, 3)) @@ -261,7 +255,7 @@ end x[4, 4, 1, 1] = 1 x = x |> aType - layer = Conv((3, 3), 1 => 1) + layer = Conv((3, 3), 1 => 1; use_bias=false) display(layer) ps, st = Lux.setup(rng, layer) |> dev @@ -271,7 +265,7 @@ end @jet layer(x, ps, st) - layer = Conv((3, 1), 1 => 1) + layer = Conv((3, 1), 1 => 1; use_bias=false) display(layer) ps, st = Lux.setup(rng, layer) |> dev @@ -281,7 +275,7 @@ end @jet layer(x, ps, st) - layer = Conv((1, 3), 1 => 1) + layer = Conv((1, 3), 1 => 1; use_bias=false) display(layer) ps, st = Lux.setup(rng, layer) |> dev @@ -291,7 +285,7 @@ end @jet layer(x, ps, st) - layer = Conv((1, 3), 1 => 1; init_weight=Lux.glorot_normal) + layer = Conv((1, 3), 1 => 1; init_weight=Lux.glorot_normal, use_bias=false) display(layer) ps, st = Lux.setup(rng, layer) |> dev From 01ce048d6aee54e98d905f65d1c90138837dc1f6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 15:55:54 -0400 Subject: [PATCH 69/95] feat: update bias in LSTMCell --- src/layers/recurrent.jl | 62 ++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index d83e07cbd8..cd8a95d00f 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -308,10 +308,8 @@ end @doc doc""" LSTMCell(in_dims => out_dims; use_bias::Bool=true, train_state::Bool=false, - train_memory::Bool=false, - init_weight=(glorot_uniform, glorot_uniform, glorot_uniform, glorot_uniform), - init_bias=(zeros32, zeros32, ones32, zeros32), init_state=zeros32, - init_memory=zeros32) + train_memory::Bool=false, init_weight=nothing, init_bias=nothing, + init_state=zeros32, init_memory=zeros32) Long Short-Term (LSTM) Cell @@ -333,8 +331,14 @@ Long Short-Term (LSTM) Cell - `use_bias`: Set to false to deactivate bias - `train_state`: Trainable initial hidden state can be activated by setting this to `true` - `train_memory`: Trainable initial memory can be activated by setting this to `true` - - `init_bias`: Initializer for bias. Must be a tuple containing 4 functions - - `init_weight`: Initializer for weight. Must be a tuple containing 4 functions + - `init_bias`: Initializer for bias. Must be a tuple containing 4 functions. If a single + value is passed, it is copied into a 4 element tuple. If `nothing`, then we use + uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(out_dims))`. + - `init_weight`: Initializer for weight. Must be a tuple containing 4 functions. If a + single value is passed, it is copied into a 4 element tuple. If `nothing`, then we use + uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(out_dims))`. - `init_state`: Initializer for hidden state - `init_memory`: Initializer for memory @@ -369,11 +373,13 @@ Long Short-Term (LSTM) Cell ## Parameters - - `weight_i`: Concatenated Weights to map from input space - ``\{ W_{ii}, W_{if}, W_{ig}, W_{io} \}``. - - `weight_h`: Concatenated Weights to map from hidden space - ``\{ W_{hi}, W_{hf}, W_{hg}, W_{ho} \}`` - - `bias`: Bias vector (not present if `use_bias=false`) + - `weight_ih`: Concatenated Weights to map from input space + ``\{ W_{ii}, W_{if}, W_{ig}, W_{io} \}``. + - `weight_hh`: Concatenated Weights to map from hidden space + ``\{ W_{hi}, W_{hf}, W_{hg}, W_{ho} \}`` + - `bias_ih`: Bias vector for the input-hidden connection (not present if `use_bias=false`) + - `bias_hh`: Concatenated Bias vector for the hidden-hidden connection (not present if + `use_bias=false`) - `hidden_state`: Initial hidden state vector (not present if `train_state=false`) - `memory`: Initial memory vector (not present if `train_memory=false`) @@ -393,10 +399,10 @@ Long Short-Term (LSTM) Cell use_bias <: StaticBool end -function LSTMCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType}; - use_bias::BoolType=True(), train_state::BoolType=False(), - train_memory::BoolType=False(), init_weight=glorot_uniform, - init_bias=zeros32, init_state=zeros32, init_memory=zeros32) +function LSTMCell( + (in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType}; use_bias::BoolType=True(), + train_state::BoolType=False(), train_memory::BoolType=False(), + init_weight=nothing, init_bias=nothing, init_state=zeros32, init_memory=zeros32) init_weight isa NTuple{4} || (init_weight = ntuple(Returns(init_weight), 4)) init_bias isa NTuple{4} || (init_bias = ntuple(Returns(init_bias), 4)) return LSTMCell(static(train_state), static(train_memory), in_dims, out_dims, @@ -404,14 +410,17 @@ function LSTMCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType}; end function initialparameters(rng::AbstractRNG, lstm::LSTMCell) - weight_i = vcat([init_weight(rng, lstm.out_dims, lstm.in_dims) - for init_weight in lstm.init_weight]...) - weight_h = vcat([init_weight(rng, lstm.out_dims, lstm.out_dims) - for init_weight in lstm.init_weight]...) - ps = (; weight_i, weight_h) + weight_ih = vcat([init_rnn_weight( + rng, init_weight, lstm.out_dims, (lstm.out_dims, lstm.in_dins)) + for init_weight in lstm.init_weight]...) + weight_hh = vcat([init_rnn_weight( + rng, init_weight, lstm.out_dims, (lstm.out_dims, lstm.out_dims)) + for init_weight in lstm.init_weight]...) + ps = (; weight_ih, weight_hh) if has_bias(lstm) - bias = vcat([init_bias(rng, lstm.out_dims) for init_bias in lstm.init_bias]...) - ps = merge(ps, (bias=bias,)) + bias_ih = vcat([init_bias(rng, lstm.out_dims) for init_bias in lstm.init_bias]...) + bias_hh = vcat([init_bias(rng, lstm.out_dims) for init_bias in lstm.init_bias]...) + ps = merge(ps, (bias_ih, bias_hh)) end has_train_state(lstm) && (ps = merge(ps, (hidden_state=lstm.init_state(rng, lstm.out_dims),))) @@ -455,10 +464,11 @@ const _LSTMCellInputType = Tuple{ function (lstm::LSTMCell)( (x, (hidden_state, memory))::_LSTMCellInputType, ps, st::NamedTuple) y, hidden_stateₙ, memoryₙ = match_eltype(lstm, ps, st, x, hidden_state, memory) - bias = safe_getproperty(ps, Val(:bias)) - z = fused_dense_bias_activation(identity, ps.weight_h, hidden_stateₙ, bias) - g = LuxLib.Impl.matmul(ps.weight_i, y) .+ z - + bias_hh = safe_getproperty(ps, Val(:bias_hh)) + z₁ = fused_dense_bias_activation(identity, ps.weight_hh, hidden_stateₙ, bias_hh) + bias_ih = safe_getproperty(ps, Val(:bias_ih)) + z₂ = fused_dense_bias_activation(identity, ps.weight_ih, y, bias_ih) + g = z₁ .+ z₂ input, forget, cell, output = multigate(g, Val(4)) memory₂ = @. sigmoid_fast(forget) * memoryₙ + sigmoid_fast(input) * tanh_fast(cell) hidden_state₂ = @. sigmoid_fast(output) * tanh_fast(memory₂) From 965c9b3551d82c336b0851fe7353053ba6d3370e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 16:08:21 -0400 Subject: [PATCH 70/95] feat: update bias in GRUCell --- docs/src/introduction/updating_to_v1.md | 4 +- src/layers/recurrent.jl | 87 ++++++++++++++----------- 2 files changed, 51 insertions(+), 40 deletions(-) diff --git a/docs/src/introduction/updating_to_v1.md b/docs/src/introduction/updating_to_v1.md index a554093cf3..cc9f2279ba 100644 --- a/docs/src/introduction/updating_to_v1.md +++ b/docs/src/introduction/updating_to_v1.md @@ -136,5 +136,5 @@ abstraction. ### New Features - [`InstanceNorm`](@ref) now supports tracking statistics. -- [`RNNCell`](@ref) add `bias_ih` and `bias_hh` to the parameters to align with Pytorch. - Both are controlled using `init_bias` and `use_bias`. +- [`RNNCell`](@ref) and [`LSTMCell`](@ref) add `bias_ih` and `bias_hh` to the parameters to + align with Pytorch. Both are controlled using `init_bias` and `use_bias`. diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index cd8a95d00f..2ba4756066 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -411,15 +411,17 @@ end function initialparameters(rng::AbstractRNG, lstm::LSTMCell) weight_ih = vcat([init_rnn_weight( - rng, init_weight, lstm.out_dims, (lstm.out_dims, lstm.in_dins)) + rng, init_weight, lstm.out_dims, (lstm.out_dims, lstm.in_dims)) for init_weight in lstm.init_weight]...) weight_hh = vcat([init_rnn_weight( rng, init_weight, lstm.out_dims, (lstm.out_dims, lstm.out_dims)) for init_weight in lstm.init_weight]...) ps = (; weight_ih, weight_hh) if has_bias(lstm) - bias_ih = vcat([init_bias(rng, lstm.out_dims) for init_bias in lstm.init_bias]...) - bias_hh = vcat([init_bias(rng, lstm.out_dims) for init_bias in lstm.init_bias]...) + bias_ih = vcat([init_rnn_bias(rng, init_bias, lstm.out_dims, lstm.out_dims) + for init_bias in lstm.init_bias]...) + bias_hh = vcat([init_rnn_bias(rng, init_bias, lstm.out_dims, lstm.out_dims) + for init_bias in lstm.init_bias]...) ps = merge(ps, (bias_ih, bias_hh)) end has_train_state(lstm) && @@ -485,17 +487,14 @@ end @doc doc""" GRUCell((in_dims, out_dims)::Pair{<:Int,<:Int}; use_bias=true, train_state::Bool=false, - init_weight::Tuple{Function,Function,Function}=(glorot_uniform, glorot_uniform, - glorot_uniform), - init_bias::Tuple{Function,Function,Function}=(zeros32, zeros32, zeros32), - init_state::Function=zeros32) + init_weight=nothing, init_bias=nothing, init_state=zeros32) Gated Recurrent Unit (GRU) Cell ```math \begin{align} - r &= \sigma(W_{ir} \times x + W_{hr} \times h_{prev} + b_{hr})\\ - z &= \sigma(W_{iz} \times x + W_{hz} \times h_{prev} + b_{hz})\\ + r &= \sigma(W_{ir} \times x + b_{ir} + W_{hr} \times h_{prev} + b_{hr})\\ + z &= \sigma(W_{iz} \times x + b_{iz} + W_{hz} \times h_{prev} + b_{hz})\\ n &= \tanh(W_{in} \times x + b_{in} + r \cdot (W_{hn} \times h_{prev} + b_{hn}))\\ h_{new} &= (1 - z) \cdot n + z \cdot h_{prev} \end{align} @@ -507,8 +506,14 @@ Gated Recurrent Unit (GRU) Cell - `out_dims`: Output (Hidden State) Dimension - `use_bias`: Set to false to deactivate bias - `train_state`: Trainable initial hidden state can be activated by setting this to `true` - - `init_bias`: Initializer for bias. Must be a tuple containing 3 functions - - `init_weight`: Initializer for weight. Must be a tuple containing 3 functions + - `init_bias`: Initializer for bias. Must be a tuple containing 3 functions. If a single + value is passed, it is copied into a 3 element tuple. If `nothing`, then we use + uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(out_dims))`. + - `init_weight`: Initializer for weight. Must be a tuple containing 3 functions. If a + single value is passed, it is copied into a 3 element tuple. If `nothing`, then we use + uniform distribution with bounds `-bound` and `bound` where + `bound = inv(sqrt(out_dims))`. - `init_state`: Initializer for hidden state ## Inputs @@ -532,13 +537,14 @@ Gated Recurrent Unit (GRU) Cell ## Parameters - - `weight_i`: Concatenated Weights to map from input space - ``\{ W_{ir}, W_{iz}, W_{in} \}``. - - `weight_h`: Concatenated Weights to map from hidden space - ``\{ W_{hr}, W_{hz}, W_{hn} \}``. - - `bias_i`: Bias vector (``b_{in}``; not present if `use_bias=false`). - - `bias_h`: Concatenated Bias vector for the hidden space - ``\{ b_{hr}, b_{hz}, b_{hn} \}`` (not present if `use_bias=false`). + - `weight_ih`: Concatenated Weights to map from input space + ``\{ W_{ir}, W_{iz}, W_{in} \}``. + - `weight_hh`: Concatenated Weights to map from hidden space + ``\{ W_{hr}, W_{hz}, W_{hn} \}``. + - `bias_ih`: Concatenated Bias vector for the input space + ``\{ b_{ir}, b_{iz}, b_{in} \}`` (not present if `use_bias=false`). + - `bias_hh`: Concatenated Bias vector for the hidden space + ``\{ b_{hr}, b_{hz}, b_{hn} \}`` (not present if `use_bias=false`). - `hidden_state`: Initial hidden state vector (not present if `train_state=false`) ``\{ b_{hr}, b_{hz}, b_{hn} \}``. @@ -566,15 +572,19 @@ function GRUCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType}; end function initialparameters(rng::AbstractRNG, gru::GRUCell) - weight_i = vcat([init_weight(rng, gru.out_dims, gru.in_dims) - for init_weight in gru.init_weight]...) - weight_h = vcat([init_weight(rng, gru.out_dims, gru.out_dims) - for init_weight in gru.init_weight]...) - ps = (; weight_i, weight_h) + weight_ih = vcat([init_rnn_weight( + rng, init_weight, gru.out_dims, (gru.out_dims, gru.in_dims)) + for init_weight in gru.init_weight]...) + weight_hh = vcat([init_rnn_weight( + rng, init_weight, gru.out_dims, (gru.out_dims, gru.out_dims)) + for init_weight in gru.init_weight]...) + ps = (; weight_ih, weight_hh) if has_bias(gru) - bias_i = gru.init_bias[1](rng, gru.out_dims, 1) - bias_h = vcat([init_bias(rng, gru.out_dims) for init_bias in gru.init_bias]...) - ps = merge(ps, (bias_i=bias_i, bias_h=bias_h)) + bias_ih = vcat([init_rnn_bias(rng, init_bias, gru.out_dims, gru.out_dims) + for init_bias in gru.init_bias]...) + bias_hh = vcat([init_rnn_bias(rng, init_bias, gru.out_dims, gru.out_dims) + for init_bias in gru.init_bias]...) + ps = merge(ps, (; bias_ih, bias_hh)) end has_train_state(gru) && (ps = merge(ps, (hidden_state=gru.init_state(rng, gru.out_dims),))) @@ -599,21 +609,22 @@ const _GRUCellInputType = Tuple{<:AbstractMatrix, Tuple{<:AbstractMatrix}} function (gru::GRUCell)((x, (hidden_state,))::_GRUCellInputType, ps, st::NamedTuple) y, hidden_stateₙ = match_eltype(gru, ps, st, x, hidden_state) - gxs = multigate(ps.weight_i * y, Val(3)) - bias_h = safe_getproperty(ps, Val(:bias_h)) - ghbs = multigate( - fused_dense_bias_activation(identity, ps.weight_h, hidden_stateₙ, bias_h), Val(3)) - r = @. sigmoid_fast(gxs[1] + ghbs[1]) - z = @. sigmoid_fast(gxs[2] + ghbs[2]) - n = gru_cell_compute(gxs[3], r, ghbs[3], safe_getproperty(ps, Val(:bias_i))) - hidden_state₂ = @. (1 - z) * n + z * hidden_stateₙ + z₁ = fused_dense_bias_activation( + identity, ps.weight_ih, y, safe_getproperty(ps, Val(:bias_ih))) + z₂ = fused_dense_bias_activation( + identity, ps.weight_hh, hidden_stateₙ, safe_getproperty(ps, Val(:bias_hh))) - return (hidden_state₂, (hidden_state₂,)), st -end + gxs₁, gxs₂, gxs₃ = multigate(z₁, Val(3)) + ghbs₁, ghbs₂, ghbs₃ = multigate(z₂, Val(3)) + + r = @. sigmoid_fast(gxs₁ + ghbs₁) + z = @. sigmoid_fast(gxs₂ + ghbs₂) + n = @. tanh_fast(gxs₃ + r * ghbs₃) + h′ = @. (1 - z) * n + z * hidden_stateₙ -gru_cell_compute(x, r, y, ::Nothing) = @. tanh_fast(x + r * y) -gru_cell_compute(x, r, y, bias) = @. tanh_fast(x + r * y + bias) + return (h′, (h′,)), st +end function Base.show(io::IO, g::GRUCell) print(io, "GRUCell($(g.in_dims) => $(g.out_dims)") From 0787b5bb79acb68092dd14c3aaf09337cd2064b9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 16:30:41 -0400 Subject: [PATCH 71/95] feat: add cross correlation option to ConvTranspose --- src/layers/conv.jl | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 577e3e643d..ed8fa404be 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -232,7 +232,8 @@ end @doc doc""" ConvTranspose(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, activation=identity; init_weight=glorot_uniform, init_bias=zeros32, - stride=1, pad=0, dilation=1, groups=1, use_bias=True()) + stride=1, pad=0, dilation=1, groups=1, use_bias=True(), + cross_correlation=False()) Standard convolutional transpose layer. @@ -271,6 +272,8 @@ Standard convolutional transpose layer. convolution into (set `groups = in_chs` for Depthwise Convolutions). `in_chs` and `out_chs` must be divisible by `groups`. - `use_bias`: Trainable bias can be disabled entirely by setting this to `false`. + - `cross_correlation`: If `true`, perform transposed cross-correlation instead of + transposed convolution. # Extended Help @@ -301,12 +304,14 @@ Standard convolutional transpose layer. init_weight init_bias use_bias <: StaticBool + cross_correlation <: StaticBool end function ConvTranspose( k::Tuple{Vararg{IntegerType}}, ch::Pair{<:IntegerType, <:IntegerType}, - activation=identity; init_weight=glorot_uniform, init_bias=zeros32, - stride=1, pad=0, dilation=1, groups=1, use_bias::BoolType=True()) + activation=identity; init_weight=glorot_uniform, + init_bias=zeros32, stride=1, pad=0, dilation=1, groups=1, + use_bias::BoolType=True(), cross_correlation::BoolType=False()) stride = Utils.expand(Val(length(k)), stride) dilation = Utils.expand(Val(length(k)), dilation) pad = if pad isa SamePad @@ -319,8 +324,8 @@ function ConvTranspose( @argcheck ch[1] % groups==0 DimensionMismatch("Output channel dimension must be divisible by groups.") @argcheck allequal(length, (stride, dilation, k)) - return ConvTranspose(activation, first(ch), last(ch), k, stride, pad, dilation, - groups, init_weight, init_bias, static(use_bias)) + return ConvTranspose(activation, first(ch), last(ch), k, stride, pad, dilation, groups, + init_weight, init_bias, static(use_bias), static(cross_correlation)) end function initialparameters(rng::AbstractRNG, c::ConvTranspose) @@ -340,7 +345,8 @@ end function (c::ConvTranspose)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(c, ps, st, x) - cdims = conv_transpose_dims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups) + cdims = construct_crosscor_convdims(c.cross_correlation, + conv_transpose_dims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups)) bias = safe_getproperty(ps, Val(:bias)) σ = NNlib.fast_act(c.activation, y) return bias_activation!!(σ, conv_transpose(y, ps.weight, cdims), bias), st @@ -356,7 +362,7 @@ function Base.show(io::IO, l::ConvTranspose) print(io, ", dilation=", PrettyPrinting.tuple_string(l.dilation)) (l.groups == 1) || print(io, ", groups=", l.groups) has_bias(l) || print(io, ", use_bias=false") - return print(io, ")") + print(io, ")") end @doc doc""" From 6271c009103a597cb935175a707cea8e804a2f3b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 16:59:13 -0400 Subject: [PATCH 72/95] fix: accidental type to rand32 --- docs/src/introduction/updating_to_v1.md | 2 ++ src/layers/conv.jl | 4 ++-- src/layers/recurrent.jl | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/src/introduction/updating_to_v1.md b/docs/src/introduction/updating_to_v1.md index cc9f2279ba..4178ae38bf 100644 --- a/docs/src/introduction/updating_to_v1.md +++ b/docs/src/introduction/updating_to_v1.md @@ -138,3 +138,5 @@ abstraction. - [`InstanceNorm`](@ref) now supports tracking statistics. - [`RNNCell`](@ref) and [`LSTMCell`](@ref) add `bias_ih` and `bias_hh` to the parameters to align with Pytorch. Both are controlled using `init_bias` and `use_bias`. +- [`ConvTranspose`](@ref) allows `flipkernel=true` via `cross_correlation=true`. This makes + it efficient for MIOpen. diff --git a/src/layers/conv.jl b/src/layers/conv.jl index ed8fa404be..e28c78e1a1 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -432,7 +432,7 @@ function Base.show(io::IO, m::MaxPool) print(io, "MaxPool(", m.k) all(==(0), m.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(m.pad)) m.stride == m.k || print(io, ", stride=", PrettyPrinting.tuple_string(m.stride)) - return print(io, ")") + print(io, ")") end @doc doc""" @@ -502,7 +502,7 @@ function Base.show(io::IO, m::MeanPool) print(io, "MeanPool(", m.k) all(==(0), m.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(m.pad)) m.stride == m.k || print(io, ", stride=", PrettyPrinting.tuple_string(m.stride)) - return print(io, ")") + print(io, ")") end """ diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 2ba4756066..fe039bfd62 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -30,7 +30,7 @@ LuxOps.eachslice(x::AbstractMatrix, ::BatchLastIndex) = LuxOps.eachslice(x, Val( function init_rnn_weight(rng::AbstractRNG, init_weight, hidden_dims, dims) if init_weight === nothing bound = inv(sqrt(hidden_dims)) - y = randn32(rng, Float32, dims...) + y = randn32(rng, dims...) @. y = (y - 0.5f0) * 2 * bound return y end From 1122d4083d211ac194772b9e614ce9942c4c0c44 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 16:59:58 -0400 Subject: [PATCH 73/95] fix: unwanted printing --- src/layers/conv.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index e28c78e1a1..900d8d3b23 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -208,10 +208,8 @@ function (c::Conv)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(c, ps, st, x) cdims = construct_crosscor_convdims(c.cross_correlation, DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups)) - @show cdims bias = safe_getproperty(ps, Val(:bias)) σ = NNlib.fast_act(c.activation, y) - @show σ return fused_conv_bias_activation(σ, ps.weight, y, bias, cdims), st end From e47f06362eb54dcc802abf8a5759ee2cf90bf74a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 17:04:51 -0400 Subject: [PATCH 74/95] refactor: move the Upsample layer --- src/layers/conv.jl | 292 ++++++++++++++++++++-------------------- src/layers/recurrent.jl | 2 +- 2 files changed, 147 insertions(+), 147 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 900d8d3b23..66c5e61787 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -363,6 +363,152 @@ function Base.show(io::IO, l::ConvTranspose) print(io, ")") end +""" + Upsample(mode = :nearest; [scale, size, align_corners=false]) + Upsample(scale, mode = :nearest) + +Upsampling Layer. + +## Layer Construction + +### Option 1 + + - `mode`: Set to `:nearest`, `:linear`, `:bilinear` or `:trilinear` + +Exactly one of two keywords must be specified: + + - If `scale` is a number, this applies to all but the last two dimensions (channel and + batch) of the input. It may also be a tuple, to control dimensions individually. + - Alternatively, keyword `size` accepts a tuple, to directly specify the leading + dimensions of the output. + +### Option 2 + + - If `scale` is a number, this applies to all but the last two dimensions (channel and + batch) of the input. It may also be a tuple, to control dimensions individually. + - `mode`: Set to `:nearest`, `:bilinear` or `:trilinear` + +Currently supported upsampling `mode`s and corresponding NNlib's methods are: + + - `:nearest` -> `NNlib.upsample_nearest` + - `:bilinear` -> `NNlib.upsample_bilinear` + - `:trilinear` -> `NNlib.upsample_trilinear` + +# Extended Help + +## Other Keyword Arguments + + - `align_corners`: If `true`, the corner pixels of the input and output tensors are + aligned, and thus preserving the values at those pixels. This only has effect when mode + is one of `:bilinear` or `:trilinear`. + +## Inputs + + - `x`: For the input dimensions look into the documentation for the corresponding `NNlib` + function + + + As a rule of thumb, `:nearest` should work with arrays of arbitrary dimensions + + `:bilinear` works with 4D Arrays + + `:trilinear` works with 5D Arrays + +## Returns + + - Upsampled Input of size `size` or of size `(I_1 x scale[1], ..., I_N x scale[N], C, N)` + - Empty `NamedTuple()` +""" +@concrete struct Upsample <: AbstractLuxLayer + scale + size + upsample_mode <: StaticSymbol + align_corners <: Bool +end + +function Upsample(mode::SymbolType=static(:nearest); scale=nothing, + size=nothing, align_corners::Bool=false) + @argcheck dynamic(mode) in (:nearest, :bilinear, :trilinear) + + if !xor(isnothing(scale), isnothing(size)) + throw(ArgumentError("Either scale or size should be specified (but not both).")) + end + return Upsample(scale, size, static(mode), align_corners) +end + +Upsample(scale, mode::SymbolType=static(:nearest)) = Upsample(mode; scale) + +function (m::Upsample)(x::AbstractArray, _, st::NamedTuple) + return lux_upsample_scale_dispatch(m.upsample_mode, x, m.scale, m.align_corners), st +end +function (m::Upsample{Nothing})(x::AbstractArray, _, st::NamedTuple) + return lux_upsample_size_dispatch(m.upsample_mode, x, m.size, m.align_corners), st +end + +for interp in (:bilinear, :trilinear) + nnlib_interp_func = Symbol(:upsample_, interp) + @eval begin + function lux_upsample_scale_dispatch( + ::StaticSymbol{$(Meta.quot(interp))}, x, scale, align_corners) + return $(nnlib_interp_func)(x, scale) + end + function lux_upsample_size_dispatch( + ::StaticSymbol{$(Meta.quot(interp))}, x, size, align_corners) + return $(nnlib_interp_func)(x; size) + end + end +end + +function lux_upsample_size_dispatch(::StaticSymbol{:nearest}, x, size, _) + return NNlib.upsample_nearest(x; size) +end +function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale, _) + return NNlib.upsample_nearest(x, scale) +end +function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale::Integer, _) + return NNlib.upsample_nearest(x, ntuple(i -> scale, ndims(x) - 2)) +end + +function Base.show(io::IO, u::Upsample) + print(io, "Upsample(", u.upsample_mode) + u.scale !== nothing && print(io, ", scale = $(u.scale)") + u.size !== nothing && print(io, ", size = $(u.size)") + u.align_corners && print(io, ", align_corners = $(u.align_corners)") + print(io, ")") +end + +""" + PixelShuffle(r::Int) + +Pixel shuffling layer with upscale factor `r`. Usually used for generating higher +resolution images while upscaling them. + +See `NNlib.pixel_shuffle` for more details. + +PixelShuffle is not a Layer, rather it returns a [`WrappedFunction`](@ref) with the +function set to `Base.Fix2(pixel_shuffle, r)` + +## Arguments + + - `r`: Upscale factor + +## Inputs + + - `x`: For 4D-arrays representing N images, the operation converts input + `size(x) == (W, H, r² x C, N)` to output of size `(r x W, r x H, C, N)`. For + D-dimensional data, it expects `ndims(x) == D + 2` with channel and batch dimensions, and + divides the number of channels by `rᴰ`. + +## Returns + + - Output of size `(r x W, r x H, C, N)` for 4D-arrays, and `(r x W, r x H, ..., C, N)` + for D-dimensional data, where `D = ndims(x) - 2` +""" +@concrete struct PixelShuffle <: AbstractLuxWrapperLayer{:layer} + layer <: AbstractLuxLayer +end + +function PixelShuffle(r::IntegerType) + return PixelShuffle(WrappedFunction(Base.Fix2(pixel_shuffle, r))) +end + @doc doc""" MaxPool(window::NTuple; pad=0, stride=window) @@ -503,117 +649,6 @@ function Base.show(io::IO, m::MeanPool) print(io, ")") end -""" - Upsample(mode = :nearest; [scale, size, align_corners=false]) - Upsample(scale, mode = :nearest) - -Upsampling Layer. - -## Layer Construction - -### Option 1 - - - `mode`: Set to `:nearest`, `:linear`, `:bilinear` or `:trilinear` - -Exactly one of two keywords must be specified: - - - If `scale` is a number, this applies to all but the last two dimensions (channel and - batch) of the input. It may also be a tuple, to control dimensions individually. - - Alternatively, keyword `size` accepts a tuple, to directly specify the leading - dimensions of the output. - -### Option 2 - - - If `scale` is a number, this applies to all but the last two dimensions (channel and - batch) of the input. It may also be a tuple, to control dimensions individually. - - `mode`: Set to `:nearest`, `:bilinear` or `:trilinear` - -Currently supported upsampling `mode`s and corresponding NNlib's methods are: - - - `:nearest` -> `NNlib.upsample_nearest` - - `:bilinear` -> `NNlib.upsample_bilinear` - - `:trilinear` -> `NNlib.upsample_trilinear` - -# Extended Help - -## Other Keyword Arguments - - - `align_corners`: If `true`, the corner pixels of the input and output tensors are - aligned, and thus preserving the values at those pixels. This only has effect when mode - is one of `:bilinear` or `:trilinear`. - -## Inputs - - - `x`: For the input dimensions look into the documentation for the corresponding `NNlib` - function - - + As a rule of thumb, `:nearest` should work with arrays of arbitrary dimensions - + `:bilinear` works with 4D Arrays - + `:trilinear` works with 5D Arrays - -## Returns - - - Upsampled Input of size `size` or of size `(I_1 x scale[1], ..., I_N x scale[N], C, N)` - - Empty `NamedTuple()` -""" -@concrete struct Upsample <: AbstractLuxLayer - scale - size - upsample_mode <: StaticSymbol - align_corners <: Bool -end - -function Upsample(mode::SymbolType=static(:nearest); scale=nothing, - size=nothing, align_corners::Bool=false) - @argcheck dynamic(mode) in (:nearest, :bilinear, :trilinear) - - if !xor(isnothing(scale), isnothing(size)) - throw(ArgumentError("Either scale or size should be specified (but not both).")) - end - return Upsample(scale, size, static(mode), align_corners) -end - -Upsample(scale, mode::SymbolType=static(:nearest)) = Upsample(mode; scale) - -function (m::Upsample)(x::AbstractArray, _, st::NamedTuple) - return lux_upsample_scale_dispatch(m.upsample_mode, x, m.scale, m.align_corners), st -end -function (m::Upsample{Nothing})(x::AbstractArray, _, st::NamedTuple) - return lux_upsample_size_dispatch(m.upsample_mode, x, m.size, m.align_corners), st -end - -for interp in (:bilinear, :trilinear) - nnlib_interp_func = Symbol(:upsample_, interp) - @eval begin - function lux_upsample_scale_dispatch( - ::StaticSymbol{$(Meta.quot(interp))}, x, scale, align_corners) - return $(nnlib_interp_func)(x, scale) - end - function lux_upsample_size_dispatch( - ::StaticSymbol{$(Meta.quot(interp))}, x, size, align_corners) - return $(nnlib_interp_func)(x; size) - end - end -end - -function lux_upsample_size_dispatch(::StaticSymbol{:nearest}, x, size, _) - return NNlib.upsample_nearest(x; size) -end -function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale, _) - return NNlib.upsample_nearest(x, scale) -end -function lux_upsample_scale_dispatch(::StaticSymbol{:nearest}, x, scale::Integer, _) - return NNlib.upsample_nearest(x, ntuple(i -> scale, ndims(x) - 2)) -end - -function Base.show(io::IO, u::Upsample) - print(io, "Upsample(", u.upsample_mode) - u.scale !== nothing && print(io, ", scale = $(u.scale)") - u.size !== nothing && print(io, ", size = $(u.size)") - u.align_corners && print(io, ", align_corners = $(u.align_corners)") - print(io, ")") -end - """ GlobalMaxPool() @@ -725,38 +760,3 @@ function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}, _, st::NamedTuple) whe end Base.show(io::IO, a::AdaptiveMeanPool) = print(io, "AdaptiveMeanPool(", a.out, ")") - -""" - PixelShuffle(r::Int) - -Pixel shuffling layer with upscale factor `r`. Usually used for generating higher -resolution images while upscaling them. - -See `NNlib.pixel_shuffle` for more details. - -PixelShuffle is not a Layer, rather it returns a [`WrappedFunction`](@ref) with the -function set to `Base.Fix2(pixel_shuffle, r)` - -## Arguments - - - `r`: Upscale factor - -## Inputs - - - `x`: For 4D-arrays representing N images, the operation converts input - `size(x) == (W, H, r² x C, N)` to output of size `(r x W, r x H, C, N)`. For - D-dimensional data, it expects `ndims(x) == D + 2` with channel and batch dimensions, and - divides the number of channels by `rᴰ`. - -## Returns - - - Output of size `(r x W, r x H, C, N)` for 4D-arrays, and `(r x W, r x H, ..., C, N)` - for D-dimensional data, where `D = ndims(x) - 2` -""" -@concrete struct PixelShuffle <: AbstractLuxWrapperLayer{:layer} - layer <: AbstractLuxLayer -end - -function PixelShuffle(r::IntegerType) - return PixelShuffle(WrappedFunction(Base.Fix2(pixel_shuffle, r))) -end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index fe039bfd62..878c713beb 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -422,7 +422,7 @@ function initialparameters(rng::AbstractRNG, lstm::LSTMCell) for init_bias in lstm.init_bias]...) bias_hh = vcat([init_rnn_bias(rng, init_bias, lstm.out_dims, lstm.out_dims) for init_bias in lstm.init_bias]...) - ps = merge(ps, (bias_ih, bias_hh)) + ps = merge(ps, (; bias_ih, bias_hh)) end has_train_state(lstm) && (ps = merge(ps, (hidden_state=lstm.init_state(rng, lstm.out_dims),))) From fd667807f0f600e17fa556f78d038069e1a604ec Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Sep 2024 18:17:43 -0400 Subject: [PATCH 75/95] feat: generalize pooling implementation and add LP versions --- docs/src/api/Lux/layers.md | 3 + ext/LuxSimpleChainsExt.jl | 5 +- src/Lux.jl | 6 +- src/layers/conv.jl | 261 ------------------------------------- src/layers/normalize.jl | 4 +- src/layers/pooling.jl | 240 ++++++++++++++++++++++++++++++++++ 6 files changed, 252 insertions(+), 267 deletions(-) create mode 100644 src/layers/pooling.jl diff --git a/docs/src/api/Lux/layers.md b/docs/src/api/Lux/layers.md index 6041984940..b591844aab 100644 --- a/docs/src/api/Lux/layers.md +++ b/docs/src/api/Lux/layers.md @@ -35,10 +35,13 @@ VariationalHiddenDropout ## Pooling Layers ```@docs +AdaptiveLPPool AdaptiveMaxPool AdaptiveMeanPool +GlobalLPPool GlobalMaxPool GlobalMeanPool +LPPool MaxPool MeanPool ``` diff --git a/ext/LuxSimpleChainsExt.jl b/ext/LuxSimpleChainsExt.jl index 1d10fc106a..a311559ac7 100644 --- a/ext/LuxSimpleChainsExt.jl +++ b/ext/LuxSimpleChainsExt.jl @@ -61,8 +61,9 @@ function Lux.make_simplechain_network(layer::FlattenLayer) end function Lux.make_simplechain_network(layer::MaxPool) - if layer.stride == layer.k && (!(layer.pad isa SamePad) && all(==(0), layer.pad)) - return SimpleChains.MaxPool(layer.k) + if layer.layer.mode.stride == layer.layer.mode.kernel_size && + all(==(0), layer.layer.mode.pad) + return SimpleChains.MaxPool(layer.layer.mode.kernel_size) end throw(SimpleChainsModelConversionException("MaxPool with non-standard parameters not \ supported.")) diff --git a/src/Lux.jl b/src/Lux.jl index a0650c673f..37506abcaa 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -58,6 +58,7 @@ include("layers/basic.jl") include("layers/containers.jl") include("layers/normalize.jl") include("layers/conv.jl") +include("layers/pooling.jl") include("layers/dropout.jl") include("layers/recurrent.jl") include("layers/extension.jl") @@ -87,8 +88,9 @@ include("distributed/public_api.jl") # Layers export Chain, Parallel, SkipConnection, PairwiseFusion, BranchLayer, Maxout, RepeatedLayer export Bilinear, Dense, Embedding, Scale -export Conv, ConvTranspose, MaxPool, MeanPool, GlobalMaxPool, GlobalMeanPool, - AdaptiveMaxPool, AdaptiveMeanPool, Upsample, PixelShuffle +export Conv, ConvTranspose, Upsample, PixelShuffle +export MaxPool, MeanPool, LPPool, GlobalMaxPool, GlobalMeanPool, GlobalLPPool, + AdaptiveMaxPool, AdaptiveMeanPool, AdaptiveLPPool export AlphaDropout, Dropout, VariationalHiddenDropout export BatchNorm, GroupNorm, InstanceNorm, LayerNorm export WeightNorm diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 66c5e61787..b026384069 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -42,15 +42,6 @@ CRC.@non_differentiable conv_transpose_dims(::Any...) conv_transpose(x, weight, cdims) = LuxLib.Impl.∇conv_data(x, weight, cdims) -function compute_adaptive_pooling_dims(x::AbstractArray, outsize) - insize = size(x)[1:(end - 2)] - stride = insize .÷ outsize - k = insize .- (outsize .- 1) .* stride - return PoolDims(x, k; padding=0, stride=stride) -end - -CRC.@non_differentiable compute_adaptive_pooling_dims(::Any, ::Any) - function init_conv_weight( rng::AbstractRNG, init_weight::F, filter::NTuple{N, <:IntegerType}, in_chs::IntegerType, out_chs::IntegerType, groups, σ::A) where {F, N, A} @@ -508,255 +499,3 @@ end function PixelShuffle(r::IntegerType) return PixelShuffle(WrappedFunction(Base.Fix2(pixel_shuffle, r))) end - -@doc doc""" - MaxPool(window::NTuple; pad=0, stride=window) - -Max pooling layer, which replaces all pixels in a block of size `window` with the maximum -value. - -# Arguments - - - `window`: Tuple of integers specifying the size of the window. Eg, for 2D pooling - `length(window) == 2` - -## Keyword Arguments - - - `stride`: Should each be either single integer, or a tuple with `N` integers - - - `pad`: Specifies the number of elements added to the borders of the data array. It can - be - - + a single integer for equal padding all around, - + a tuple of `N` integers, to apply the same padding at begin/end of each spatial - dimension, - + a tuple of `2*N` integers, for asymmetric padding, or - + the singleton `SamePad()`, to calculate padding such that - `size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial - dimension. - -# Extended Help - -## Inputs - - - `x`: Data satisfying `ndims(x) == N + 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` - -## Returns - - - Output of the pooling `y` of size `(O_N, ..., O_1, C, N)` where - -```math - O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s_i} + 1\right\rfloor -``` - - - Empty `NamedTuple()` - -See also [`Conv`](@ref), [`MeanPool`](@ref), [`GlobalMaxPool`](@ref), -[`AdaptiveMaxPool`](@ref) -""" -@concrete struct MaxPool <: AbstractLuxLayer - k <: Tuple{Vararg{IntegerType}} - pad <: Tuple{Vararg{IntegerType}} - stride <: Tuple{Vararg{IntegerType}} -end - -function MaxPool(k::Tuple{Vararg{IntegerType}}; pad=0, stride=k) - stride = Utils.expand(Val(length(k)), stride) - pad = calc_padding(pad, k, 1, stride) - @argcheck allequal(length, (stride, k)) - - return MaxPool(k, pad, stride) -end - -function (m::MaxPool)(x, _, st::NamedTuple) - return maxpool(x, PoolDims(x, m.k; padding=m.pad, m.stride)), st -end - -function Base.show(io::IO, m::MaxPool) - print(io, "MaxPool(", m.k) - all(==(0), m.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(m.pad)) - m.stride == m.k || print(io, ", stride=", PrettyPrinting.tuple_string(m.stride)) - print(io, ")") -end - -@doc doc""" - MeanPool(window::NTuple; pad=0, stride=window) - -Mean pooling layer, which replaces all pixels in a block of size `window` with the mean -value. - -# Arguments - - - `window`: Tuple of integers specifying the size of the window. Eg, for 2D pooling - `length(window) == 2` - -## Keyword Arguments - - - `stride`: Should each be either single integer, or a tuple with `N` integers - - - `pad`: Specifies the number of elements added to the borders of the data array. It can - be - - + a single integer for equal padding all around, - + a tuple of `N` integers, to apply the same padding at begin/end of each spatial - dimension, - + a tuple of `2*N` integers, for asymmetric padding, or - + the singleton `SamePad()`, to calculate padding such that - `size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial - dimension. - -# Extended Help - -## Inputs - - - `x`: Data satisfying `ndims(x) == N + 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` - -## Returns - - - Output of the pooling `y` of size `(O_N, ..., O_1, C, N)` where - -```math - O_i = \left\lfloor\frac{I_i + p_i + p_{(i + N) \% |p|} - d_i \times (k_i - 1)}{s_i} + 1\right\rfloor -``` - - - Empty `NamedTuple()` - -See also [`Conv`](@ref), [`MaxPool`](@ref), [`GlobalMeanPool`](@ref), -[`AdaptiveMeanPool`](@ref) -""" -@concrete struct MeanPool <: AbstractLuxLayer - k <: Tuple{Vararg{IntegerType}} - pad <: Tuple{Vararg{IntegerType}} - stride <: Tuple{Vararg{IntegerType}} -end - -function MeanPool(k::Tuple{Vararg{IntegerType}}; pad=0, stride=k) - stride = Utils.expand(Val(length(k)), stride) - pad = calc_padding(pad, k, 1, stride) - @argcheck allequal(length, (stride, k)) - - return MeanPool(k, pad, stride) -end - -function (m::MeanPool)(x, _, st::NamedTuple) - return meanpool(x, PoolDims(x, m.k; padding=m.pad, m.stride)), st -end - -function Base.show(io::IO, m::MeanPool) - print(io, "MeanPool(", m.k) - all(==(0), m.pad) || print(io, ", pad=", PrettyPrinting.tuple_string(m.pad)) - m.stride == m.k || print(io, ", stride=", PrettyPrinting.tuple_string(m.stride)) - print(io, ")") -end - -""" - GlobalMaxPool() - -Global Max Pooling layer. Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output, -by performing max pooling on the complete (w,h)-shaped feature maps. - -## Inputs - - - `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` - -## Returns - - - Output of the pooling `y` of size `(1, ..., 1, C, N)` - - Empty `NamedTuple()` - -See also [`MaxPool`](@ref), [`AdaptiveMaxPool`](@ref), [`GlobalMeanPool`](@ref) -""" -struct GlobalMaxPool <: AbstractLuxLayer end - -function (g::GlobalMaxPool)(x, _, st::NamedTuple) - return maxpool(x, PoolDims(x, size(x)[1:(end - 2)])), st -end - -""" - GlobalMeanPool() - -Global Mean Pooling layer. Transforms (w,h,c,b)-shaped input into (1,1,c,b)-shaped output, -by performing mean pooling on the complete (w,h)-shaped feature maps. - -## Inputs - - - `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` - -## Returns - - - Output of the pooling `y` of size `(1, ..., 1, C, N)` - - Empty `NamedTuple()` - -See also [`MeanPool`](@ref), [`AdaptiveMeanPool`](@ref), [`GlobalMaxPool`](@ref) -""" -struct GlobalMeanPool <: AbstractLuxLayer end - -function (g::GlobalMeanPool)(x, _, st::NamedTuple) - return meanpool(x, PoolDims(x, size(x)[1:(end - 2)])), st -end - -""" - AdaptiveMaxPool(out::NTuple) - -Adaptive Max Pooling layer. Calculates the necessary window size such that its output has -`size(y)[1:N] == out`. - -## Arguments - - - `out`: Size of the first `N` dimensions for the output - -## Inputs - - - `x`: Expects as input an array with `ndims(x) == N+2`, i.e. channel and batch - dimensions, after the `N` feature dimensions, where `N = length(out)`. - -## Returns - - - Output of size `(out..., C, N)` - - Empty `NamedTuple()` - -See also [`MaxPool`](@ref), [`AdaptiveMeanPool`](@ref). -""" -struct AdaptiveMaxPool{S, O <: Tuple{Vararg{IntegerType}}} <: AbstractLuxLayer - out::O - AdaptiveMaxPool(out) = new{length(out) + 2, typeof(out)}(out) -end - -function (a::AdaptiveMaxPool{S})(x::AbstractArray{T, S}, _, st::NamedTuple) where {S, T} - return maxpool(x, compute_adaptive_pooling_dims(x, a.out)), st -end - -Base.show(io::IO, a::AdaptiveMaxPool) = print(io, "AdaptiveMaxPool(", a.out, ")") - -""" - AdaptiveMeanPool(out::NTuple) - -Adaptive Mean Pooling layer. Calculates the necessary window size such that its output has -`size(y)[1:N] == out`. - -## Arguments - - - `out`: Size of the first `N` dimensions for the output - -## Inputs - - - `x`: Expects as input an array with `ndims(x) == N+2`, i.e. channel and batch - dimensions, after the `N` feature dimensions, where `N = length(out)`. - -## Returns - - - Output of size `(out..., C, N)` - - Empty `NamedTuple()` - -See also [`MeanPool`](@ref), [`AdaptiveMaxPool`](@ref). -""" -struct AdaptiveMeanPool{S, O <: Tuple{Vararg{IntegerType}}} <: AbstractLuxLayer - out::O - AdaptiveMeanPool(out) = new{length(out) + 2, typeof(out)}(out) -end - -function (a::AdaptiveMeanPool{S})(x::AbstractArray{T, S}, _, st::NamedTuple) where {S, T} - return meanpool(x, compute_adaptive_pooling_dims(x, a.out)), st -end - -Base.show(io::IO, a::AdaptiveMeanPool) = print(io, "AdaptiveMeanPool(", a.out, ")") diff --git a/src/layers/normalize.jl b/src/layers/normalize.jl index 1912922863..a355b03a1c 100644 --- a/src/layers/normalize.jl +++ b/src/layers/normalize.jl @@ -323,8 +323,8 @@ Use `Lux.testmode` during inference. ## Example ```jldoctest -julia> Chain(Dense(784 => 64), InstanceNorm(64, relu), Dense(64 => 10), - InstanceNorm(10, relu)) +julia> Chain(Dense(784 => 64), InstanceNorm(64, relu; affine=true), Dense(64 => 10), + InstanceNorm(10, relu; affine=true)) Chain( layer_1 = Dense(784 => 64), # 50_240 parameters layer_2 = InstanceNorm(64, relu, affine=true, track_stats=false), # 128 parameters, plus 1 diff --git a/src/layers/pooling.jl b/src/layers/pooling.jl new file mode 100644 index 0000000000..815fb0e4c4 --- /dev/null +++ b/src/layers/pooling.jl @@ -0,0 +1,240 @@ +abstract type AbstractPoolMode end + +CRC.@non_differentiable (::AbstractPoolMode)(::Any...) + +@concrete struct GenericPoolMode <: AbstractPoolMode + kernel_size <: Tuple{Vararg{IntegerType}} + stride <: Tuple{Vararg{IntegerType}} + pad <: Tuple{Vararg{IntegerType}} + dilation <: Tuple{Vararg{IntegerType}} +end + +(m::GenericPoolMode)(x) = PoolDims(x, m.kernel_size; padding=m.pad, m.stride, m.dilation) + +struct GlobalPoolMode <: AbstractPoolMode end + +(::GlobalPoolMode)(x) = PoolDims(x, size(x)[1:(end - 2)]) + +@concrete struct AdaptivePoolMode <: AbstractPoolMode + out_size <: Tuple{Vararg{IntegerType}} +end + +function (m::AdaptivePoolMode)(x) + in_size = size(x)[1:(end - 2)] + stride = in_size .÷ m.out_size + kernel_size = in_size .- (m.out_size .- 1) .* stride + return PoolDims(x, kernel_size; padding=0, stride, dilation=1) +end + +symbol_to_pool_mode(::StaticSymbol{:generic}) = GenericPoolMode +symbol_to_pool_mode(::StaticSymbol{:global}) = GlobalPoolMode +symbol_to_pool_mode(::StaticSymbol{:adaptive}) = AdaptivePoolMode + +abstract type AbstractPoolOp end + +struct MaxPoolOp <: AbstractPoolOp end +(m::MaxPoolOp)(x, pdims) = maxpool(x, pdims) + +struct MeanPoolOp <: AbstractPoolOp end +(m::MeanPoolOp)(x, pdims) = meanpool(x, pdims) + +@concrete struct LpPoolOp <: AbstractPoolOp + p +end +(m::LpPoolOp)(x, pdims) = lpnormpool(x, pdims; m.p) + +symbol_to_pool_op(::StaticSymbol{:max}, _) = MaxPoolOp() +symbol_to_pool_op(::StaticSymbol{:mean}, _) = MeanPoolOp() +symbol_to_pool_op(::StaticSymbol{:lp}, p) = LpPoolOp(p) + +@concrete struct PoolingLayer <: AbstractLuxLayer + mode <: AbstractPoolMode + op <: AbstractPoolOp +end + +function PoolingLayer(mode::SymbolType, op::SymbolType, + arg::Union{Nothing, Tuple{Vararg{IntegerType}}}=nothing; + stride=arg, pad=0, dilation=1, p=2) + return PoolingLayer(symbol_to_pool_mode(static(mode)), + symbol_to_pool_op(static(op), p), arg; stride, pad, dilation) +end + +function PoolingLayer(::Type{GenericPoolMode}, op::AbstractPoolOp, + kernel_size::Tuple{Vararg{IntegerType}}; stride=kernel_size, pad=0, dilation=1) + stride = Utils.expand(Val(length(kernel_size)), stride) + pad = calc_padding(pad, kernel_size, dilation, stride) + dilation = Utils.expand(Val(length(kernel_size)), dilation) + @argcheck allequal(length, (stride, kernel_size, dilation)) + + return PoolingLayer(GenericPoolMode(kernel_size, stride, pad, dilation), op) +end + +function PoolingLayer(::Type{AdaptivePoolMode}, op::AbstractPoolOp, + out_size::Tuple{Vararg{IntegerType}}; kwargs...) + return PoolingLayer(AdaptivePoolMode(out_size), op) +end + +function PoolingLayer(::Type{GlobalPoolMode}, op::AbstractPoolOp, ::Nothing; kwargs...) + return PoolingLayer(GlobalPoolMode(), op) +end + +(m::PoolingLayer)(x, _, st::NamedTuple) = m.op(x, m.mode(x)), st + +for layer_op in (:Max, :Mean, :LP) + op = Symbol(lowercase(string(layer_op))) + + layer_name = Symbol(layer_op, :Pool) + extra_kwargs = layer_op == :LP ? ", p=2" : "" + layer_docstring = """ + $(layer_name)(window; stride=window, pad=0, dilation=1$(extra_kwargs)) + + $(layer_op) Pooling layer, which replaces all pixels in a block of size `window` with + the reduction operation: $(op). + + ## Arguments + + - `window`: Tuple of integers specifying the size of the window. Eg, for 2D pooling + `length(window) == 2` + + ## Keyword Arguments + + - `stride`: Should each be either single integer, or a tuple with `N` integers + - `dilation`: Should each be either single integer, or a tuple with `N` integers + + - `pad`: Specifies the number of elements added to the borders of the data array. It can + be + + + a single integer for equal padding all around, + + a tuple of `N` integers, to apply the same padding at begin/end of each spatial + dimension, + + a tuple of `2*N` integers, for asymmetric padding, or + + the singleton `SamePad()`, to calculate padding such that + `size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial + dimension. + + # Extended Help + + ## Inputs + + - `x`: Data satisfying `ndims(x) == N + 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` + + ## Returns + + - Output of the pooling `y` of size `(O_N, ..., O_1, C, N)` where + + ```math + O_i = \\left\\lfloor\\frac{I_i + p_i + p_{(i + N) \\% |p|} - d_i \\times (k_i - 1)}{s_i} + 1\\right\\rfloor + ``` + + - Empty `NamedTuple()` + """ + + global_layer_name = Symbol(:Global, layer_name) + extra_kwargs = layer_op == :LP ? "; p=2" : "" + global_pooling_docstring = """ + $(global_layer_name)($(extra_kwargs)) + + Global $(layer_op) Pooling layer. Transforms `(w, h, c, b)`-shaped input into + `(1, 1, c, b)`-shaped output, by performing mean pooling on the complete `(w, h)`-shaped + feature maps. + + ## Inputs + + - `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` + + ## Returns + + - Output of the pooling `y` of size `(1, ..., 1, C, N)` + - Empty `NamedTuple()` + """ + + adaptive_layer_name = Symbol(:Adaptive, layer_name) + adaptive_pooling_docstring = """ + $(adaptive_layer_name)(output_size$(extra_kwargs)) + + Adaptive $(layer_op) Pooling layer. Calculates the necessary window size such that + its output has `size(y)[1:N] == output_size`. + + ## Arguments + + - `output_size`: Size of the first `N` dimensions for the output + + ## Inputs + + - `x`: Expects as input an array with `ndims(x) == N + 2`, i.e. channel and batch + dimensions, after the `N` feature dimensions, where `N = length(output_size)`. + + ## Returns + + - Output of size `(out..., C, N)` + - Empty `NamedTuple()` + """ + + @eval begin + # Generic Pooling Layer + @doc $(layer_docstring) @concrete struct $(layer_name) <: + AbstractLuxWrapperLayer{:layer} + layer <: PoolingLayer + end + + function $(layer_name)( + window::Tuple{Vararg{IntegerType}}; stride=window, pad=0, dilation=1, p=2) + return $(layer_name)(PoolingLayer(static(:generic), static($(Meta.quot(op))), + window; stride, pad, dilation, p)) + end + + function Base.show(io::IO, ::MIME"text/plain", m::$(layer_name)) + kernel_size = m.layer.mode.kernel_size + print(io, string($(Meta.quot(layer_name))), "($(kernel_size)") + pad = m.layer.mode.pad + all(==(0), pad) || print(io, ", pad=", PrettyPrinting.tuple_string(pad)) + stride = m.layer.mode.stride + stride == kernel_size || + print(io, ", stride=", PrettyPrinting.tuple_string(stride)) + dilation = m.layer.mode.dilation + all(==(1), dilation) || + print(io, ", dilation=", PrettyPrinting.tuple_string(dilation)) + if $(Meta.quot(op)) == :lp + a.layer.op.p == 2 || print(io, ", p=", a.layer.op.p) + end + print(io, ")") + end + + # Global Pooling Layer + @doc $(global_pooling_docstring) @concrete struct $(global_layer_name) <: + AbstractLuxWrapperLayer{:layer} + layer <: PoolingLayer + end + + function $(global_layer_name)(; p=2) + return $(global_layer_name)(PoolingLayer(static(:global), $(Meta.quot(op)); p)) + end + + function Base.show(io::IO, ::MIME"text/plain", g::$(global_layer_name)) + print(io, string($(Meta.quot(global_layer_name))), "(") + if $(Meta.quot(op)) == :lp + a.layer.op.p == 2 || print(io, ", p=", a.layer.op.p) + end + print(io, ")") + end + + # Adaptive Pooling Layer + @doc $(adaptive_pooling_docstring) @concrete struct $(adaptive_layer_name) <: + AbstractLuxWrapperLayer{:layer} + layer <: PoolingLayer + end + + function $(adaptive_layer_name)(out_size::Tuple{Vararg{IntegerType}}; p=2) + return $(adaptive_layer_name)(PoolingLayer( + static(:adaptive), $(Meta.quot(op)), out_size; p)) + end + + function Base.show(io::IO, ::MIME"text/plain", a::$(adaptive_layer_name)) + print(io, string($(Meta.quot(adaptive_layer_name))), "(", a.layer.mode.out_size) + if $(Meta.quot(op)) == :lp + a.layer.op.p == 2 || print(io, ", p=", a.layer.op.p) + end + print(io, ")") + end + end +end From 15b20e4b9a5cd08db2e5f9dece8dcf12c69c75ca Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 09:18:19 -0400 Subject: [PATCH 76/95] fix: tests using old naming --- test/layers/recurrent_tests.jl | 24 ++++++++++++++---------- test/utils_tests.jl | 4 ++-- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/test/layers/recurrent_tests.jl b/test/layers/recurrent_tests.jl index e2970027bb..b5f2867d07 100644 --- a/test/layers/recurrent_tests.jl +++ b/test/layers/recurrent_tests.jl @@ -109,7 +109,8 @@ end l, back = Zygote.pullback( p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test_throws ErrorException gs.bias + @test_throws ErrorException gs.bias_ih + @test_throws ErrorException gs.bias_hh @test_throws ErrorException gs.hidden_state @test_throws ErrorException gs.memory @@ -122,7 +123,8 @@ end l, back = Zygote.pullback( p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test_throws ErrorException gs.bias + @test_throws ErrorException gs.bias_ih + @test_throws ErrorException gs.bias_hh @test !isnothing(gs.hidden_state) @test_throws ErrorException gs.memory @@ -135,7 +137,8 @@ end l, back = Zygote.pullback( p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test_throws ErrorException gs.bias + @test_throws ErrorException gs.bias_ih + @test_throws ErrorException gs.bias_hh @test_throws ErrorException gs.hidden_state @test !isnothing(gs.memory) @@ -147,19 +150,20 @@ end l, back = Zygote.pullback( p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test_throws ErrorException gs.bias + @test_throws ErrorException gs.bias_ih + @test_throws ErrorException gs.bias_hh @test !isnothing(gs.hidden_state) @test !isnothing(gs.memory) lstm = LSTMCell(3 => 5; use_bias=true, train_state=true, train_memory=true) ps, st = Lux.setup(rng, lstm) .|> dev - ps = merge( - _ps, (bias=ps.bias, hidden_state=ps.hidden_state, memory=ps.memory)) + ps = merge(_ps, (; ps.bias_ih, ps.bias_hh, ps.hidden_state, ps.memory)) (y, carry), _ = Lux.apply(lstm, x, ps, st) l, back = Zygote.pullback( p -> sum(abs2, 0 .- sum(lstm(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test !isnothing(gs.bias) + @test !isnothing(gs.bias_ih) + @test !isnothing(gs.bias_hh) @test !isnothing(gs.hidden_state) @test !isnothing(gs.memory) end @@ -213,7 +217,8 @@ end @test carry == _carry l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) gs = back(one(l))[1] - @test_throws ErrorException gs.bias + @test_throws ErrorException gs.bias_ih + @test_throws ErrorException gs.bias_hh @test_throws ErrorException gs.hidden_state gru = GRUCell(3 => 5; use_bias=false, train_state=true) @@ -227,8 +232,7 @@ end gru = GRUCell(3 => 5; use_bias=true, train_state=true) ps, st = Lux.setup(rng, gru) .|> dev - ps = merge( - _ps, (bias_h=ps.bias_h, bias_i=ps.bias_i, hidden_state=ps.hidden_state)) + ps = merge(_ps, (; ps.bias_ih, ps.bias_hh, ps.hidden_state)) (y, carry), _ = Lux.apply(gru, x, ps, st) @test carry == _carry l, back = Zygote.pullback(p -> sum(abs2, 0 .- sum(gru(x, p, st)[1][1])), ps) diff --git a/test/utils_tests.jl b/test/utils_tests.jl index 5c6c74b750..fe7811cffa 100644 --- a/test/utils_tests.jl +++ b/test/utils_tests.jl @@ -113,13 +113,13 @@ end @test length(Zygote.gradient(l2reg, ps)) == 1 end -@testitem "Utils.init_hidden_state" setup=[SharedTestSetup] tags=[:recurrent_layers] begin +@testitem "Utils.init_rnn_hidden_state" setup=[SharedTestSetup] tags=[:recurrent_layers] begin rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES rnn = RNNCell(3 => 5; init_state=Lux.zeros32) x = randn(rng, Float32, 3, 2, 2) - @test Lux.Utils.init_hidden_state(rng, rnn, view(dev(x), :, 1, :)) == + @test Lux.Utils.init_rnn_hidden_state(rng, rnn, view(dev(x), :, 1, :)) == aType(zeros(Float32, 5, 2)) end end From 51956a93356df4876b23aac0e1b5f21c0f5f3d82 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 09:47:22 -0400 Subject: [PATCH 77/95] test: remove unnecessary Enzyme runtime API --- docs/src/introduction/updating_to_v1.md | 2 ++ test/layers/recurrent_tests.jl | 48 +++++++++++-------------- 2 files changed, 23 insertions(+), 27 deletions(-) diff --git a/docs/src/introduction/updating_to_v1.md b/docs/src/introduction/updating_to_v1.md index 4178ae38bf..845905586c 100644 --- a/docs/src/introduction/updating_to_v1.md +++ b/docs/src/introduction/updating_to_v1.md @@ -140,3 +140,5 @@ abstraction. align with Pytorch. Both are controlled using `init_bias` and `use_bias`. - [`ConvTranspose`](@ref) allows `flipkernel=true` via `cross_correlation=true`. This makes it efficient for MIOpen. +- Pooling Layers based on lpnorm have been added -- [`LPPool`](@ref), + [`GlobalLPPool`](@ref), and [`AdaptiveLPPool`](@ref). diff --git a/test/layers/recurrent_tests.jl b/test/layers/recurrent_tests.jl index b5f2867d07..590c7338e9 100644 --- a/test/layers/recurrent_tests.jl +++ b/test/layers/recurrent_tests.jl @@ -1,5 +1,4 @@ @testitem "RNNCell" setup=[SharedTestSetup] tags=[:recurrent_layers] begin - Enzyme.API.runtimeActivity!(true) rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES @@ -8,7 +7,7 @@ RNNCell(3 => 5, identity; use_bias=false), RNNCell(3 => 5, identity; use_bias=false, train_state=false)) display(rnncell) - ps, st = Lux.setup(rng, rnncell) .|> dev + ps, st = Lux.setup(rng, rnncell) |> dev for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType (y, carry), st_ = Lux.apply(rnncell, x, ps, st) @@ -36,10 +35,10 @@ RNNCell(3 => 5, identity; use_bias=true, train_state=true)) rnn_no_trainable_state = RNNCell( 3 => 5, identity; use_bias=false, train_state=false) - _ps, _st = Lux.setup(rng, rnn_no_trainable_state) .|> dev + _ps, _st = Lux.setup(rng, rnn_no_trainable_state) |> dev rnncell = RNNCell(3 => 5, identity; use_bias=false, train_state=true) - ps, st = Lux.setup(rng, rnncell) .|> dev + ps, st = Lux.setup(rng, rnncell) |> dev ps = merge(_ps, (hidden_state=ps.hidden_state,)) for x_size in ((3, 2), (3,)) @@ -60,14 +59,13 @@ end @testitem "LSTMCell" setup=[SharedTestSetup] tags=[:recurrent_layers] begin - Enzyme.API.runtimeActivity!(true) rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES for lstmcell in (LSTMCell(3 => 5), LSTMCell(3 => 5; use_bias=true), LSTMCell(3 => 5; use_bias=false)) display(lstmcell) - ps, st = Lux.setup(rng, lstmcell) .|> dev + ps, st = Lux.setup(rng, lstmcell) |> dev for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType @@ -97,12 +95,12 @@ end x = randn(rng, Float32, x_size...) |> aType _lstm = LSTMCell( 3 => 5; use_bias=false, train_state=false, train_memory=false) - _ps, _st = Lux.setup(rng, _lstm) .|> dev + _ps, _st = Lux.setup(rng, _lstm) |> dev (_y, _carry), _ = Lux.apply(_lstm, x, _ps, _st) lstm = LSTMCell( 3 => 5; use_bias=false, train_state=false, train_memory=false) - ps, st = Lux.setup(rng, lstm) .|> dev + ps, st = Lux.setup(rng, lstm) |> dev ps = _ps (y, carry), _ = Lux.apply(lstm, x, ps, st) @test carry == _carry @@ -116,7 +114,7 @@ end lstm = LSTMCell( 3 => 5; use_bias=false, train_state=true, train_memory=false) - ps, st = Lux.setup(rng, lstm) .|> dev + ps, st = Lux.setup(rng, lstm) |> dev ps = merge(_ps, (hidden_state=ps.hidden_state,)) (y, carry), _ = Lux.apply(lstm, x, ps, st) @test carry == _carry @@ -130,7 +128,7 @@ end lstm = LSTMCell( 3 => 5; use_bias=false, train_state=false, train_memory=true) - ps, st = Lux.setup(rng, lstm) .|> dev + ps, st = Lux.setup(rng, lstm) |> dev ps = merge(_ps, (memory=ps.memory,)) (y, carry), _ = Lux.apply(lstm, x, ps, st) @test carry == _carry @@ -143,7 +141,7 @@ end @test !isnothing(gs.memory) lstm = LSTMCell(3 => 5; use_bias=false, train_state=true, train_memory=true) - ps, st = Lux.setup(rng, lstm) .|> dev + ps, st = Lux.setup(rng, lstm) |> dev ps = merge(_ps, (hidden_state=ps.hidden_state, memory=ps.memory)) (y, carry), _ = Lux.apply(lstm, x, ps, st) @test carry == _carry @@ -156,7 +154,7 @@ end @test !isnothing(gs.memory) lstm = LSTMCell(3 => 5; use_bias=true, train_state=true, train_memory=true) - ps, st = Lux.setup(rng, lstm) .|> dev + ps, st = Lux.setup(rng, lstm) |> dev ps = merge(_ps, (; ps.bias_ih, ps.bias_hh, ps.hidden_state, ps.memory)) (y, carry), _ = Lux.apply(lstm, x, ps, st) l, back = Zygote.pullback( @@ -172,14 +170,13 @@ end end @testitem "GRUCell" setup=[SharedTestSetup] tags=[:recurrent_layers] begin - Enzyme.API.runtimeActivity!(true) rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES for grucell in (GRUCell(3 => 5), GRUCell(3 => 5; use_bias=true), GRUCell(3 => 5; use_bias=false)) display(grucell) - ps, st = Lux.setup(rng, grucell) .|> dev + ps, st = Lux.setup(rng, grucell) |> dev for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType @@ -207,11 +204,11 @@ end for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType _gru = GRUCell(3 => 5; use_bias=false, train_state=false) - _ps, _st = Lux.setup(rng, _gru) .|> dev + _ps, _st = Lux.setup(rng, _gru) |> dev (_y, _carry), _ = Lux.apply(_gru, x, _ps, _st) gru = GRUCell(3 => 5; use_bias=false, train_state=false) - ps, st = Lux.setup(rng, gru) .|> dev + ps, st = Lux.setup(rng, gru) |> dev ps = _ps (y, carry), _ = Lux.apply(gru, x, ps, st) @test carry == _carry @@ -222,7 +219,7 @@ end @test_throws ErrorException gs.hidden_state gru = GRUCell(3 => 5; use_bias=false, train_state=true) - ps, st = Lux.setup(rng, gru) .|> dev + ps, st = Lux.setup(rng, gru) |> dev ps = merge(_ps, (hidden_state=ps.hidden_state,)) (y, carry), _ = Lux.apply(gru, x, ps, st) @test carry == _carry @@ -231,7 +228,7 @@ end @test !isnothing(gs.hidden_state) gru = GRUCell(3 => 5; use_bias=true, train_state=true) - ps, st = Lux.setup(rng, gru) .|> dev + ps, st = Lux.setup(rng, gru) |> dev ps = merge(_ps, (; ps.bias_ih, ps.bias_hh, ps.hidden_state)) (y, carry), _ = Lux.apply(gru, x, ps, st) @test carry == _carry @@ -244,7 +241,6 @@ end end @testitem "StatefulRecurrentCell" setup=[SharedTestSetup] tags=[:recurrent_layers] begin - Enzyme.API.runtimeActivity!(true) rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES @@ -258,7 +254,7 @@ end for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType - ps, st = Lux.setup(rng, rnn) .|> dev + ps, st = Lux.setup(rng, rnn) |> dev y, st_ = rnn(x, ps, st) @@ -292,7 +288,6 @@ end end @testitem "Recurrence" setup=[SharedTestSetup] tags=[:recurrent_layers] begin - Enzyme.API.runtimeActivity!(true) rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES @@ -318,7 +313,7 @@ end (ntuple(identity, ndims(x) - 2)..., ndims(x), ndims(x) - 1)) end - ps, st = Lux.setup(rng, rnn) .|> dev + ps, st = Lux.setup(rng, rnn) |> dev y, st_ = rnn(x, ps, st) y_, st__ = rnn_seq(x, ps, st) @@ -343,7 +338,7 @@ end randn(rng, Float32, 3, 4) |> aType, Tuple(randn(rng, Float32, 3) for _ in 1:4) .|> aType, [randn(rng, Float32, 3) for _ in 1:4] .|> aType) - ps, st = Lux.setup(rng, rnn) .|> dev + ps, st = Lux.setup(rng, rnn) |> dev y, st_ = rnn(x, ps, st) y_, st__ = rnn_seq(x, ps, st) @@ -383,7 +378,7 @@ end init_state=(rng, args...; kwargs...) -> zeros(args...; kwargs...), init_bias=(rng, args...; kwargs...) -> zeros(args...; kwargs...)); return_sequence=true) - ps, st = Lux.setup(rng, encoder) .|> dev + ps, st = Lux.setup(rng, encoder) |> dev m2 = reshape([0.5, 0.0, 0.7, 0.8], 1, :, 1) |> aType res, _ = encoder(m2, ps, st) @@ -392,7 +387,6 @@ end end @testitem "Bidirectional" setup=[SharedTestSetup] tags=[:recurrent_layers] begin - Enzyme.API.runtimeActivity!(true) rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES @@ -404,7 +398,7 @@ end # Batched Time Series x = randn(rng, Float32, 3, 4, 2) |> aType - ps, st = Lux.setup(rng, bi_rnn) .|> dev + ps, st = Lux.setup(rng, bi_rnn) |> dev y, st_ = bi_rnn(x, ps, st) y_, st__ = bi_rnn_no_merge(x, ps, st) @@ -440,7 +434,7 @@ end # Batched Time Series x = randn(rng, Float32, 3, 4, 2) |> aType - ps, st = Lux.setup(rng, bi_rnn) .|> dev + ps, st = Lux.setup(rng, bi_rnn) |> dev y, st_ = bi_rnn(x, ps, st) y_, st__ = bi_rnn_no_merge(x, ps, st) From 8931b199dd918916963265f20a13c47a362e7830 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 10:18:55 -0400 Subject: [PATCH 78/95] test: Enzyme with runtimeActivity enabled --- test/contrib/freeze_tests.jl | 2 -- test/helpers/training_tests.jl | 6 +----- test/layers/dropout_tests.jl | 3 --- test/shared_testsetup.jl | 3 +++ 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/test/contrib/freeze_tests.jl b/test/contrib/freeze_tests.jl index 96c449135d..22fef5bbb1 100644 --- a/test/contrib/freeze_tests.jl +++ b/test/contrib/freeze_tests.jl @@ -1,5 +1,4 @@ @testitem "All Parameter Freezing" setup=[SharedTestSetup] tags=[:contrib] begin - Enzyme.API.runtimeActivity!(true) rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES @@ -66,7 +65,6 @@ end @testitem "Partial Freezing" setup=[SharedTestSetup] tags=[:contrib] begin using Lux.Experimental: FrozenLayer - Enzyme.API.runtimeActivity!(true) rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES diff --git a/test/helpers/training_tests.jl b/test/helpers/training_tests.jl index fc6c307d9f..67897b17fc 100644 --- a/test/helpers/training_tests.jl +++ b/test/helpers/training_tests.jl @@ -20,7 +20,7 @@ end @testitem "AbstractADTypes" setup=[SharedTestSetup] tags=[:helpers] begin - using ADTypes, Optimisers, Enzyme + using ADTypes, Optimisers function _loss_function(model, ps, st, data) y, st = model(data, ps, st) @@ -52,7 +52,6 @@ end @testitem "Training API" setup=[SharedTestSetup] tags=[:helpers] begin using ADTypes, Optimisers - import Enzyme, Tracker, ReverseDiff, Zygote mse = MSELoss() @@ -128,9 +127,6 @@ end @testitem "Enzyme: Invalidate Cache on State Update" setup=[SharedTestSetup] tags=[:helpers] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin using ADTypes, Optimisers - using Enzyme - - Enzyme.API.runtimeActivity!(true) mse = MSELoss() diff --git a/test/layers/dropout_tests.jl b/test/layers/dropout_tests.jl index 5d377d9b6e..62ecafcd0b 100644 --- a/test/layers/dropout_tests.jl +++ b/test/layers/dropout_tests.jl @@ -64,9 +64,6 @@ end end @testitem "VariationalHiddenDropout" setup=[SharedTestSetup] tags=[:normalize_layers] begin - using Enzyme - Enzyme.API.runtimeActivity!(true) - rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index aba3646de4..85abd32042 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -1,5 +1,8 @@ @testsetup module SharedTestSetup +using Enzyme +Enzyme.API.runtimeActivity!(true) + include("setup_modes.jl") import Reexport: @reexport From 77dda0af5fd81c40eba7a6bee0a52988aebf1fc3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 11:10:58 -0400 Subject: [PATCH 79/95] feat: add outpad to conv transpose --- docs/src/introduction/updating_to_v1.md | 2 + ext/LuxFluxExt.jl | 70 ++++++------------------- src/layers/conv.jl | 24 ++++++--- test/layers/conv_tests.jl | 22 ++++++++ test/layers/normalize_tests.jl | 2 +- test/transform/flux_tests.jl | 39 ++------------ 6 files changed, 59 insertions(+), 100 deletions(-) diff --git a/docs/src/introduction/updating_to_v1.md b/docs/src/introduction/updating_to_v1.md index 845905586c..665b4c0747 100644 --- a/docs/src/introduction/updating_to_v1.md +++ b/docs/src/introduction/updating_to_v1.md @@ -140,5 +140,7 @@ abstraction. align with Pytorch. Both are controlled using `init_bias` and `use_bias`. - [`ConvTranspose`](@ref) allows `flipkernel=true` via `cross_correlation=true`. This makes it efficient for MIOpen. +- [`ConvTranspose`](@ref) now has an `outpad` keyword argument, which is used to increase + the size of the output in the desired dimensions. - Pooling Layers based on lpnorm have been added -- [`LPPool`](@ref), [`GlobalLPPool`](@ref), and [`AdaptiveLPPool`](@ref). diff --git a/ext/LuxFluxExt.jl b/ext/LuxFluxExt.jl index 617980d02f..50a5491c05 100644 --- a/ext/LuxFluxExt.jl +++ b/ext/LuxFluxExt.jl @@ -113,15 +113,16 @@ function Lux.convert_flux_model( out_chs, in_chs = size(l.weight)[(end - 1):end] groups = l.groups pad = l.pad isa Flux.SamePad ? SamePad() : l.pad + outpad = hasfield(typeof(l), :outpad) ? l.outpad : 0 if preserve_ps_st _bias = l.bias isa Bool ? nothing : vec(copy(l.bias)) - return Lux.ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride, - pad, l.dilation, groups, use_bias=!(l.bias isa Bool), + return Lux.ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride, pad, + outpad, l.dilation, groups, use_bias=!(l.bias isa Bool), init_weight=Returns(Lux.maybe_flip_conv_weight(l.weight)), init_bias=Returns(_bias)) else return Lux.ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride, pad, - l.dilation, groups, use_bias=!(l.bias isa Bool)) + outpad, l.dilation, groups, use_bias=!(l.bias isa Bool)) end end @@ -176,58 +177,6 @@ function Lux.convert_flux_model(l::Flux.Upsample{mode}; kwargs...) where {mode} return Lux.Upsample(mode; l.scale, l.size, align_corners=false) end -function Lux.convert_flux_model( - l::Flux.RNNCell; preserve_ps_st::Bool=false, force_preserve::Bool=false) - out_dims, in_dims = size(l.Wi) - if preserve_ps_st - if force_preserve - throw(FluxModelConversionException("Recurrent Cell: $(typeof(l)) for Flux uses a `reset!` mechanism which hasn't been extensively tested with `FluxLayer`. Rewrite the model manually to use `RNNCell`.")) - end - @warn "Preserving Parameters: `Wh` & `Wi` for `Flux.RNNCell` is ambiguous in Lux and hence not supported. Ignoring these parameters." maxlog=1 - return Lux.RNNCell(in_dims => out_dims, l.σ; init_bias=Returns(copy(l.b)), - init_state=Returns(copy(l.state0))) - else - return Lux.RNNCell(in_dims => out_dims, l.σ) - end -end - -function Lux.convert_flux_model( - l::Flux.LSTMCell; preserve_ps_st::Bool=false, force_preserve::Bool=false) - _out_dims, in_dims = size(l.Wi) - out_dims = _out_dims ÷ 4 - if preserve_ps_st - if force_preserve - throw(FluxModelConversionException("Recurrent Cell: $(typeof(l)) for Flux uses a `reset!` mechanism which hasn't been extensively tested with `FluxLayer`. Rewrite the model manually to use `LSTMCell`.")) - end - @warn "Preserving Parameters: `Wh` & `Wi` for `Flux.LSTMCell` is ambiguous in Lux \ - and hence not supported. Ignoring these parameters." maxlog=1 - bs = LuxOps.multigate(l.b, Val(4)) - _s, _m = copy.(l.state0) - return Lux.LSTMCell(in_dims => out_dims; init_bias=Returns.(bs), - init_state=Returns(_s), init_memory=Returns(_m)) - else - return Lux.LSTMCell(in_dims => out_dims) - end -end - -function Lux.convert_flux_model( - l::Flux.GRUCell; preserve_ps_st::Bool=false, force_preserve::Bool=false) - _out_dims, in_dims = size(l.Wi) - out_dims = _out_dims ÷ 3 - if preserve_ps_st - if force_preserve - throw(FluxModelConversionException("Recurrent Cell: $(typeof(l)) for Flux uses a `reset!` mechanism which hasn't been extensively tested with `FluxLayer`. Rewrite the model manually to use `GRUCell`.")) - end - @warn "Preserving Parameters: `Wh` & `Wi` for `Flux.GRUCell` is ambiguous in Lux \ - and hence not supported. Ignoring these parameters." maxlog=1 - bs = LuxOps.multigate(l.b, Val(3)) - return Lux.GRUCell( - in_dims => out_dims; init_bias=Returns.(bs), init_state=Returns(copy(l.state0))) - else - return Lux.GRUCell(in_dims => out_dims) - end -end - function Lux.convert_flux_model( l::Flux.BatchNorm; preserve_ps_st::Bool=false, force_preserve::Bool=false) if preserve_ps_st @@ -268,4 +217,15 @@ function Lux.convert_flux_model(l::T; kwargs...) where {T <: _INVALID_TRANSFORMA throw(FluxModelConversionException("Transformation of type $(T) is not supported.")) end +for cell in (:RNNCell, :LSTMCell, :GRUCell) + msg = "Recurrent Cell: $(cell) for Flux has semantical difference with Lux, \ + mostly in-terms of how the bias term is dealt with. Lux aligns with the Pytorch \ + definition of these models and hence converting `Flux.$(cell)` to `Lux.$(cell) \ + is not possible. Rewrite the model manually." + @eval function Lux.convert_flux_model(::Flux.$(cell); preserve_ps_st::Bool=false, + force_preserve::Bool=false) + throw(FluxModelConversionException($msg)) + end +end + end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index b026384069..59e9fe021d 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -17,14 +17,14 @@ end CRC.@non_differentiable calc_padding(::Any...) function conv_transpose_dims( - x::AbstractArray, weight::AbstractArray; padding, stride, dilation, groups) + x::AbstractArray, weight::AbstractArray; padding, stride, dilation, groups, outpad) # Calculate size of "input", from ∇conv_data()'s perspective... - function calc_dim(xsz, wsz, stride, dilation, pad) - return (xsz - 1) * stride + 1 + (wsz - 1) * dilation - pad + function calc_dim(xsz, wsz, stride, dilation, pad, outpad) + return (xsz - 1) * stride + 1 + (wsz - 1) * dilation - pad + outpad end combined_pad = ntuple(i -> padding[2i - 1] + padding[2i], length(padding) ÷ 2) I = map(calc_dim, size(x)[1:(end - 2)], size(weight)[1:(end - 2)], - stride, dilation, combined_pad) + stride, dilation, combined_pad, outpad) C_in = size(weight)[end - 1] * groups C_out = size(weight)[end] batch_size = size(x)[end] @@ -221,7 +221,7 @@ end @doc doc""" ConvTranspose(k::NTuple{N,Integer}, (in_chs => out_chs)::Pair{<:Integer,<:Integer}, activation=identity; init_weight=glorot_uniform, init_bias=zeros32, - stride=1, pad=0, dilation=1, groups=1, use_bias=True(), + stride=1, pad=0, outpad=0, dilation=1, groups=1, use_bias=True(), cross_correlation=False()) Standard convolutional transpose layer. @@ -263,6 +263,9 @@ Standard convolutional transpose layer. - `use_bias`: Trainable bias can be disabled entirely by setting this to `false`. - `cross_correlation`: If `true`, perform transposed cross-correlation instead of transposed convolution. + - `outpad`: To converse [`Conv`](@ref) inversability when `stride > 1`, `outpad` can be + used to increase the size of the output in the desired dimensions. Whereas `pad` is used + to zero-pad the input, `outpad` only affects the output shape. # Extended Help @@ -288,6 +291,7 @@ Standard convolutional transpose layer. kernel_size <: Tuple{Vararg{IntegerType}} stride <: Tuple{Vararg{IntegerType}} pad <: Tuple{Vararg{IntegerType}} + outpad <: Tuple{Vararg{IntegerType}} dilation <: Tuple{Vararg{IntegerType}} groups <: IntegerType init_weight @@ -299,7 +303,7 @@ end function ConvTranspose( k::Tuple{Vararg{IntegerType}}, ch::Pair{<:IntegerType, <:IntegerType}, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, stride=1, pad=0, dilation=1, groups=1, + init_bias=zeros32, stride=1, pad=0, outpad=0, dilation=1, groups=1, use_bias::BoolType=True(), cross_correlation::BoolType=False()) stride = Utils.expand(Val(length(k)), stride) dilation = Utils.expand(Val(length(k)), dilation) @@ -308,12 +312,14 @@ function ConvTranspose( else calc_padding(pad, k, dilation, stride) end + outpad = Utils.expand(Val(length(k)), outpad) @argcheck ch[2] % groups==0 DimensionMismatch("Input channel dimension must be divisible by groups.") @argcheck ch[1] % groups==0 DimensionMismatch("Output channel dimension must be divisible by groups.") @argcheck allequal(length, (stride, dilation, k)) - return ConvTranspose(activation, first(ch), last(ch), k, stride, pad, dilation, groups, + return ConvTranspose( + activation, first(ch), last(ch), k, stride, pad, outpad, dilation, groups, init_weight, init_bias, static(use_bias), static(cross_correlation)) end @@ -335,7 +341,8 @@ end function (c::ConvTranspose)(x::AbstractArray, ps, st::NamedTuple) y = match_eltype(c, ps, st, x) cdims = construct_crosscor_convdims(c.cross_correlation, - conv_transpose_dims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups)) + conv_transpose_dims( + y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups, c.outpad)) bias = safe_getproperty(ps, Val(:bias)) σ = NNlib.fast_act(c.activation, y) return bias_activation!!(σ, conv_transpose(y, ps.weight, cdims), bias), st @@ -350,6 +357,7 @@ function Base.show(io::IO, l::ConvTranspose) all(==(1), l.dilation) || print(io, ", dilation=", PrettyPrinting.tuple_string(l.dilation)) (l.groups == 1) || print(io, ", groups=", l.groups) + all(==(0), l.outpad) || print(io, ", outpad=", PrettyPrinting.tuple_string(l.outpad)) has_bias(l) || print(io, ", use_bias=false") print(io, ")") end diff --git a/test/layers/conv_tests.jl b/test/layers/conv_tests.jl index 7248b18254..788cf13112 100644 --- a/test/layers/conv_tests.jl +++ b/test/layers/conv_tests.jl @@ -617,5 +617,27 @@ end @test_throws DimensionMismatch layer(x, ps, st) end + + @testest "with Output Padding" begin + m1 = ConvTranspose((3, 5), 3 => 6; stride=3) + m2 = ConvTranspose((3, 5), 3 => 6; stride=3, outpad=(1, 0)) + + ps1, st1 = Lux.setup(rng, m1) |> dev + ps2, st2 = Lux.setup(rng, m2) |> dev + + x = randn(Float32, 10, 11, 3, 2) |> aType + @test size(m1(x, ps1, st1)[1])[1:2] .+ (1, 0) == size(m2(x, ps2, st2)[1])[1:2] + + m1 = ConvTranspose((3, 5, 3), 3 => 6; stride=3) + m2 = ConvTranspose((3, 5, 3), 3 => 6; stride=3, outpad=(1, 0, 1)) + + ps1, st1 = Lux.setup(rng, m1) |> dev + ps2, st2 = Lux.setup(rng, m2) |> dev + + x = randn(Float32, 10, 11, 12, 3, 2) |> aType + + @test size(m1(x, ps1, st1)[1])[1:3] .+ (1, 0, 1) == + size(m2(x, ps2, st2)[1])[1:3] + end end end diff --git a/test/layers/normalize_tests.jl b/test/layers/normalize_tests.jl index 7d8c9db732..46474dbf1a 100644 --- a/test/layers/normalize_tests.jl +++ b/test/layers/normalize_tests.jl @@ -128,7 +128,7 @@ end __f = let m = m, x = x, st = st ps -> sum(first(m(x, ps, st))) end - test_gradients(__f, ps; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoEnzyme()]) + test_gradients(__f, ps; atol=1.0f-3, rtol=1.0f-3) @testset "affine: $affine" for affine in (true, false) m = GroupNorm(2, 2; affine) diff --git a/test/transform/flux_tests.jl b/test/transform/flux_tests.jl index a6530e27bb..fe7e5fceef 100644 --- a/test/transform/flux_tests.jl +++ b/test/transform/flux_tests.jl @@ -327,53 +327,20 @@ @testset "Recurrent" begin @testset "RNNCell" begin model = Flux.RNNCell(2 => 3) |> fdev(dev) - x = rand(Float32, 2, 4) |> aType - - model_lux = tolux(model) - ps, st = Lux.setup(StableRNG(12345), model_lux) .|> dev - - @test size(model_lux(x, ps, st)[1][1]) == (3, 4) - + @test_throws Lux.FluxModelConversionException tolux(model) @test_throws Lux.FluxModelConversionException toluxforce(model) - - model_lux = toluxpsst(model) - ps, st = Lux.setup(StableRNG(12345), model_lux) .|> dev - - @test size(model_lux(x, ps, st)[1][1]) == (3, 4) end @testset "LSTMCell" begin model = Flux.LSTMCell(2 => 3) |> fdev(dev) - x = rand(Float32, 2, 4) |> aType - - model_lux = tolux(model) - ps, st = Lux.setup(StableRNG(12345), model_lux) .|> dev - - @test size(model_lux(x, ps, st)[1][1]) == (3, 4) - + @test_throws Lux.FluxModelConversionException tolux(model) @test_throws Lux.FluxModelConversionException toluxforce(model) - - model_lux = toluxpsst(model) - ps, st = Lux.setup(StableRNG(12345), model_lux) .|> dev - - @test size(model_lux(x, ps, st)[1][1]) == (3, 4) end @testset "GRUCell" begin model = Flux.GRUCell(2 => 3) |> fdev(dev) - x = rand(Float32, 2, 4) |> aType - - model_lux = tolux(model) - ps, st = Lux.setup(StableRNG(12345), model_lux) .|> dev - - @test size(model_lux(x, ps, st)[1][1]) == (3, 4) - + @test_throws Lux.FluxModelConversionException tolux(model) @test_throws Lux.FluxModelConversionException toluxforce(model) - - model_lux = toluxpsst(model) - ps, st = Lux.setup(StableRNG(12345), model_lux) .|> dev - - @test size(model_lux(x, ps, st)[1][1]) == (3, 4) end end From 03c57e5b7ef7da237c6a58784f3b09b25b6aba95 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 11:19:53 -0400 Subject: [PATCH 80/95] docs: move docs around --- README.md | 71 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 38 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index f21a9ae7e5..033be3ebdd 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,43 @@ st_opt, ps = Optimisers.update(st_opt, ps, gs) Look in the [examples](/examples/) directory for self-contained usage examples. The [documentation](https://lux.csail.mit.edu) has examples sorted into proper categories. -## 🧪 Testing +## 🆘 Getting Help + +For usage related questions, please use [Github Discussions](https://github.com/orgs/LuxDL/discussions) which allows questions and answers to be indexed. To report bugs use [github issues](https://github.com/LuxDL/Lux.jl/issues) or even better send in a [pull request](https://github.com/LuxDL/Lux.jl/pulls). + +## 🧑‍🔬 Citation + +If you found this library to be useful in academic work, then please cite: + +```bibtex +@software{pal2023lux, + author = {Pal, Avik}, + title = {{Lux: Explicit Parameterization of Deep Neural Networks in Julia}}, + month = apr, + year = 2023, + note = {If you use this software, please cite it as below.}, + publisher = {Zenodo}, + version = {v0.5.0}, + doi = {10.5281/zenodo.7808904}, + url = {https://doi.org/10.5281/zenodo.7808904} +} + +@thesis{pal2023efficient, + title = {{On Efficient Training \& Inference of Neural Differential Equations}}, + author = {Pal, Avik}, + year = {2023}, + school = {Massachusetts Institute of Technology} +} +``` + +Also consider starring [our github repo](https://github.com/LuxDL/Lux.jl/). + +## 🧑‍💻 Contributing + +This section is somewhat incomplete. You can contribute by contributing to finishing this +section 😜. + +### 🧪 Testing The full test of `Lux.jl` takes a long time, here's how to test a portion of the code. @@ -125,36 +161,5 @@ ReTestItems.runtests("tests/"; name = "NAME OF THE TEST") For the `SkipConnection` tests that would be: ```julia -ReTestItems.runtests("tests/"; name = SkipConnection) -``` - -## 🆘 Getting Help - -For usage related questions, please use [Github Discussions](https://github.com/orgs/LuxDL/discussions) which allows questions and answers to be indexed. To report bugs use [github issues](https://github.com/LuxDL/Lux.jl/issues) or even better send in a [pull request](https://github.com/LuxDL/Lux.jl/pulls). - -## 🧑‍🔬 Citation - -If you found this library to be useful in academic work, then please cite: - -```bibtex -@software{pal2023lux, - author = {Pal, Avik}, - title = {{Lux: Explicit Parameterization of Deep Neural Networks in Julia}}, - month = apr, - year = 2023, - note = {If you use this software, please cite it as below.}, - publisher = {Zenodo}, - version = {v0.5.0}, - doi = {10.5281/zenodo.7808904}, - url = {https://doi.org/10.5281/zenodo.7808904} -} - -@thesis{pal2023efficient, - title = {{On Efficient Training \& Inference of Neural Differential Equations}}, - author = {Pal, Avik}, - year = {2023}, - school = {Massachusetts Institute of Technology} -} +ReTestItems.runtests("tests/"; name = "SkipConnection") ``` - -Also consider starring [our github repo](https://github.com/LuxDL/Lux.jl/). From abf57f54c1fbd80ac0e9d63a0897d3131f9ce8db Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 11:20:35 -0400 Subject: [PATCH 81/95] chore: run formatter --- ext/LuxFluxExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/LuxFluxExt.jl b/ext/LuxFluxExt.jl index 50a5491c05..d0f89b2b0d 100644 --- a/ext/LuxFluxExt.jl +++ b/ext/LuxFluxExt.jl @@ -222,8 +222,8 @@ for cell in (:RNNCell, :LSTMCell, :GRUCell) mostly in-terms of how the bias term is dealt with. Lux aligns with the Pytorch \ definition of these models and hence converting `Flux.$(cell)` to `Lux.$(cell) \ is not possible. Rewrite the model manually." - @eval function Lux.convert_flux_model(::Flux.$(cell); preserve_ps_st::Bool=false, - force_preserve::Bool=false) + @eval function Lux.convert_flux_model( + ::Flux.$(cell); preserve_ps_st::Bool=false, force_preserve::Bool=false) throw(FluxModelConversionException($msg)) end end From 74d22c4ab84124e6bf09432b0f8ceec5a0500fb4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 11:57:21 -0400 Subject: [PATCH 82/95] test: more testing for ConvTranspose --- src/layers/conv.jl | 1 + test/layers/conv_tests.jl | 267 ++++++++++++++++++++------------------ 2 files changed, 143 insertions(+), 125 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 59e9fe021d..5a1d8a586f 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -359,6 +359,7 @@ function Base.show(io::IO, l::ConvTranspose) (l.groups == 1) || print(io, ", groups=", l.groups) all(==(0), l.outpad) || print(io, ", outpad=", PrettyPrinting.tuple_string(l.outpad)) has_bias(l) || print(io, ", use_bias=false") + known(l.cross_correlation) && print(io, ", cross_correlation=true") print(io, ")") end diff --git a/test/layers/conv_tests.jl b/test/layers/conv_tests.jl index 788cf13112..0db8b91144 100644 --- a/test/layers/conv_tests.jl +++ b/test/layers/conv_tests.jl @@ -470,174 +470,191 @@ end rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - x = randn(Float32, 5, 5, 1, 1) |> aType - layer = Conv((3, 3), 1 => 1) - ps, st = Lux.setup(rng, layer) |> dev - y = layer(x, ps, st)[1] - - layer = ConvTranspose((3, 3), 1 => 1) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @jet layer(y, ps, st) + @testset for cross_correlation in (true, false) + x = randn(Float32, 5, 5, 1, 1) |> aType + layer = Conv((3, 3), 1 => 1) + ps, st = Lux.setup(rng, layer) |> dev + y = layer(x, ps, st)[1] - x_hat1 = layer(y, ps, st)[1] + layer = ConvTranspose((3, 3), 1 => 1; cross_correlation) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev - layer = ConvTranspose((3, 3), 1 => 1; use_bias=false) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev + @jet layer(y, ps, st) - @jet layer(y, ps, st) + x_hat1 = layer(y, ps, st)[1] - x_hat2 = layer(y, ps, st)[1] + layer = ConvTranspose((3, 3), 1 => 1; use_bias=false, cross_correlation) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev - @test size(x_hat1) == size(x_hat2) == size(x) + @jet layer(y, ps, st) - layer = ConvTranspose((3, 3), 1 => 1) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - x = rand(Float32, 5, 5, 1, 1) |> aType + x_hat2 = layer(y, ps, st)[1] - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) + @test size(x_hat1) == size(x_hat2) == size(x) - x = rand(Float32, 5, 5, 2, 4) |> aType - layer = ConvTranspose((3, 3), 2 => 3) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev + layer = ConvTranspose((3, 3), 1 => 1; cross_correlation) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev + x = rand(Float32, 5, 5, 1, 1) |> aType - @jet layer(x, ps, st) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) - - # test ConvTranspose supports groups argument - x = randn(Float32, 10, 10, 2, 3) |> aType - layer1 = ConvTranspose((3, 3), 2 => 4; pad=SamePad()) - display(layer1) - ps1, st1 = Lux.setup(rng, layer1) |> dev - @test size(ps1.weight) == (3, 3, 4, 2) - @test size(layer1(x, ps1, st1)[1]) == (10, 10, 4, 3) - - layer2 = ConvTranspose((3, 3), 2 => 4; groups=2, pad=SamePad()) - display(layer2) - ps2, st2 = Lux.setup(rng, layer2) |> dev - @test size(ps2.weight) == (3, 3, 2, 2) - @test size(layer1(x, ps1, st1)[1]) == size(layer2(x, ps2, st2)[1]) - - __f = (x, ps) -> sum(first(layer1(x, ps, st1))) - test_gradients(__f, x, ps1; atol=1.0f-3, rtol=1.0f-3) - - __f = (x, ps) -> sum(first(layer2(x, ps, st2))) - test_gradients(__f, x, ps2; atol=1.0f-3, rtol=1.0f-3) - - x = randn(Float32, 10, 2, 1) |> aType - layer = ConvTranspose((3,), 2 => 4; pad=SamePad(), groups=2) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) - @jet layer(x, ps, st) + x = rand(Float32, 5, 5, 2, 4) |> aType + layer = ConvTranspose((3, 3), 2 => 3; cross_correlation) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev - @test size(layer(x, ps, st)[1]) == (10, 4, 1) - @test length(ps.weight) == 3 * (2 * 4) / 2 + @jet layer(x, ps, st) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) + # test ConvTranspose supports groups argument + x = randn(Float32, 10, 10, 2, 3) |> aType + layer1 = ConvTranspose((3, 3), 2 => 4; pad=SamePad(), cross_correlation) + display(layer1) + ps1, st1 = Lux.setup(rng, layer1) |> dev + @test size(ps1.weight) == (3, 3, 4, 2) + @test size(layer1(x, ps1, st1)[1]) == (10, 10, 4, 3) + + layer2 = ConvTranspose( + (3, 3), 2 => 4; groups=2, pad=SamePad(), cross_correlation) + display(layer2) + ps2, st2 = Lux.setup(rng, layer2) |> dev + @test size(ps2.weight) == (3, 3, 2, 2) + @test size(layer1(x, ps1, st1)[1]) == size(layer2(x, ps2, st2)[1]) + + __f = (x, ps) -> sum(first(layer1(x, ps, st1))) + test_gradients(__f, x, ps1; atol=1.0f-3, rtol=1.0f-3) + + __f = (x, ps) -> sum(first(layer2(x, ps, st2))) + test_gradients(__f, x, ps2; atol=1.0f-3, rtol=1.0f-3) + + x = randn(Float32, 10, 2, 1) |> aType + layer = ConvTranspose((3,), 2 => 4; pad=SamePad(), groups=2, cross_correlation) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev - x = randn(Float32, 10, 11, 4, 2) |> aType - layer = ConvTranspose((3, 5), 4 => 4; pad=SamePad(), groups=4) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev + @jet layer(x, ps, st) - @jet layer(x, ps, st) + @test size(layer(x, ps, st)[1]) == (10, 4, 1) + @test length(ps.weight) == 3 * (2 * 4) / 2 - @test size(layer(x, ps, st)[1]) == (10, 11, 4, 2) - @test length(ps.weight) == (3 * 5) * (4 * 4) / 4 + __f = (x, ps) -> sum(first(layer(x, ps, st))) + test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) - __f = (x, ps) -> sum(first(layer(x, ps, st))) - test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) + x = randn(Float32, 10, 11, 4, 2) |> aType + layer = ConvTranspose( + (3, 5), 4 => 4; pad=SamePad(), groups=4, cross_correlation) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev - x = randn(Float32, 10, 11, 4, 2) |> aType - layer = ConvTranspose((3, 5), 4 => 4, tanh; pad=SamePad(), groups=4) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev + @jet layer(x, ps, st) - @jet layer(x, ps, st) - @test size(layer(x, ps, st)[1]) == (10, 11, 4, 2) - @test length(ps.weight) == (3 * 5) * (4 * 4) / 4 + @test size(layer(x, ps, st)[1]) == (10, 11, 4, 2) + @test length(ps.weight) == (3 * 5) * (4 * 4) / 4 - __f = (x, ps) -> sum(first(layer(x, ps, st))) - test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) + __f = (x, ps) -> sum(first(layer(x, ps, st))) + test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) - x = randn(Float32, 10, 11, 12, 3, 2) |> aType - layer = ConvTranspose((3, 5, 3), 3 => 6; pad=SamePad(), groups=3) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev + x = randn(Float32, 10, 11, 4, 2) |> aType + layer = ConvTranspose( + (3, 5), 4 => 4, tanh; pad=SamePad(), groups=4, cross_correlation) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev - @jet layer(x, ps, st) - @test size(layer(x, ps, st)[1]) == (10, 11, 12, 6, 2) - @test length(ps.weight) == (3 * 5 * 3) * (3 * 6) / 3 + @jet layer(x, ps, st) + @test size(layer(x, ps, st)[1]) == (10, 11, 4, 2) + @test length(ps.weight) == (3 * 5) * (4 * 4) / 4 - x = randn(Float32, 10, 11, 12, 3, 2) |> aType - layer = ConvTranspose((3, 5, 3), 3 => 6, tanh; pad=SamePad(), groups=3) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev + __f = (x, ps) -> sum(first(layer(x, ps, st))) + test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3) - @jet layer(x, ps, st) - @test size(layer(x, ps, st)[1]) == (10, 11, 12, 6, 2) - @test length(ps.weight) == (3 * 5 * 3) * (3 * 6) / 3 + x = randn(Float32, 10, 11, 12, 3, 2) |> aType + layer = ConvTranspose( + (3, 5, 3), 3 => 6; pad=SamePad(), groups=3, cross_correlation) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev - @test occursin("groups=2", sprint(show, ConvTranspose((3, 3), 2 => 4; groups=2))) - @test occursin("2 => 4", sprint(show, ConvTranspose((3, 3), 2 => 4; groups=2))) + @jet layer(x, ps, st) + @test size(layer(x, ps, st)[1]) == (10, 11, 12, 6, 2) + @test length(ps.weight) == (3 * 5 * 3) * (3 * 6) / 3 - @testset "SamePad size mismatch LuxDL/Lux.jl#534" begin - layer = ConvTranspose((3,), 2 => 1; pad=SamePad(), stride=2) + x = randn(Float32, 10, 11, 12, 3, 2) |> aType + layer = ConvTranspose( + (3, 5, 3), 3 => 6, tanh; pad=SamePad(), groups=3, cross_correlation) display(layer) - x = ones(Float32, 2, 2, 1) |> aType ps, st = Lux.setup(rng, layer) |> dev - y = first(layer(x, ps, st)) - @test size(y) == (4, 1, 1) @jet layer(x, ps, st) - end + @test size(layer(x, ps, st)[1]) == (10, 11, 12, 6, 2) + @test length(ps.weight) == (3 * 5 * 3) * (3 * 6) / 3 - @testset "Catch Channel Mismatch Early: LuxDL/Lux.jl#455" begin - layer = ConvTranspose((4, 4), 42 => 16; stride=2, pad=1) + @test occursin("groups=2", + sprint(show, ConvTranspose((3, 3), 2 => 4; groups=2, cross_correlation))) + @test occursin("2 => 4", + sprint(show, ConvTranspose((3, 3), 2 => 4; groups=2, cross_correlation))) - x = randn(Float32, 28, 28, 42, 3) |> aType - ps, st = Lux.setup(rng, layer) |> dev + @testset "SamePad size mismatch LuxDL/Lux.jl#534" begin + layer = ConvTranspose( + (3,), 2 => 1; pad=SamePad(), stride=2, cross_correlation) + display(layer) + x = ones(Float32, 2, 2, 1) |> aType + ps, st = Lux.setup(rng, layer) |> dev - @test layer(x, ps, st) isa Any + y = first(layer(x, ps, st)) + @test size(y) == (4, 1, 1) + @jet layer(x, ps, st) + end - x = randn(Float32, 28, 28, 46, 3) |> aType + @testset "Catch Channel Mismatch Early: LuxDL/Lux.jl#455" begin + layer = ConvTranspose((4, 4), 42 => 16; stride=2, pad=1, cross_correlation) - @test_throws DimensionMismatch layer(x, ps, st) + x = randn(Float32, 28, 28, 42, 3) |> aType + ps, st = Lux.setup(rng, layer) |> dev - x = randn(Float32, 28, 28, 23, 3) |> aType + @test layer(x, ps, st) isa Any - @test_throws DimensionMismatch layer(x, ps, st) - end + x = randn(Float32, 28, 28, 46, 3) |> aType - @testest "with Output Padding" begin - m1 = ConvTranspose((3, 5), 3 => 6; stride=3) - m2 = ConvTranspose((3, 5), 3 => 6; stride=3, outpad=(1, 0)) + @test_throws DimensionMismatch layer(x, ps, st) - ps1, st1 = Lux.setup(rng, m1) |> dev - ps2, st2 = Lux.setup(rng, m2) |> dev + x = randn(Float32, 28, 28, 23, 3) |> aType - x = randn(Float32, 10, 11, 3, 2) |> aType - @test size(m1(x, ps1, st1)[1])[1:2] .+ (1, 0) == size(m2(x, ps2, st2)[1])[1:2] + @test_throws DimensionMismatch layer(x, ps, st) + end - m1 = ConvTranspose((3, 5, 3), 3 => 6; stride=3) - m2 = ConvTranspose((3, 5, 3), 3 => 6; stride=3, outpad=(1, 0, 1)) + @testset "with Output Padding" begin + m1 = ConvTranspose((3, 5), 3 => 6; stride=3, cross_correlation) + display(m1) + m2 = ConvTranspose( + (3, 5), 3 => 6; stride=3, outpad=(1, 0), cross_correlation) + display(m2) - ps1, st1 = Lux.setup(rng, m1) |> dev - ps2, st2 = Lux.setup(rng, m2) |> dev + ps1, st1 = Lux.setup(rng, m1) |> dev + ps2, st2 = Lux.setup(rng, m2) |> dev - x = randn(Float32, 10, 11, 12, 3, 2) |> aType + x = randn(Float32, 10, 11, 3, 2) |> aType + @test size(m1(x, ps1, st1)[1])[1:2] .+ (1, 0) == + size(m2(x, ps2, st2)[1])[1:2] - @test size(m1(x, ps1, st1)[1])[1:3] .+ (1, 0, 1) == - size(m2(x, ps2, st2)[1])[1:3] + m1 = ConvTranspose((3, 5, 3), 3 => 6; stride=3, cross_correlation) + display(m1) + m2 = ConvTranspose( + (3, 5, 3), 3 => 6; stride=3, outpad=(1, 0, 1), cross_correlation) + display(m2) + + ps1, st1 = Lux.setup(rng, m1) |> dev + ps2, st2 = Lux.setup(rng, m2) |> dev + + x = randn(Float32, 10, 11, 12, 3, 2) |> aType + + @test size(m1(x, ps1, st1)[1])[1:3] .+ (1, 0, 1) == + size(m2(x, ps2, st2)[1])[1:3] + end end end end From 5afdd5a406ca3970e7ac3eacec1b0595c7b452c4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 12:35:28 -0400 Subject: [PATCH 83/95] test: more comprehensive testing for Pooling operations --- src/layers/pooling.jl | 17 ++++++- test/layers/conv_tests.jl | 98 ------------------------------------ test/layers/pooling_tests.jl | 78 ++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 100 deletions(-) create mode 100644 test/layers/pooling_tests.jl diff --git a/src/layers/pooling.jl b/src/layers/pooling.jl index 815fb0e4c4..bc4da7b089 100644 --- a/src/layers/pooling.jl +++ b/src/layers/pooling.jl @@ -83,6 +83,13 @@ end for layer_op in (:Max, :Mean, :LP) op = Symbol(lowercase(string(layer_op))) + no_gpu_danger = layer_op == :LP ? """ + + !!! danger "GPU Support" + + This layer is currently only supported on CPU. + """ : "" + layer_name = Symbol(layer_op, :Pool) extra_kwargs = layer_op == :LP ? ", p=2" : "" layer_docstring = """ @@ -112,6 +119,8 @@ for layer_op in (:Max, :Mean, :LP) `size(output,d) == size(x,d) / stride` (possibly rounded) for each spatial dimension. + $(no_gpu_danger) + # Extended Help ## Inputs @@ -138,6 +147,8 @@ for layer_op in (:Max, :Mean, :LP) `(1, 1, c, b)`-shaped output, by performing mean pooling on the complete `(w, h)`-shaped feature maps. + $(no_gpu_danger) + ## Inputs - `x`: Data satisfying `ndims(x) > 2`, i.e. `size(x) = (I_N, ..., I_1, C, N)` @@ -159,6 +170,8 @@ for layer_op in (:Max, :Mean, :LP) - `output_size`: Size of the first `N` dimensions for the output + $(no_gpu_danger) + ## Inputs - `x`: Expects as input an array with `ndims(x) == N + 2`, i.e. channel and batch @@ -195,7 +208,7 @@ for layer_op in (:Max, :Mean, :LP) all(==(1), dilation) || print(io, ", dilation=", PrettyPrinting.tuple_string(dilation)) if $(Meta.quot(op)) == :lp - a.layer.op.p == 2 || print(io, ", p=", a.layer.op.p) + m.layer.op.p == 2 || print(io, ", p=", m.layer.op.p) end print(io, ")") end @@ -213,7 +226,7 @@ for layer_op in (:Max, :Mean, :LP) function Base.show(io::IO, ::MIME"text/plain", g::$(global_layer_name)) print(io, string($(Meta.quot(global_layer_name))), "(") if $(Meta.quot(op)) == :lp - a.layer.op.p == 2 || print(io, ", p=", a.layer.op.p) + g.layer.op.p == 2 || print(io, ", p=", g.layer.op.p) end print(io, ")") end diff --git a/test/layers/conv_tests.jl b/test/layers/conv_tests.jl index 0db8b91144..18a4dd4b39 100644 --- a/test/layers/conv_tests.jl +++ b/test/layers/conv_tests.jl @@ -1,101 +1,3 @@ -@testitem "Pooling" setup=[SharedTestSetup] tags=[:core_layers] begin - rng = StableRNG(12345) - - @testset "$mode" for (mode, aType, dev, ongpu) in MODES - x = randn(rng, Float32, 10, 10, 3, 2) |> aType - y = randn(rng, Float32, 20, 20, 3, 2) |> aType - - layer = AdaptiveMaxPool((5, 5)) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test layer(x, ps, st)[1] == maxpool(x, PoolDims(x, 2)) - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - layer = AdaptiveMeanPool((5, 5)) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test layer(x, ps, st)[1] == meanpool(x, PoolDims(x, 2)) - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - layer = AdaptiveMaxPool((10, 5)) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test layer(y, ps, st)[1] == maxpool(y, PoolDims(y, (2, 4))) - @jet layer(y, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - layer = AdaptiveMeanPool((10, 5)) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test layer(y, ps, st)[1] == meanpool(y, PoolDims(y, (2, 4))) - @jet layer(y, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - layer = GlobalMaxPool() - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test size(layer(x, ps, st)[1]) == (1, 1, 3, 2) - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - layer = GlobalMeanPool() - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test size(layer(x, ps, st)[1]) == (1, 1, 3, 2) - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - layer = MaxPool((2, 2)) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test layer(x, ps, st)[1] == maxpool(x, PoolDims(x, 2)) - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - layer = MeanPool((2, 2)) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test layer(x, ps, st)[1] == meanpool(x, PoolDims(x, 2)) - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3) - - @testset "$ltype SamePad windowsize $k" for ltype in (MeanPool, MaxPool), - k in ((1,), (2,), (3,), (4, 5), (6, 7, 8)) - - x = ones(Float32, (k .+ 3)..., 1, 1) |> aType - - layer = ltype(k; pad=Lux.SamePad()) - display(layer) - ps, st = Lux.setup(rng, layer) |> dev - - @test size(layer(x, ps, st)[1])[1:(end - 2)] == cld.(size(x)[1:(end - 2)], k) - @jet layer(x, ps, st) - __f = x -> sum(first(layer(x, ps, st))) - - soft_fail = ltype == MaxPool ? [AutoFiniteDiff()] : [] - test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail) - end - end -end - @testitem "CNN" setup=[SharedTestSetup] tags=[:core_layers] begin rng = StableRNG(12345) diff --git a/test/layers/pooling_tests.jl b/test/layers/pooling_tests.jl new file mode 100644 index 0000000000..123fe1d052 --- /dev/null +++ b/test/layers/pooling_tests.jl @@ -0,0 +1,78 @@ +@testitem "Pooling" setup=[SharedTestSetup] tags=[:core_layers] begin + rng = StableRNG(12345) + + nnlib_op = Dict(:LPPool => (args...) -> lpnormpool(args...; p=2), + :MeanPool => meanpool, :MaxPool => maxpool) + + @testset "$mode" for (mode, aType, dev, ongpu) in MODES + @testset for ltype in (:LPPool, :MeanPool, :MaxPool) + if ongpu && ltype == :LPPool + @test_broken false + continue + end + + broken_backends = ltype == :LPPool ? [AutoTracker()] : [] + + adaptive_ltype = Symbol(:Adaptive, ltype) + global_ltype = Symbol(:Global, ltype) + + x = randn(rng, Float32, 10, 10, 3, 2) |> aType + y = randn(rng, Float32, 20, 20, 3, 2) |> aType + + layer = getfield(Lux, adaptive_ltype)((5, 5)) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev + + @test size(layer(x, ps, st)[1]) == (5, 5, 3, 2) + @test layer(x, ps, st)[1] == nnlib_op[ltype](x, PoolDims(x, 2)) + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, broken_backends) + + layer = getfield(Lux, adaptive_ltype)((10, 5)) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev + + @test size(layer(y, ps, st)[1]) == (10, 5, 3, 2) + @test layer(y, ps, st)[1] == nnlib_op[ltype](y, PoolDims(y, (2, 4))) + @jet layer(y, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, broken_backends) + + layer = getfield(Lux, global_ltype)() + display(layer) + ps, st = Lux.setup(rng, layer) |> dev + + @test size(layer(x, ps, st)[1]) == (1, 1, 3, 2) + @test layer(x, ps, st)[1] == nnlib_op[ltype](x, PoolDims(x, size(x)[1:2])) + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, broken_backends) + + layer = getfield(Lux, ltype)((2, 2)) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev + + @test layer(x, ps, st)[1] == nnlib_op[ltype](x, PoolDims(x, 2)) + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, broken_backends) + + @testset "SamePad windowsize $k" for k in ((1,), (2,), (3,), (4, 5), (6, 7, 8)) + x = ones(Float32, (k .+ 3)..., 1, 1) |> aType + + layer = getfield(Lux, ltype)(k; pad=Lux.SamePad()) + display(layer) + ps, st = Lux.setup(rng, layer) |> dev + + @test size(layer(x, ps, st)[1])[1:(end - 2)] == + cld.(size(x)[1:(end - 2)], k) + @jet layer(x, ps, st) + __f = x -> sum(first(layer(x, ps, st))) + + soft_fail = ltype == :MaxPool ? [AutoFiniteDiff()] : [] + test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail, broken_backends) + end + end + end +end From c548e64ea5b40949eda536467f4eefadd7f0f1d8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 13:02:35 -0400 Subject: [PATCH 84/95] test: minor test fixes --- test/autodiff/nested_autodiff_tests.jl | 2 +- test/layers/pooling_tests.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/autodiff/nested_autodiff_tests.jl b/test/autodiff/nested_autodiff_tests.jl index 1131446430..1a6a0bf41e 100644 --- a/test/autodiff/nested_autodiff_tests.jl +++ b/test/autodiff/nested_autodiff_tests.jl @@ -29,7 +29,7 @@ function test_nested_ad_input_gradient_jacobian(aType, dev, ongpu, loss_fn, X, m allow_unstable() do test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; - atol=1.0f-3, rtol=1.0f-1, soft_fail=[AutoForwardDiff()], + atol=1.0f-2, rtol=1.0f-1, soft_fail=[AutoForwardDiff()], skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) end end diff --git a/test/layers/pooling_tests.jl b/test/layers/pooling_tests.jl index 123fe1d052..2824de30ee 100644 --- a/test/layers/pooling_tests.jl +++ b/test/layers/pooling_tests.jl @@ -11,7 +11,7 @@ continue end - broken_backends = ltype == :LPPool ? [AutoTracker()] : [] + broken_backends = ltype == :LPPool ? [AutoTracker(), AutoEnzyme()] : [] adaptive_ltype = Symbol(:Adaptive, ltype) global_ltype = Symbol(:Global, ltype) From 3ccb1cc3610bd3bb971738d7c5909c8e1714c52b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 15:03:39 -0400 Subject: [PATCH 85/95] fix: change in init --- src/layers/basic.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 7b0168f96e..d05ff8050d 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -317,7 +317,7 @@ function Dense(mapping::Pair{<:IntegerType, <:IntegerType}, activation=identity; end function Dense(in_dims::IntegerType, out_dims::IntegerType, activation=identity; - init_weight=glorot_uniform, init_bias=zeros32, use_bias::BoolType=True()) + init_weight=nothing, init_bias=nothing, use_bias::BoolType=True()) return Dense(activation, in_dims, out_dims, init_weight, init_bias, static(use_bias)) end @@ -506,7 +506,7 @@ end function Bilinear( ((in1_dims, in2_dims), out)::Pair{<:Tuple, <:IntegerType}, activation=identity; - init_weight=glorot_uniform, init_bias=zeros32, use_bias::BoolType=True()) + init_weight=nothing, init_bias=nothing, use_bias::BoolType=True()) return Bilinear( activation, in1_dims, in2_dims, out, init_weight, init_bias, static(use_bias)) end From bf11707168833374422847adf1ed5563493dfef2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 15:40:20 -0400 Subject: [PATCH 86/95] fix: DDIM updates and fix argument ordering --- examples/DDIM/Project.toml | 4 --- examples/DDIM/main.jl | 68 +++++++++++++++++++------------------- src/helpers/stateful.jl | 10 +++--- 3 files changed, 40 insertions(+), 42 deletions(-) diff --git a/examples/DDIM/Project.toml b/examples/DDIM/Project.toml index 42a76263b4..461bf2222d 100644 --- a/examples/DDIM/Project.toml +++ b/examples/DDIM/Project.toml @@ -1,8 +1,6 @@ [deps] -AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e" @@ -25,10 +23,8 @@ TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -AMDGPU = "0.9.6, 1" ArgCheck = "2.3.0" CairoMakie = "0.12" -ChainRulesCore = "1.23" Comonicon = "1" ConcreteStructs = "0.2.3" DataAugmentation = "0.3" diff --git a/examples/DDIM/main.jl b/examples/DDIM/main.jl index 6e81b88f8d..1a0039541f 100644 --- a/examples/DDIM/main.jl +++ b/examples/DDIM/main.jl @@ -6,11 +6,10 @@ # ## Package Imports -using ArgCheck, CairoMakie, ChainRulesCore, ConcreteStructs, Comonicon, DataAugmentation, - DataDeps, FileIO, ImageCore, JLD2, Lux, LuxCUDA, MLUtils, Optimisers, - ParameterSchedulers, ProgressBars, Random, Setfield, StableRNGs, Statistics, Zygote +using ArgCheck, CairoMakie, ConcreteStructs, Comonicon, DataAugmentation, DataDeps, FileIO, + ImageCore, JLD2, Lux, LuxCUDA, MLUtils, Optimisers, ParameterSchedulers, ProgressBars, + Random, Setfield, StableRNGs, Statistics, Zygote using TensorBoardLogger: TBLogger, log_value, log_images -const CRC = ChainRulesCore CUDA.allowscalar(false) @@ -130,24 +129,22 @@ function ddim(rng::AbstractRNG, args...; min_signal_rate=0.02f0, max_signal_rate, dispatch=:DDIM) do x::AbstractArray{<:Real, 4} images = bn(x) rng = Lux.replicate(rng) - T = eltype(x) - noises = CRC.@ignore_derivatives randn!(rng, similar(images, T, size(images)...)) - diffusion_times = CRC.@ignore_derivatives rand!( - rng, similar(images, T, 1, 1, 1, size(images, 4))) + noises = rand_like(rng, images) + diffusion_times = rand_like(rng, images, (1, 1, 1, size(images, 4))) - noise_rates, signal_rates = __diffusion_schedules( + noise_rates, signal_rates = diffusion_schedules( diffusion_times, min_signal_rate, max_signal_rate) noisy_images = @. signal_rates * images + noise_rates * noises - pred_noises, pred_images = __denoise(unet, noisy_images, noise_rates, signal_rates) + pred_noises, pred_images = denoise(unet, noisy_images, noise_rates, signal_rates) @return noises, images, pred_noises, pred_images end end -function __diffusion_schedules(diffusion_times::AbstractArray{T, 4}, min_signal_rate::T, +function diffusion_schedules(diffusion_times::AbstractArray{T, 4}, min_signal_rate::T, max_signal_rate::T) where {T <: Real} start_angle = acos(max_signal_rate) end_angle = acos(min_signal_rate) @@ -160,8 +157,7 @@ function __diffusion_schedules(diffusion_times::AbstractArray{T, 4}, min_signal_ return noise_rates, signal_rates end -function __denoise( - unet, noisy_images::AbstractArray{T, 4}, noise_rates::AbstractArray{T, 4}, +function denoise(unet, noisy_images::AbstractArray{T, 4}, noise_rates::AbstractArray{T, 4}, signal_rates::AbstractArray{T, 4}) where {T <: Real} pred_noises = unet((noisy_images, noise_rates .^ 2)) pred_images = @. (noisy_images - pred_noises * noise_rates) / signal_rates @@ -170,7 +166,7 @@ end # ## Helper Functions for Image Generation -function __reverse_diffusion( +function reverse_diffusion( model, initial_noise::AbstractArray{T, 4}, diffusion_steps::Int) where {T <: Real} num_images = size(initial_noise, 4) step_size = one(T) / diffusion_steps @@ -188,15 +184,15 @@ function __reverse_diffusion( # We start t = 1, and gradually decreases to t=0 diffusion_times = (ones(T, 1, 1, 1, num_images) .- step_size * step) |> dev - noise_rates, signal_rates = __diffusion_schedules( + noise_rates, signal_rates = diffusion_schedules( diffusion_times, min_signal_rate, max_signal_rate) - pred_noises, pred_images = __denoise( + pred_noises, pred_images = denoise( StatefulLuxLayer{true}(model.model.layers.unet, model.ps.unet, model.st.unet), noisy_images, noise_rates, signal_rates) next_diffusion_times = diffusion_times .- step_size - next_noisy_rates, next_signal_rates = __diffusion_schedules( + next_noisy_rates, next_signal_rates = diffusion_schedules( next_diffusion_times, min_signal_rate, max_signal_rate) next_noisy_images = next_signal_rates .* pred_images .+ @@ -206,14 +202,14 @@ function __reverse_diffusion( return pred_images end -function __denormalize(model::StatefulLuxLayer{true}, x::AbstractArray{<:Real, 4}) +function denormalize(model::StatefulLuxLayer, x::AbstractArray{<:Real, 4}) mean = reshape(model.st.bn.running_mean, 1, 1, 3, 1) var = reshape(model.st.bn.running_var, 1, 1, 3, 1) std = sqrt.(var .+ model.model.layers.bn.epsilon) return std .* x .+ mean end -function __save_images(output_dir, images::AbstractArray{<:Real, 4}) +function save_images(output_dir, images::AbstractArray{<:Real, 4}) imgs = Vector{Array{RGB, 2}}(undef, size(images, 4)) for i in axes(images, 4) img = @view images[:, :, :, i] @@ -224,7 +220,7 @@ function __save_images(output_dir, images::AbstractArray{<:Real, 4}) return imgs end -function __generate_and_save_image_grid(output_dir, imgs::Vector{<:AbstractArray{<:RGB, 2}}) +function generate_and_save_image_grid(output_dir, imgs::Vector{<:AbstractArray{<:RGB, 2}}) fig = Figure() nrows, ncols = 3, 4 for r in 1:nrows, c in 1:ncols @@ -238,11 +234,11 @@ function __generate_and_save_image_grid(output_dir, imgs::Vector{<:AbstractArray return end -function __generate( +function generate( model::StatefulLuxLayer, rng, image_size::NTuple{4, Int}, diffusion_steps::Int, dev) initial_noise = randn(rng, Float32, image_size...) |> dev - generated_images = __reverse_diffusion(model, initial_noise, diffusion_steps) - generated_images = __denormalize(model, generated_images) + generated_images = reverse_diffusion(model, initial_noise, diffusion_steps) + generated_images = denormalize(model, generated_images) return clamp01.(generated_images) end @@ -287,21 +283,23 @@ function Base.getindex(ds::FlowersDataset, i::Int) end function preprocess_image(image::Matrix{<:RGB}, image_size::Int) - return apply(CenterResizeCrop((image_size, image_size)), Image(image)) |> itemdata + return apply( + CenterResizeCrop((image_size, image_size)), DataAugmentation.Image(image)) |> + itemdata end const maeloss = MAELoss() function loss_function(model, ps, st, data) (noises, images, pred_noises, pred_images), st = Lux.apply(model, data, ps, st) - noise_loss = maeloss(noises, pred_noises) - image_loss = maeloss(images, pred_images) + noise_loss = maeloss(pred_noises, noises) + image_loss = maeloss(pred_images, images) return noise_loss, st, (; image_loss, noise_loss) end # ## Entry Point for our code -@main function main(; epochs::Int=100, image_size::Int=128, +Comonicon.@main function main(; epochs::Int=100, image_size::Int=128, batchsize::Int=128, learning_rate_start::Float32=1.0f-3, learning_rate_end::Float32=1.0f-5, weight_decay::Float32=1.0f-6, checkpoint_interval::Int=25, expt_dir=tempname(@__DIR__), @@ -316,7 +314,8 @@ end @info "Experiment directory: $(expt_dir)" - rng = StableRNG(1234) + rng = Random.default_rng() + Random.seed!(rng, 1234) image_dir = joinpath(expt_dir, "images") isdir(image_dir) || mkpath(image_dir) @@ -339,19 +338,20 @@ end states = states |> gdev model = StatefulLuxLayer{true}(model, parameters, Lux.testmode(states)) - generated_images = __generate(model, StableRNG(generate_image_seed), + generated_images = generate(model, StableRNG(generate_image_seed), (image_size, image_size, 3, generate_n_images), diffusion_steps, gdev) |> cpu_device() path = joinpath(image_dir, "inference") @info "Saving generated images to $(path)" - imgs = __save_images(path, generated_images) - __generate_and_save_image_grid(path, imgs) + imgs = save_images(path, generated_images) + generate_and_save_image_grid(path, imgs) return end tb_dir = joinpath(expt_dir, "tb_logs") - @info "Logging Tensorboard logs to $(tb_dir). Run tensorboard with `tensorboard --logdir $(dirname(tb_dir))`" + @info "Tensorboard logs being saved to $(tb_dir). Run tensorboard with \ + `tensorboard --logdir $(dirname(tb_dir))`" tb_logger = TBLogger(tb_dir) tstate = Training.TrainState( @@ -393,13 +393,13 @@ end if epoch % generate_image_interval == 0 || epoch == epochs model_test = StatefulLuxLayer{true}( tstate.model, tstate.parameters, Lux.testmode(tstate.states)) - generated_images = __generate(model_test, StableRNG(generate_image_seed), + generated_images = generate(model_test, StableRNG(generate_image_seed), (image_size, image_size, 3, generate_n_images), diffusion_steps, gdev) |> cpu_device() path = joinpath(image_dir, "epoch_$(epoch)") @info "Saving generated images to $(path)" - imgs = __save_images(path, generated_images) + imgs = save_images(path, generated_images) log_images(tb_logger, "Generated Images", imgs; step) end diff --git a/src/helpers/stateful.jl b/src/helpers/stateful.jl index 02c57eeaf4..0fdf475ee6 100644 --- a/src/helpers/stateful.jl +++ b/src/helpers/stateful.jl @@ -121,10 +121,12 @@ function (s::StatefulLuxLayer)(x, p=s.ps) return y end -function CRC.rrule( - ::Type{<:StatefulLuxLayer{FT}}, model::AbstractLuxLayer, ps, st, st_any) where {FT} - slayer = StatefulLuxLayer{FT}(model, ps, st, st_any) - ∇StatefulLuxLayer(Δ) = NoTangent(), NoTangent(), Δ.ps, NoTangent(), NoTangent() +function CRC.rrule(::Type{<:StatefulLuxLayer}, model::AbstractLuxLayer, + ps, st, st_any, fixed_state_type) + slayer = StatefulLuxLayer(model, ps, st, st_any, fixed_state_type) + function ∇StatefulLuxLayer(Δ) + return NoTangent(), NoTangent(), Δ.ps, NoTangent(), NoTangent(), NoTangent() + end return slayer, ∇StatefulLuxLayer end From f23be5f4db0fb7603bda25e14d005c259d362f14 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 15:44:37 -0400 Subject: [PATCH 87/95] fix: testing using old init assumptions --- test/contrib/map_tests.jl | 10 +++++----- test/layers/basic_tests.jl | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/contrib/map_tests.jl b/test/contrib/map_tests.jl index 00b2f994f6..8badcf358c 100644 --- a/test/contrib/map_tests.jl +++ b/test/contrib/map_tests.jl @@ -8,16 +8,16 @@ function zero_dense_params_1(l, ps, st, name) if l isa Dense && occurs_in(KeyPath(:chain), name) - @set! ps.weight = zero.(ps.weight) - @set! ps.bias = zero.(ps.bias) + @set! ps.weight = zero(ps.weight) + @set! ps.bias = zero(ps.bias) end return l, ps, st end function zero_dense_params_2(l, ps, st, name) if l isa Dense - @set! ps.weight = zero.(ps.weight) - @set! ps.bias = zero.(ps.bias) + @set! ps.weight = zero(ps.weight) + @set! ps.bias = zero(ps.bias) end return l, ps, st end @@ -37,7 +37,7 @@ @test all(iszero, ps_.chain.dense_2.weight) @test all(iszero, ps_.chain.dense_2.bias) @test !all(iszero, ps_.dense_3.weight) - @test all(iszero, ps_.dense_3.bias) + @test !all(iszero, ps_.dense_3.bias) # Custom Layers -- See https://github.com/LuxDL/Lux.jl/issues/187 struct SimpleCustom{L1, L2} <: Lux.AbstractLuxContainerLayer{(:dense, :conv)} diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 1535ab3b04..78bcf97e80 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -302,7 +302,7 @@ end @testset "Two-streams zero sum" begin x = zeros(Float32, 2, 1) |> aType y = zeros(Float32, 1, 1) |> aType - layer = Bilinear((2, 1) => 3) + layer = Bilinear((2, 1) => 3; init_bias=zeros32) display(layer) ps, st = Lux.setup(rng, layer) |> dev From 06b20b87e43b89a7540d02796d3ca778ccdf820a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 16:01:01 -0400 Subject: [PATCH 88/95] fix: ConvMixer minor updates --- examples/ConvMixer/README.md | 2 ++ examples/ConvMixer/main.jl | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/ConvMixer/README.md b/examples/ConvMixer/README.md index fbea290333..f16d8850db 100644 --- a/examples/ConvMixer/README.md +++ b/examples/ConvMixer/README.md @@ -78,6 +78,8 @@ Flags 1. Weight-Decay with Adam in Optimisers.jl works differently from `torch.optim.AdamW`, so you might need to adjust the value of `--weight-decay` to get the same results. + Pytorch multiplies the weight decay with the learning rate, whereas in Optimisers.jl + the learning rate is decoupled from the weight decay. 2. To match the results from the original repo, we need more augmentation strategies, that are currently not implemented in DataAugmentation.jl. 3. Don't compare the reported timings in that repo against the numbers here. They time the diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index 2c9dc2824c..0f8e9600ba 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -68,7 +68,7 @@ function accuracy(model, ps, st, dataloader) return total_correct / total end -@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, +Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5, clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01) rng = StableRNG(seed) From 1039d979bbc7523655699e5dd669cc7709e45c28 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 16:16:46 -0400 Subject: [PATCH 89/95] fix: onehot supports GPUArrays --- examples/ConvMixer/main.jl | 5 ++--- examples/HyperNet/main.jl | 5 ++--- examples/NeuralODE/main.jl | 5 ++--- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index 0f8e9600ba..56ca4115f1 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -58,10 +58,9 @@ end function accuracy(model, ps, st, dataloader) total_correct, total = 0, 0 st = Lux.testmode(st) - cpu_dev = cpu_device() for (x, y) in dataloader - target_class = onecold(cpu_dev(y)) - predicted_class = onecold(cpu_dev(first(model(x, ps, st)))) + target_class = onecold(y) + predicted_class = onecold(first(model(x, ps, st))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end diff --git a/examples/HyperNet/main.jl b/examples/HyperNet/main.jl index e2d12e7c91..e0d96d4d72 100644 --- a/examples/HyperNet/main.jl +++ b/examples/HyperNet/main.jl @@ -62,10 +62,9 @@ const loss = CrossEntropyLoss(; logits=Val(true)) function accuracy(model, ps, st, dataloader, data_idx) total_correct, total = 0, 0 st = Lux.testmode(st) - cpu_dev = cpu_device() for (x, y) in dataloader - target_class = onecold(cpu_dev(y)) - predicted_class = onecold(cpu_dev(model((data_idx, x), ps, st)[1])) + target_class = onecold(y) + predicted_class = onecold(first(model((data_idx, x), ps, st))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index 2b83e13bab..085bcafedd 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -109,10 +109,9 @@ const logitcrossentropy = CrossEntropyLoss(; logits=Val(true)) function accuracy(model, ps, st, dataloader) total_correct, total = 0, 0 st = Lux.testmode(st) - cpu_dev = cpu_device() for (x, y) in dataloader - target_class = onecold(cpu_dev(y)) - predicted_class = onecold(cpu_dev(first(model(x, ps, st)))) + target_class = onecold(y) + predicted_class = onecold(first(model(x, ps, st))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end From 9516a8ddb5ff6bf141b5445de6cf40b3a125bbd2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 16:16:22 -0400 Subject: [PATCH 90/95] test: explicitly zero init bias --- test/layers/basic_tests.jl | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/test/layers/basic_tests.jl b/test/layers/basic_tests.jl index 78bcf97e80..7c06662971 100644 --- a/test/layers/basic_tests.jl +++ b/test/layers/basic_tests.jl @@ -138,37 +138,31 @@ end @testset "zeros" begin @test begin - layer = Dense(10, 1, identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + layer = Dense(10, 1, identity; init_weight=ones32, init_bias=zeros32) first(Lux.apply( layer, ones(10, 1) |> aType, dev.(Lux.setup(rng, layer))...)) end == 10 * aType(ones(1, 1)) @test begin - layer = Dense(10, 1, identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + layer = Dense(10, 1, identity; init_weight=ones32, init_bias=zeros32) first(Lux.apply( layer, ones(10, 2) |> aType, dev.(Lux.setup(rng, layer))...)) end == 10 * aType(ones(1, 2)) @test begin - layer = Dense(10, 2, identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + layer = Dense(10, 2, identity; init_weight=ones32, init_bias=zeros32) first(Lux.apply( layer, ones(10, 1) |> aType, dev.(Lux.setup(rng, layer))...)) end == 10 * aType(ones(2, 1)) @test begin - layer = Dense(10, 2, identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + layer = Dense(10, 2, identity; init_weight=ones32, init_bias=zeros32) first(Lux.apply(layer, aType([ones(10, 1) 2 * ones(10, 1)]), dev.(Lux.setup(rng, layer))...)) end == aType([10 20; 10 20]) @test begin - layer = Dense(10, 2, identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...), - use_bias=false) + layer = Dense(10, 2, identity; init_weight=ones32, use_bias=false) first(Lux.apply(layer, aType([ones(10, 1) 2 * ones(10, 1)]), dev.(Lux.setup(rng, layer))...)) end == aType([10 20; 10 20]) @@ -214,23 +208,19 @@ end @testset "zeros" begin @test begin - layer = Scale(10, 1, identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + layer = Scale(10, 1, identity; init_weight=ones32) first(Lux.apply( layer, ones(10, 1) |> aType, dev.(Lux.setup(rng, layer))...)) end == aType(ones(10, 1)) @test begin - layer = Scale(10, 1, identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + layer = Scale(10, 1, identity; init_weight=ones32) first(Lux.apply( layer, ones(10, 2) |> aType, dev.(Lux.setup(rng, layer))...)) end == aType(ones(10, 2)) @test begin - layer = Scale(2, identity; - init_weight=(rng, args...; kwargs...) -> ones(args...; kwargs...), - init_bias=(rng, args...; kwargs...) -> ones(args...; kwargs...)) + layer = Scale(2, identity; init_weight=ones32, init_bias=ones32) first(Lux.apply( layer, [1 2; 3 4] |> aType, dev.(Lux.setup(rng, layer))...)) end == aType([2.0 3.0; 4.0 5.0]) From 9ca0d38cc79a4d138cc24894ea7270adae7b97a7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 17:03:24 -0400 Subject: [PATCH 91/95] fix: optionally test with FiniteDiff if ForwardDiff fails --- test/autodiff/nested_autodiff_tests.jl | 38 +++++++++++++++--- test/layers/recurrent_tests.jl | 53 +++++++++++++------------- 2 files changed, 59 insertions(+), 32 deletions(-) diff --git a/test/autodiff/nested_autodiff_tests.jl b/test/autodiff/nested_autodiff_tests.jl index 1a6a0bf41e..8344006767 100644 --- a/test/autodiff/nested_autodiff_tests.jl +++ b/test/autodiff/nested_autodiff_tests.jl @@ -27,10 +27,23 @@ function test_nested_ad_input_gradient_jacobian(aType, dev, ongpu, loss_fn, X, m !iszero(ComponentArray(∂ps |> cpu_device())) && all(x -> x === nothing || isfinite(x), ComponentArray(∂ps |> cpu_device())) + __f = (x, ps) -> loss_fn(model, x, ps, st) + allow_unstable() do - test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; - atol=1.0f-2, rtol=1.0f-1, soft_fail=[AutoForwardDiff()], - skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) + FDIFF_WORKS = try + LuxTestUtils.gradient(__f, AutoForwardDiff(), X, ps) + true + catch + false + end + skip_backends = [AutoReverseDiff(), AutoTracker(), AutoEnzyme()] + if FDIFF_WORKS + test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; atol=1.0f-3, + rtol=1.0f-1, soft_fail=[AutoFiniteDiff()], skip_backends) + else + test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; atol=1.0f-3, + rtol=1.0f-1, skip_backends=vcat(skip_backends, [AutoFiniteDiff()])) + end end end @@ -152,10 +165,23 @@ function test_nested_ad_parameter_gradient_jacobian(aType, dev, ongpu, loss_fn, !iszero(ComponentArray(∂ps |> cpu_device())) && all(x -> x === nothing || isfinite(x), ComponentArray(∂ps |> cpu_device())) + __f = (x, ps) -> loss_fn(model, x, ps, st) + allow_unstable() do - test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; - atol=1.0f-3, rtol=1.0f-1, soft_fail=[AutoForwardDiff()], - skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) + FDIFF_WORKS = try + LuxTestUtils.gradient(__f, AutoForwardDiff(), X, ps) + true + catch + false + end + skip_backends = [AutoReverseDiff(), AutoTracker(), AutoEnzyme()] + if FDIFF_WORKS + test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; atol=1.0f-3, + rtol=1.0f-1, soft_fail=[AutoFiniteDiff()], skip_backends) + else + test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; atol=1.0f-3, + rtol=1.0f-1, skip_backends=vcat(skip_backends, [AutoFiniteDiff()])) + end end end diff --git a/test/layers/recurrent_tests.jl b/test/layers/recurrent_tests.jl index 590c7338e9..91b5ace684 100644 --- a/test/layers/recurrent_tests.jl +++ b/test/layers/recurrent_tests.jl @@ -2,13 +2,13 @@ rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - for rnncell in (RNNCell(3 => 5, identity), RNNCell(3 => 5, tanh), + @testset for rnncell in (RNNCell(3 => 5, identity), RNNCell(3 => 5, tanh), RNNCell(3 => 5, tanh; use_bias=false), RNNCell(3 => 5, identity; use_bias=false), RNNCell(3 => 5, identity; use_bias=false, train_state=false)) display(rnncell) ps, st = Lux.setup(rng, rnncell) |> dev - for x_size in ((3, 2), (3,)) + @testset for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType (y, carry), st_ = Lux.apply(rnncell, x, ps, st) @@ -17,7 +17,7 @@ function loss_loop_rnncell(p) (y, carry), st_ = rnncell(x, p, st) - for i in 1:10 + for _ in 1:10 (y, carry), st_ = rnncell((x, carry), p, st_) end return sum(abs2, y) @@ -25,13 +25,14 @@ @test_throws ErrorException ps.train_state - test_gradients(loss_loop_rnncell, ps; atol=1.0f-3, - rtol=1.0f-3, broken_backends=[AutoEnzyme()]) + test_gradients(loss_loop_rnncell, ps; atol=1.0f-3, rtol=1.0f-3, + soft_fail=[AutoFiniteDiff()], broken_backends=[AutoEnzyme()]) end end @testset "Trainable hidden states" begin - for rnncell in (RNNCell(3 => 5, identity; use_bias=false, train_state=true), + @testset for rnncell in ( + RNNCell(3 => 5, identity; use_bias=false, train_state=true), RNNCell(3 => 5, identity; use_bias=true, train_state=true)) rnn_no_trainable_state = RNNCell( 3 => 5, identity; use_bias=false, train_state=false) @@ -62,12 +63,12 @@ end rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - for lstmcell in (LSTMCell(3 => 5), LSTMCell(3 => 5; use_bias=true), + @testset for lstmcell in (LSTMCell(3 => 5), LSTMCell(3 => 5; use_bias=true), LSTMCell(3 => 5; use_bias=false)) display(lstmcell) ps, st = Lux.setup(rng, lstmcell) |> dev - for x_size in ((3, 2), (3,)) + @testset for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType (y, carry), st_ = Lux.apply(lstmcell, x, ps, st) @@ -82,8 +83,8 @@ end return sum(abs2, y) end - test_gradients(loss_loop_lstmcell, ps; atol=1.0f-3, - rtol=1.0f-3, broken_backends=[AutoEnzyme()]) + test_gradients(loss_loop_lstmcell, ps; atol=1.0f-3, rtol=1.0f-3, + soft_fail=[AutoFiniteDiff()], broken_backends=[AutoEnzyme()]) @test_throws ErrorException ps.train_state @test_throws ErrorException ps.train_memory @@ -91,7 +92,7 @@ end end @testset "Trainable hidden states" begin - for x_size in ((3, 2), (3,)) + @testset for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType _lstm = LSTMCell( 3 => 5; use_bias=false, train_state=false, train_memory=false) @@ -173,12 +174,12 @@ end rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - for grucell in (GRUCell(3 => 5), GRUCell(3 => 5; use_bias=true), + @testset for grucell in (GRUCell(3 => 5), GRUCell(3 => 5; use_bias=true), GRUCell(3 => 5; use_bias=false)) display(grucell) ps, st = Lux.setup(rng, grucell) |> dev - for x_size in ((3, 2), (3,)) + @testset for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType (y, carry), st_ = Lux.apply(grucell, x, ps, st) @@ -193,15 +194,15 @@ end return sum(abs2, y) end - test_gradients(loss_loop_grucell, ps; atol=1e-3, - rtol=1e-3, broken_backends=[AutoEnzyme()]) + test_gradients(loss_loop_grucell, ps; atol=1e-3, rtol=1e-3, + soft_fail=[AutoFiniteDiff()], broken_backends=[AutoEnzyme()]) @test_throws ErrorException ps.train_state end end @testset "Trainable hidden states" begin - for x_size in ((3, 2), (3,)) + @testset for x_size in ((3, 2), (3,)) x = randn(rng, Float32, x_size...) |> aType _gru = GRUCell(3 => 5; use_bias=false, train_state=false) _ps, _st = Lux.setup(rng, _gru) |> dev @@ -280,8 +281,8 @@ end return sum(abs2, y) end - test_gradients( - loss_loop_rnn, ps; atol=1e-3, rtol=1e-3, broken_backends=[AutoEnzyme()]) + test_gradients(loss_loop_rnn, ps; atol=1e-3, rtol=1e-3, + broken_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()]) end end end @@ -325,12 +326,12 @@ end @test all(x -> size(x) == (5, 2), y_) __f = p -> sum(first(rnn(x, p, st))) - test_gradients( - __f, ps; atol=1e-3, rtol=1e-3, skip_backends=[AutoEnzyme()]) + test_gradients(__f, ps; atol=1e-3, rtol=1e-3, + skip_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()]) __f = p -> sum(Base.Fix1(sum, abs2), first(rnn_seq(x, p, st))) - test_gradients( - __f, ps; atol=1e-3, rtol=1e-3, skip_backends=[AutoEnzyme()]) + test_gradients(__f, ps; atol=1e-3, rtol=1e-3, + skip_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()]) end # Batched Time Series without data batches @@ -360,12 +361,12 @@ end end __f = p -> sum(first(rnn(x, p, st))) - test_gradients( - __f, ps; atol=1e-3, rtol=1e-3, skip_backends=[AutoEnzyme()]) + test_gradients(__f, ps; atol=1e-3, rtol=1e-3, + skip_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()]) __f = p -> sum(Base.Fix1(sum, abs2), first(rnn_seq(x, p, st))) - test_gradients( - __f, ps; atol=1e-3, rtol=1e-3, skip_backends=[AutoEnzyme()]) + test_gradients(__f, ps; atol=1e-3, rtol=1e-3, + skip_backends=[AutoEnzyme()], soft_fail=[AutoFiniteDiff()]) end end end From e395ed9ec512e5deebe3013e48be24a1ece17e47 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 18:39:56 -0400 Subject: [PATCH 92/95] ci(buildkite): run some of the tutorials on CPU runners (#879) [skip tests] --- .buildkite/documentation.yml | 42 ++++++++++++++++++--- .buildkite/pipeline.yml | 2 +- docs/tutorials.jl | 41 +++++++++++++------- examples/GravitationalWaveForm/Project.toml | 2 - examples/GravitationalWaveForm/main.jl | 6 +-- 5 files changed, 67 insertions(+), 26 deletions(-) diff --git a/.buildkite/documentation.yml b/.buildkite/documentation.yml index 9895b6ef15..ecdb2c7d56 100644 --- a/.buildkite/documentation.yml +++ b/.buildkite/documentation.yml @@ -2,9 +2,9 @@ steps: - group: ":open_book: Build & Deploy Documentation" if: build.message !~ /\[skip docs\]/ && !build.pull_request.draft steps: - - label: "Tutorial Build [%N/%t]" - key: "tutorial-build" - parallelism: 6 + - label: "Tutorial Build [%N/%t] CUDA Runners" + key: "tutorial-build-cuda" + parallelism: 4 plugins: - JuliaCI/julia#v1: version: "1" @@ -14,6 +14,8 @@ steps: - src - ext command: julia --code-coverage=user --color=yes --project=docs docs/tutorials.jl + env: + TUTORIAL_BACKEND_GROUP: "CUDA" agents: queue: "juliagpu" cuda: "*" @@ -22,10 +24,40 @@ steps: - "docs/src/tutorials/intermediate/**/*" - "docs/src/tutorials/advanced/**/*" - "tutorial_deps/*" + - "**/*.cov" + timeout_in_minutes: 60 + + - label: "Tutorial Build [%N/%t] CPU Runners" + if: build.message !~ /\[skip docs\]/ && !build.pull_request.draft + key: "tutorial-build-cpu" + parallelism: 4 + plugins: + - JuliaCI/julia#v1: + version: "1" + # - JuliaCI/julia-coverage#v1: + # codecov: true + # dirs: + # - src + # - ext + command: julia --code-coverage=user --color=yes --project=docs docs/tutorials.jl + env: + TUTORIAL_BACKEND_GROUP: "CPU" + agents: + queue: "juliaecosystem" + os: "macos" + arch: "aarch64" + artifact_paths: + - "docs/src/tutorials/beginner/**/*" + - "docs/src/tutorials/intermediate/**/*" + - "docs/src/tutorials/advanced/**/*" + - "tutorial_deps/*" + - "**/*.cov" timeout_in_minutes: 60 - label: "Final Documentation Build" - depends_on: [tutorial-build] + depends_on: + - "tutorial-build-cuda" + - "tutorial-build-cpu" plugins: - JuliaCI/julia#v1: version: "1" @@ -65,4 +97,4 @@ env: JULIA_NUM_THREADS: 4 GKSwstype: "100" # https://discourse.julialang.org/t/generation-of-documentation-fails-qt-qpa-xcb-could-not-connect-to-display/60988 SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" - SECRET_DOCUMENTER_KEY: "iRC4P/r5o9pARB670eK9jPlKQKgkTMDAyvp2GbLG8WwLuT8T1VcWx/o4+ofGlzbTh5Z+LuFgPXfgqkjGuoWLcocHNm78xQMNMywB4rcLB2shqp8xG2vhglgnTBBS4EiyPAtVqGyi5AKmfF95PfkJvnI0Lqg5P/RWQvNGywLAR0Ikgr/lqocm2CvkFGbpMzpGxGvj76JYOusVeKvGAp698TXqPabSZR2oZQLfYnEZnaO8ivkqvMGQSXfgzoIMjCOrN1rSa84SWeI9BDeBslzDHwaYGlvjpfCyviiLtKj4t5Acl1gVE0qxxZxWuALIU6z+C1W8TbW7ZDCBUFs6UTIT+Q==;U2FsdGVkX1+/HSgg1skLszz835vSO6mEtXMhG62ohQQUc5opdo7kEIAG2wCoJPQrqGyaF9kKDVvrN5G2MdjUyaLBYlv90RzXhjTiMNFdgI3M4K500xKq3itY/aEL7hUSMRKxTos8u4xhdbRboY4rPcqgtCJ2LHEjNxmml/NfEo/8lk291rGoEYQLTvKP9cuo4enmEVVRhqmabBzt1MDz0m4c8RufJWW2Ni4osaKRkYPjl/ijJ38wvRUZIiyCX7uofh+3iCKWn0111q5xFhn256Pm79Cx2ZP+yTp9sMsVNMJZ3UJ5r18F3H+zFHWWQSoiWpHn2WNB/2VUEyt0Lp1LnogKru96P2oYkXi6kqrA+qlLISUUU7R7ggJU0IRS6MjSGDyVzlaZG8m+RmY0bmQKrDwSeq1JMGkBpjwPY1o4yOnFRB7Rj1bzToLtd2IFSa8x0a2dUSyL5pBlyWklzZCxPp05R53RNSOi2KfhNfdZU2H7xEj5+z2aV5OidzowXIyYH8FlusMdk3NAOsvTbmBGiwvN4Zub9Exli06ZwARu/oJHLRh+hgOErIJ7DoX6nPrAtofSy6Etydpt+c4HkVZtGPWFSTMNWIGNx2NB1IfveOTU60H5emQ7zow5grXz4VTczqvCIh2hoQdSR4Oplr6+tDDLhtcGGHchHt473o2ygQ1m1tg7oSvMN7jmkUV1N6GniQofmlbr8d5LK4i/QtfC5GHCKIg3ohRlDvuvvKzvVWofgHX3NhXFTKK/CWAIp76iOaCWJcI562SpKyn+pFqYKpatJ42WfF3VbNpJYVMYMai5BwAE2RyZ6FhHbsaHq/NXO/dRJwHeDm4Pc/LFlGFdzpdbuf+w2DoePc56PlNmKsLNlZVlwbWcExKttI8nz3Th3aHNNtbIbD9awf1RdDspudQrTPWkyEopDVm7TkOj/J891U5p24PF5dasIJR19Tqpic3LVJuBXYRbL/Z79VRjeE3wBGLTDdhzJMA8TrS+yMSCF80bIw/F44o4WbA3Ya425mph9MIt/a137osRKATYqbustmVW/LfIyVhuHCOCRQsqTyFU+ff6Tp0EE2i1du90wosr+UutXiubYphCmuKkZONPbiXjpW1CAi40iAwxfgOVqAl13y4FlUp4EiGS7hPBUbvvEXMqT3ssfL+mlideH/v08PQCRcyG03zcCjCTmjXCggqHd+eEXhnsNZ4PFKCKiN+znR5SW+/p+kJTaBrX2e/kMU6kzjwb4NyNmZie0hHSneVtwJ1FuXJk/Zph4quv5KugCCx21xb5pePqxfKRW5jtW6r2Rc7OSNN4BHjwAcj8fOVV+12Ak7//o8mRh0aveYfoEvjCdaI8OPfjduDGfmzPUvXiqV9kGpovdlDUATyoVa3l1CowJ5r8KDOD6Ps89OG7TV2c7Wzxq2FQVjMFXxv/4wMZR1F/0zyH+ofPLVZjK3039z35GD4uoOW9Uc7WSr4FbxxuCDwOXWgstuk3rk6ASZFSe7RIwE/Y16d/aqzI+LG8pHqaEdhg6o6Y6JxBYNQo/JoglUOHwD+N5g5n9vfBNzf0xTlE/r0yjO3LCHyWzCnWr3QdKgzm6EDyL8GO+yQIbtXtw6lRQB/UEZ+ayt175r08Yhey95IsPwLVDFRRlG6pYwmzTlQOEwvqDI8SDMWboU+jp6a5jrbaAmqiIkaoiIzrV1QDp1x+Sqj0veqN+RtcpXLawJevz8dm76H+Mmp1br61nwvGcBaOKukICVj3iLeeu5tV5NoEJznWPwveHrcarZtKvOOeJbydmNAz286i0F1ocX337dt17jIkRv9sHbfqAVapob+eT7F3N/UY99GWGDVbXzaruQwsuPPR6MbLolG6buHQaKX3OZ/zJqGWfEAHw5yJKoKNe8aSgY2DsoITqPlbNRQQmOIMuF8ffD8L1stD/P5Ohth5Nql2W+l6y87/nqxkJ9y4FFS4QzrMrl9ztugfsRoYyeSWRydLUHlTCv155VsGAxjCMBQg1rP99Smfd02EbCFlWlypIw/zem0LZ1zVuz/Wjb03n+dzi2GIKRlTrt6YMrGGAcKI+3Pf1D0rsDhXNkdFUjOeofUkDbBr/splYCKLucDHFVdN88XyaQoj2fBymNJ4BqvK64TVOLwPGAQvh/rHZ5PkJR3lMI4fg+Kxdl9/5xDjkD9aV+yRvfqVGodNW/qofq34nrdb3co1tZ4BxtSANKdJg3Fv6U0I4DOMVsJTeOn/918M31rif0rKAwnHAkeyQVbZyEsFoqxvE8gUFs1zTRwZJWlmY0xnuVcM8pOh6hULeYGiF57ZlbvymygYqObe58YgrChRnF4NhKIIYzuz7mOSKRXqF3Cr0LNYHcktUH9wrqISxiHbaUQceYZ1D0q8UfiayeK9yppMkltcDUL9M93xjTGJK8pVzARXn6ETuEsNTtLvbU/KMDY7bnVc7n08suLCk1YeJB/sn0wuTbPt+27NeYIG1YXBEE0dsgJW4z64489h71v4xws856gFOHZx0L/nkW7l328HA3jltbgJFl52mQHAJwUZrt5sJef/k7gsTdX1zQtjKN8lFjo4qpvJUpenmO9nT+Wty5cjohlETBos8CdSqj4SjEu7/UhDt52evt33EayoWJ8TjKd4VRFYCXnM6eGnSMDqUU5f7DxVjrwHnT26jtq9ijKTiAxls7fYjN8TGT/S3CHZZAK1u5gSbWfkFOcE+mioboNwDvuvysjL6de+bsc7r35w4hLFnPmKemcde4pNQfEnuelBFJqwYZbcAkhN8AmtqIWPXBw9n3eUx/TJgMFEIoB/frNDRbB0WJKdBkjdE1NVvAUl3jDnZbWjG6rqE+6UvyGqKBpd0FRYAfg3ss3hVB70uluULKUBVazlNIQlqX+qYEMBXaDIkxcftre8KYebQyJnxiOB5V+eELvm6L28bK4Xh2tpXzJL7aDlQnL8dRNvQdZgDL62EXYhrc3mz0I/p7br3KMcnei/LaPRAgcsW7WKLwzE5id6JnpOJj4VXdkX7IUB4xQjDRsGKxhjbklMVFA8g/801khNlwzU/IoXsHBgTs7yZoFX/oo4Jyp514hwqPlvJEgci0OHiSA6Mx3le2nUh0SQH+AzFJ2vi7Bn1a4psiuqd+vJJ1iuNw5CBCZlV+GO8sG93BBGnLzZDoRvkIMbzwESFP3JYZ/lKs29CB2Adobl9YbwP3he0I9cD0A/RPC70gzTdVEfL6T4iPUhBr1Bn3YlUPeC2QvCTbpKkxDsfzchuq/y0xlmL4E7Rdb+4TSMlViXfnc6aoD9vvPMWLJFF2qrxRLKhUTse5V6RoE+EVmHSiX0Vd7sd/bYp7asOC0b1xL+zjfJ5DSrtMA/P8L1p+CoLNXgVfgzCB3sCa+GLSLS2INsL1Qtnfkl8IGaMDeV+VAyHjY0HCj0l1X99f/RzD6TYrZAkLS8h1EM/JjomglhVG9/HTKS20BBJeos5ifrVd38rhONJy0HCP28pn4rCIyIE4bNG+1tEsHAg4FDYgh/OYuBsaGYgha9TGV5lGIxmVCECq3IPpkPN1CsLqv3KuDvNeH6XOOAzVtFj4VoIV6QgRLP8+94ZiiEDaPQxQ7BZoqrqFYrxWHDtEuon46VtQ3Nfq/1Rq/HvszJv6JE77w7qvKlxG9sXgxzCDRqNrG83cwY2hpDBr8U0hPMrEx977Weja1aG/rG6uirNBcY5qAAOLDo+9RvV1xqvWFF8SkT97tzNUHbzw8tuUlCT9m4rshCG+jBw59rpUZwW+eR1ih9qU7Nyr3oNgi/zmkORF1duym8VSfW5dxtRBIqxxM0oSWoHti+HSd0VLdHw8jRpbQddMBr1sjD1jIgp3w2dU4oEthzStKCPY2/lAWBm+1Es1okGhEM3I939DRcYOjfJnTCtJLJ9DTKycVDMerXvHnCgImZ0Oh4mtLF+63hn+9wUc56owFeNqs+NJHqmBBFX2uNr/Rj9mzYkRRPsYYSyCB7jIS+Z8Zall6W3dwLcsE3uw/oPKx5bJDAhnp7kZgzLC0zlS2D0ZcNZuW2uUtwhZJM6OOyV+FUFgizmpIQAQ8Nm6n/1yk0asB4jZFf221a9ZmzvUfWKmmIR7OxX3qBH9x2uMMhemv9LZdEHMcjTeIXRYciMLWUNeWagYhDgV1cRBGCDTh2EhHvYX7ZXfpsHjLOR+sAEr7uR3siitf/mRkiLfT2YBgTACKKoj05UuC8aknEV4T5bWiye+gKGioml5G/fWYHyHow37g6D84n0cBTWmI0oPlg+rqpeRLOeYaTeCXOtM/7M1FHuGvzmBnag2vhKY2tpjVrg2nI3p4SRlzTyoQkyMfRXN87v5nAheVcLgrYtkv9aX7R6VMZ1UIsxn62ZHFa2IR6skB/xw7RRuJY5r5FIWs1LqIQDaon5L4C4v9rnBxMYoUM" \ No newline at end of file + SECRET_DOCUMENTER_KEY: "iRC4P/r5o9pARB670eK9jPlKQKgkTMDAyvp2GbLG8WwLuT8T1VcWx/o4+ofGlzbTh5Z+LuFgPXfgqkjGuoWLcocHNm78xQMNMywB4rcLB2shqp8xG2vhglgnTBBS4EiyPAtVqGyi5AKmfF95PfkJvnI0Lqg5P/RWQvNGywLAR0Ikgr/lqocm2CvkFGbpMzpGxGvj76JYOusVeKvGAp698TXqPabSZR2oZQLfYnEZnaO8ivkqvMGQSXfgzoIMjCOrN1rSa84SWeI9BDeBslzDHwaYGlvjpfCyviiLtKj4t5Acl1gVE0qxxZxWuALIU6z+C1W8TbW7ZDCBUFs6UTIT+Q==;U2FsdGVkX1+/HSgg1skLszz835vSO6mEtXMhG62ohQQUc5opdo7kEIAG2wCoJPQrqGyaF9kKDVvrN5G2MdjUyaLBYlv90RzXhjTiMNFdgI3M4K500xKq3itY/aEL7hUSMRKxTos8u4xhdbRboY4rPcqgtCJ2LHEjNxmml/NfEo/8lk291rGoEYQLTvKP9cuo4enmEVVRhqmabBzt1MDz0m4c8RufJWW2Ni4osaKRkYPjl/ijJ38wvRUZIiyCX7uofh+3iCKWn0111q5xFhn256Pm79Cx2ZP+yTp9sMsVNMJZ3UJ5r18F3H+zFHWWQSoiWpHn2WNB/2VUEyt0Lp1LnogKru96P2oYkXi6kqrA+qlLISUUU7R7ggJU0IRS6MjSGDyVzlaZG8m+RmY0bmQKrDwSeq1JMGkBpjwPY1o4yOnFRB7Rj1bzToLtd2IFSa8x0a2dUSyL5pBlyWklzZCxPp05R53RNSOi2KfhNfdZU2H7xEj5+z2aV5OidzowXIyYH8FlusMdk3NAOsvTbmBGiwvN4Zub9Exli06ZwARu/oJHLRh+hgOErIJ7DoX6nPrAtofSy6Etydpt+c4HkVZtGPWFSTMNWIGNx2NB1IfveOTU60H5emQ7zow5grXz4VTczqvCIh2hoQdSR4Oplr6+tDDLhtcGGHchHt473o2ygQ1m1tg7oSvMN7jmkUV1N6GniQofmlbr8d5LK4i/QtfC5GHCKIg3ohRlDvuvvKzvVWofgHX3NhXFTKK/CWAIp76iOaCWJcI562SpKyn+pFqYKpatJ42WfF3VbNpJYVMYMai5BwAE2RyZ6FhHbsaHq/NXO/dRJwHeDm4Pc/LFlGFdzpdbuf+w2DoePc56PlNmKsLNlZVlwbWcExKttI8nz3Th3aHNNtbIbD9awf1RdDspudQrTPWkyEopDVm7TkOj/J891U5p24PF5dasIJR19Tqpic3LVJuBXYRbL/Z79VRjeE3wBGLTDdhzJMA8TrS+yMSCF80bIw/F44o4WbA3Ya425mph9MIt/a137osRKATYqbustmVW/LfIyVhuHCOCRQsqTyFU+ff6Tp0EE2i1du90wosr+UutXiubYphCmuKkZONPbiXjpW1CAi40iAwxfgOVqAl13y4FlUp4EiGS7hPBUbvvEXMqT3ssfL+mlideH/v08PQCRcyG03zcCjCTmjXCggqHd+eEXhnsNZ4PFKCKiN+znR5SW+/p+kJTaBrX2e/kMU6kzjwb4NyNmZie0hHSneVtwJ1FuXJk/Zph4quv5KugCCx21xb5pePqxfKRW5jtW6r2Rc7OSNN4BHjwAcj8fOVV+12Ak7//o8mRh0aveYfoEvjCdaI8OPfjduDGfmzPUvXiqV9kGpovdlDUATyoVa3l1CowJ5r8KDOD6Ps89OG7TV2c7Wzxq2FQVjMFXxv/4wMZR1F/0zyH+ofPLVZjK3039z35GD4uoOW9Uc7WSr4FbxxuCDwOXWgstuk3rk6ASZFSe7RIwE/Y16d/aqzI+LG8pHqaEdhg6o6Y6JxBYNQo/JoglUOHwD+N5g5n9vfBNzf0xTlE/r0yjO3LCHyWzCnWr3QdKgzm6EDyL8GO+yQIbtXtw6lRQB/UEZ+ayt175r08Yhey95IsPwLVDFRRlG6pYwmzTlQOEwvqDI8SDMWboU+jp6a5jrbaAmqiIkaoiIzrV1QDp1x+Sqj0veqN+RtcpXLawJevz8dm76H+Mmp1br61nwvGcBaOKukICVj3iLeeu5tV5NoEJznWPwveHrcarZtKvOOeJbydmNAz286i0F1ocX337dt17jIkRv9sHbfqAVapob+eT7F3N/UY99GWGDVbXzaruQwsuPPR6MbLolG6buHQaKX3OZ/zJqGWfEAHw5yJKoKNe8aSgY2DsoITqPlbNRQQmOIMuF8ffD8L1stD/P5Ohth5Nql2W+l6y87/nqxkJ9y4FFS4QzrMrl9ztugfsRoYyeSWRydLUHlTCv155VsGAxjCMBQg1rP99Smfd02EbCFlWlypIw/zem0LZ1zVuz/Wjb03n+dzi2GIKRlTrt6YMrGGAcKI+3Pf1D0rsDhXNkdFUjOeofUkDbBr/splYCKLucDHFVdN88XyaQoj2fBymNJ4BqvK64TVOLwPGAQvh/rHZ5PkJR3lMI4fg+Kxdl9/5xDjkD9aV+yRvfqVGodNW/qofq34nrdb3co1tZ4BxtSANKdJg3Fv6U0I4DOMVsJTeOn/918M31rif0rKAwnHAkeyQVbZyEsFoqxvE8gUFs1zTRwZJWlmY0xnuVcM8pOh6hULeYGiF57ZlbvymygYqObe58YgrChRnF4NhKIIYzuz7mOSKRXqF3Cr0LNYHcktUH9wrqISxiHbaUQceYZ1D0q8UfiayeK9yppMkltcDUL9M93xjTGJK8pVzARXn6ETuEsNTtLvbU/KMDY7bnVc7n08suLCk1YeJB/sn0wuTbPt+27NeYIG1YXBEE0dsgJW4z64489h71v4xws856gFOHZx0L/nkW7l328HA3jltbgJFl52mQHAJwUZrt5sJef/k7gsTdX1zQtjKN8lFjo4qpvJUpenmO9nT+Wty5cjohlETBos8CdSqj4SjEu7/UhDt52evt33EayoWJ8TjKd4VRFYCXnM6eGnSMDqUU5f7DxVjrwHnT26jtq9ijKTiAxls7fYjN8TGT/S3CHZZAK1u5gSbWfkFOcE+mioboNwDvuvysjL6de+bsc7r35w4hLFnPmKemcde4pNQfEnuelBFJqwYZbcAkhN8AmtqIWPXBw9n3eUx/TJgMFEIoB/frNDRbB0WJKdBkjdE1NVvAUl3jDnZbWjG6rqE+6UvyGqKBpd0FRYAfg3ss3hVB70uluULKUBVazlNIQlqX+qYEMBXaDIkxcftre8KYebQyJnxiOB5V+eELvm6L28bK4Xh2tpXzJL7aDlQnL8dRNvQdZgDL62EXYhrc3mz0I/p7br3KMcnei/LaPRAgcsW7WKLwzE5id6JnpOJj4VXdkX7IUB4xQjDRsGKxhjbklMVFA8g/801khNlwzU/IoXsHBgTs7yZoFX/oo4Jyp514hwqPlvJEgci0OHiSA6Mx3le2nUh0SQH+AzFJ2vi7Bn1a4psiuqd+vJJ1iuNw5CBCZlV+GO8sG93BBGnLzZDoRvkIMbzwESFP3JYZ/lKs29CB2Adobl9YbwP3he0I9cD0A/RPC70gzTdVEfL6T4iPUhBr1Bn3YlUPeC2QvCTbpKkxDsfzchuq/y0xlmL4E7Rdb+4TSMlViXfnc6aoD9vvPMWLJFF2qrxRLKhUTse5V6RoE+EVmHSiX0Vd7sd/bYp7asOC0b1xL+zjfJ5DSrtMA/P8L1p+CoLNXgVfgzCB3sCa+GLSLS2INsL1Qtnfkl8IGaMDeV+VAyHjY0HCj0l1X99f/RzD6TYrZAkLS8h1EM/JjomglhVG9/HTKS20BBJeos5ifrVd38rhONJy0HCP28pn4rCIyIE4bNG+1tEsHAg4FDYgh/OYuBsaGYgha9TGV5lGIxmVCECq3IPpkPN1CsLqv3KuDvNeH6XOOAzVtFj4VoIV6QgRLP8+94ZiiEDaPQxQ7BZoqrqFYrxWHDtEuon46VtQ3Nfq/1Rq/HvszJv6JE77w7qvKlxG9sXgxzCDRqNrG83cwY2hpDBr8U0hPMrEx977Weja1aG/rG6uirNBcY5qAAOLDo+9RvV1xqvWFF8SkT97tzNUHbzw8tuUlCT9m4rshCG+jBw59rpUZwW+eR1ih9qU7Nyr3oNgi/zmkORF1duym8VSfW5dxtRBIqxxM0oSWoHti+HSd0VLdHw8jRpbQddMBr1sjD1jIgp3w2dU4oEthzStKCPY2/lAWBm+1Es1okGhEM3I939DRcYOjfJnTCtJLJ9DTKycVDMerXvHnCgImZ0Oh4mtLF+63hn+9wUc56owFeNqs+NJHqmBBFX2uNr/Rj9mzYkRRPsYYSyCB7jIS+Z8Zall6W3dwLcsE3uw/oPKx5bJDAhnp7kZgzLC0zlS2D0ZcNZuW2uUtwhZJM6OOyV+FUFgizmpIQAQ8Nm6n/1yk0asB4jZFf221a9ZmzvUfWKmmIR7OxX3qBH9x2uMMhemv9LZdEHMcjTeIXRYciMLWUNeWagYhDgV1cRBGCDTh2EhHvYX7ZXfpsHjLOR+sAEr7uR3siitf/mRkiLfT2YBgTACKKoj05UuC8aknEV4T5bWiye+gKGioml5G/fWYHyHow37g6D84n0cBTWmI0oPlg+rqpeRLOeYaTeCXOtM/7M1FHuGvzmBnag2vhKY2tpjVrg2nI3p4SRlzTyoQkyMfRXN87v5nAheVcLgrYtkv9aX7R6VMZ1UIsxn62ZHFa2IR6skB/xw7RRuJY5r5FIWs1LqIQDaon5L4C4v9rnBxMYoUM" diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index a2acf191fa..237293a0f3 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,6 +1,6 @@ steps: - label: "Triggering Pipelines (Pull Request)" - if: "build.pull_request.base_branch == 'main'" + if: build.branch != "main" || build.tag == null agents: queue: "juliagpu" plugins: diff --git a/docs/tutorials.jl b/docs/tutorials.jl index 9d11a55637..ab49870308 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -1,36 +1,49 @@ #! format: off const BEGINNER_TUTORIALS = [ - "Basics/main.jl", - "PolynomialFitting/main.jl", - "SimpleRNN/main.jl", - "SimpleChains/main.jl" + "Basics/main.jl" => "CUDA", + "PolynomialFitting/main.jl" => "CUDA", + "SimpleRNN/main.jl" => "CUDA", + "SimpleChains/main.jl" => "CPU" ] const INTERMEDIATE_TUTORIALS = [ - "NeuralODE/main.jl", - "BayesianNN/main.jl", - "HyperNet/main.jl" + "NeuralODE/main.jl" => "CUDA", + "BayesianNN/main.jl" => "CPU", + "HyperNet/main.jl" => "CUDA", ] const ADVANCED_TUTORIALS = [ - "GravitationalWaveForm/main.jl", + "GravitationalWaveForm/main.jl" => "CPU", ] const TUTORIALS = [ - collect(enumerate(Iterators.product(["beginner"], BEGINNER_TUTORIALS)))..., - collect(enumerate(Iterators.product(["intermediate"], INTERMEDIATE_TUTORIALS)))..., - collect(enumerate(Iterators.product(["advanced"], ADVANCED_TUTORIALS)))... + collect(enumerate(Iterators.product(["beginner"], first.(BEGINNER_TUTORIALS))))..., + collect(enumerate(Iterators.product(["intermediate"], first.(INTERMEDIATE_TUTORIALS))))..., + collect(enumerate(Iterators.product(["advanced"], first.(ADVANCED_TUTORIALS))))... ] +const BACKEND_LIST = lowercase.([ + last.(BEGINNER_TUTORIALS)..., + last.(INTERMEDIATE_TUTORIALS)..., + last.(ADVANCED_TUTORIALS)... +]) #! format: on +const BACKEND_GROUP = lowercase(get(ENV, "TUTORIAL_BACKEND_GROUP", "all")) + const BUILDKITE_PARALLEL_JOB_COUNT = parse( Int, get(ENV, "BUILDKITE_PARALLEL_JOB_COUNT", "-1")) +const TUTORIALS_WITH_BACKEND = if BACKEND_GROUP == "all" + TUTORIALS +else + TUTORIALS[BACKEND_LIST .== BACKEND_GROUP] +end + const TUTORIALS_BUILDING = if BUILDKITE_PARALLEL_JOB_COUNT > 0 id = parse(Int, ENV["BUILDKITE_PARALLEL_JOB"]) + 1 # Index starts from 0 - splits = collect(Iterators.partition( - TUTORIALS, cld(length(TUTORIALS), BUILDKITE_PARALLEL_JOB_COUNT))) + splits = collect(Iterators.partition(TUTORIALS_WITH_BACKEND, + cld(length(TUTORIALS_WITH_BACKEND), BUILDKITE_PARALLEL_JOB_COUNT))) id > length(splits) ? [] : splits[id] else - TUTORIALS + TUTORIALS_WITH_BACKEND end const NTASKS = min( diff --git a/examples/GravitationalWaveForm/Project.toml b/examples/GravitationalWaveForm/Project.toml index 73f3948089..67e420f41c 100644 --- a/examples/GravitationalWaveForm/Project.toml +++ b/examples/GravitationalWaveForm/Project.toml @@ -5,7 +5,6 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" @@ -19,7 +18,6 @@ ComponentArrays = "0.15" LineSearches = "7" Literate = "2" Lux = "1" -LuxCUDA = "0.3" Optimization = "3" OptimizationOptimJL = "0.3" OrdinaryDiffEq = "6" diff --git a/examples/GravitationalWaveForm/main.jl b/examples/GravitationalWaveForm/main.jl index 4bbed4952b..56bbb23018 100644 --- a/examples/GravitationalWaveForm/main.jl +++ b/examples/GravitationalWaveForm/main.jl @@ -7,12 +7,10 @@ # ## Package Imports -using Lux, ComponentArrays, LineSearches, LuxCUDA, OrdinaryDiffEq, Optimization, - OptimizationOptimJL, Printf, Random, SciMLSensitivity +using Lux, ComponentArrays, LineSearches, OrdinaryDiffEq, Optimization, OptimizationOptimJL, + Printf, Random, SciMLSensitivity using CairoMakie -CUDA.allowscalar(false) - # ## Define some Utility Functions # !!! tip From 52c8880c76949e347f7cbfc84fda35f7d2cc1079 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 18:38:45 -0400 Subject: [PATCH 93/95] docs: try fixing nested autodiff --- docs/Project.toml | 1 + docs/src/manual/nested_autodiff.md | 28 +++++------ test/autodiff/nested_autodiff_tests.jl | 64 ++++++++------------------ 3 files changed, 35 insertions(+), 58 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 2126c3122d..85ac205ae7 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -21,6 +21,7 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/docs/src/manual/nested_autodiff.md b/docs/src/manual/nested_autodiff.md index 0a5e074a47..497179c11d 100644 --- a/docs/src/manual/nested_autodiff.md +++ b/docs/src/manual/nested_autodiff.md @@ -22,7 +22,7 @@ Let's explore this using some questions that were posted on the [Julia Discourse forum](https://discourse.julialang.org/). ```@example nested_ad -using ADTypes, Lux, LinearAlgebra, Zygote, ForwardDiff, Random +using ADTypes, Lux, LinearAlgebra, Zygote, ForwardDiff, Random, StableRNGs using ComponentArrays, FiniteDiff ``` @@ -70,15 +70,15 @@ function loss_function1(model, x, ps, st, y) loss_emp = sum(abs2, ŷ .- y) # You can use `Zygote.jacobian` as well but ForwardDiff tends to be more efficient here J = ForwardDiff.jacobian(smodel, x) - loss_reg = abs2(norm(J)) + loss_reg = abs2(norm(J .* 0.01f0)) return loss_emp + loss_reg end # Using Batchnorm to show that it is possible model = Chain(Dense(2 => 4, tanh), BatchNorm(4), Dense(4 => 2)) -ps, st = Lux.setup(Xoshiro(0), model) -x = rand(Xoshiro(0), Float32, 2, 10) -y = rand(Xoshiro(11), Float32, 2, 10) +ps, st = Lux.setup(StableRNG(0), model) +x = randn(StableRNG(0), Float32, 2, 10) +y = randn(StableRNG(11), Float32, 2, 10) loss_function1(model, x, ps, st, y) ``` @@ -97,9 +97,9 @@ Now let's verify the gradients using finite differences: ComponentArray(ps)) println("∞-norm(∂x - ∂x_fd): ", norm(∂x .- ∂x_fd, Inf)) -@assert norm(∂x .- ∂x_fd, Inf) < 1e-1 # hide +@assert norm(∂x .- ∂x_fd, Inf) < 1e-2 # hide println("∞-norm(∂ps - ∂ps_fd): ", norm(ComponentArray(∂ps) .- ∂ps_fd, Inf)) -@assert norm(ComponentArray(∂ps) .- ∂ps_fd, Inf) < 1e-1 # hide +@assert norm(ComponentArray(∂ps) .- ∂ps_fd, Inf) < 1e-2 # hide nothing; # hide ``` @@ -123,8 +123,8 @@ end model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 1)) -ps, st = Lux.setup(Xoshiro(0), model) -t = rand(Xoshiro(0), Float32, 1, 16) +ps, st = Lux.setup(StableRNG(0), model) +t = rand(StableRNG(0), Float32, 1, 16) ``` Now the moment of truth: @@ -164,9 +164,9 @@ end model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 1)) -ps, st = Lux.setup(Xoshiro(0), model) +ps, st = Lux.setup(StableRNG(0), model) ps = ComponentArray(ps) # needs to be an AbstractArray for most jacobian functions -x = rand(Xoshiro(0), Float32, 1, 16) +x = rand(StableRNG(0), Float32, 1, 16) ``` We can as usual compute the gradient/jacobian of the loss function: @@ -260,9 +260,9 @@ Now let's compute the trace and compare the results: ```@example nested_ad model = Chain(Dense(4 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 4)) -ps, st = Lux.setup(Xoshiro(0), model) -x = rand(Xoshiro(0), Float32, 4, 12) -v = (rand(Xoshiro(12), Float32, 4, 12) .> 0.5f0) * 2.0f0 .- 1.0f0 # rademacher sample +ps, st = Lux.setup(StableRNG(0), model) +x = rand(StableRNG(0), Float32, 4, 12) +v = (rand(StableRNG(12), Float32, 4, 12) .> 0.5f0) * 2.0f0 .- 1.0f0 # rademacher sample nothing; # hide ``` diff --git a/test/autodiff/nested_autodiff_tests.jl b/test/autodiff/nested_autodiff_tests.jl index 8344006767..850c313878 100644 --- a/test/autodiff/nested_autodiff_tests.jl +++ b/test/autodiff/nested_autodiff_tests.jl @@ -27,28 +27,15 @@ function test_nested_ad_input_gradient_jacobian(aType, dev, ongpu, loss_fn, X, m !iszero(ComponentArray(∂ps |> cpu_device())) && all(x -> x === nothing || isfinite(x), ComponentArray(∂ps |> cpu_device())) - __f = (x, ps) -> loss_fn(model, x, ps, st) - allow_unstable() do - FDIFF_WORKS = try - LuxTestUtils.gradient(__f, AutoForwardDiff(), X, ps) - true - catch - false - end - skip_backends = [AutoReverseDiff(), AutoTracker(), AutoEnzyme()] - if FDIFF_WORKS - test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; atol=1.0f-3, - rtol=1.0f-1, soft_fail=[AutoFiniteDiff()], skip_backends) - else - test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; atol=1.0f-3, - rtol=1.0f-1, skip_backends=vcat(skip_backends, [AutoFiniteDiff()])) - end + test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; + atol=1.0f-3, rtol=1.0f-1, soft_fail=[AutoFiniteDiff()], + skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) end end -const Xs = (randn(rng, Float32, 3, 3, 2, 4), randn(rng, Float32, 2, 4), - randn(rng, Float32, 2, 4), randn(rng, Float32, 3, 3, 2, 4)) +const Xs = (randn(rng, Float32, 3, 3, 2, 2), randn(rng, Float32, 2, 2), + randn(rng, Float32, 2, 2), randn(rng, Float32, 3, 3, 2, 2)) const models = ( Chain(Conv((3, 3), 2 => 4, gelu; pad=SamePad()), BatchNorm(4), @@ -63,25 +50,25 @@ const models = ( # smodel | ForwardDiff.jacobian function loss_function1(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, ForwardDiff.jacobian(smodel, x)) + return sum(abs2, ForwardDiff.jacobian(smodel, x) .* 0.01f0) end # smodel | Zygote.jacobian function loss_function2(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, only(Zygote.jacobian(smodel, x))) + return sum(abs2, only(Zygote.jacobian(smodel, x)) .* 0.01f0) end # sum(abs2) ∘ smodel | ForwardDiff.gradient function loss_function3(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, ForwardDiff.gradient(Base.Fix1(sum, abs2) ∘ smodel, x)) + return sum(abs2, ForwardDiff.gradient(Base.Fix1(sum, abs2) ∘ smodel, x) .* 0.01f0) end # sum(abs2) ∘ smodel | Zygote.gradient function loss_function4(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ smodel, x))) + return sum(abs2, only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ smodel, x)) .* 0.01f0) end const ALL_TEST_CONFIGS = Iterators.product( @@ -165,28 +152,15 @@ function test_nested_ad_parameter_gradient_jacobian(aType, dev, ongpu, loss_fn, !iszero(ComponentArray(∂ps |> cpu_device())) && all(x -> x === nothing || isfinite(x), ComponentArray(∂ps |> cpu_device())) - __f = (x, ps) -> loss_fn(model, x, ps, st) - allow_unstable() do - FDIFF_WORKS = try - LuxTestUtils.gradient(__f, AutoForwardDiff(), X, ps) - true - catch - false - end - skip_backends = [AutoReverseDiff(), AutoTracker(), AutoEnzyme()] - if FDIFF_WORKS - test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; atol=1.0f-3, - rtol=1.0f-1, soft_fail=[AutoFiniteDiff()], skip_backends) - else - test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; atol=1.0f-3, - rtol=1.0f-1, skip_backends=vcat(skip_backends, [AutoFiniteDiff()])) - end + test_gradients((x, ps) -> loss_fn(model, x, ps, st), X, ps; + atol=1.0f-3, rtol=1.0f-1, soft_fail=[AutoFiniteDiff()], + skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) end end -const Xs = (randn(rng, Float32, 3, 3, 2, 4), randn(rng, Float32, 2, 4), - randn(rng, Float32, 2, 4), randn(rng, Float32, 3, 3, 2, 4)) +const Xs = (randn(rng, Float32, 3, 3, 2, 2), randn(rng, Float32, 2, 2), + randn(rng, Float32, 2, 2), randn(rng, Float32, 3, 3, 2, 2)) const models = ( Chain(Conv((3, 3), 2 => 4, gelu; pad=SamePad()), BatchNorm(4), @@ -201,25 +175,27 @@ const models = ( # smodel | ForwardDiff.jacobian function loss_function1(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, ForwardDiff.jacobian(Base.Fix1(smodel, x), ps)) + return sum(abs2, ForwardDiff.jacobian(Base.Fix1(smodel, x), ps) .* 0.01f0) end # smodel | Zygote.jacobian function loss_function2(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, only(Zygote.jacobian(Base.Fix1(smodel, x), ps))) + return sum(abs2, only(Zygote.jacobian(Base.Fix1(smodel, x), ps)) .* 0.01f0) end # sum(abs2) ∘ smodel | ForwardDiff.gradient function loss_function3(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, ForwardDiff.gradient(Base.Fix1(sum, abs2) ∘ Base.Fix1(smodel, x), ps)) + return sum(abs2, + ForwardDiff.gradient(Base.Fix1(sum, abs2) ∘ Base.Fix1(smodel, x), ps) .* 0.01f0) end # sum(abs2) ∘ smodel | Zygote.gradient function loss_function4(model, x, ps, st) smodel = StatefulLuxLayer{true}(model, ps, st) - return sum(abs2, only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ Base.Fix1(smodel, x), ps))) + return sum(abs2, + only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ Base.Fix1(smodel, x), ps)) .* 0.01f0) end const ALL_TEST_CONFIGS = Iterators.product( From 5010f100fde8f6bb24aa9b513c5c2da2734d3f8d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 22:00:56 -0400 Subject: [PATCH 94/95] docs: use the linux runners [skip tests] --- .buildkite/documentation.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.buildkite/documentation.yml b/.buildkite/documentation.yml index ecdb2c7d56..2a4398a12a 100644 --- a/.buildkite/documentation.yml +++ b/.buildkite/documentation.yml @@ -13,7 +13,7 @@ steps: dirs: - src - ext - command: julia --code-coverage=user --color=yes --project=docs docs/tutorials.jl + command: julia --code-coverage=user --color=yes --project=docs --threads=auto docs/tutorials.jl env: TUTORIAL_BACKEND_GROUP: "CUDA" agents: @@ -39,13 +39,13 @@ steps: # dirs: # - src # - ext - command: julia --code-coverage=user --color=yes --project=docs docs/tutorials.jl + command: julia --code-coverage=user --color=yes --project=docs --threads=auto docs/tutorials.jl env: TUTORIAL_BACKEND_GROUP: "CPU" agents: queue: "juliaecosystem" - os: "macos" - arch: "aarch64" + os: "linux" + arch: "x86_64" artifact_paths: - "docs/src/tutorials/beginner/**/*" - "docs/src/tutorials/intermediate/**/*" From 0bd7099da3f377d8345dd9fbf1105b34334056cf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 6 Sep 2024 22:27:03 -0400 Subject: [PATCH 95/95] fix: update simplechains layer API --- src/layers/extension.jl | 41 ++++++++++++++--------------------- src/transform/simplechains.jl | 2 +- 2 files changed, 17 insertions(+), 26 deletions(-) diff --git a/src/layers/extension.jl b/src/layers/extension.jl index e4d7298ca7..8242790a86 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -48,8 +48,8 @@ Base.show(io::IO, ::MIME"text/plain", l::FluxLayer) = print(io, "FluxLayer($(l.l ## SimpleChains.jl """ - SimpleChainsLayer{ToArray}(layer, lux_layer=nothing) - SimpleChainsLayer(layer, ToArray::Union{Bool, Val}=Val(false)) + SimpleChainsLayer(layer, to_array::Union{Bool, Val}=Val(false)) + SimpleChainsLayer(layer, lux_layer, to_array) Wraps a `SimpleChains` layer into a `Lux` layer. All operations are performed using `SimpleChains` but the layer satisfies the `AbstractLuxLayer` interface. @@ -62,39 +62,30 @@ regular `Array` or not. Default is `false`. - `layer`: SimpleChains layer - `lux_layer`: Potentially equivalent Lux layer that is used for printing """ -struct SimpleChainsLayer{ToArray, SL, LL <: Union{Nothing, AbstractLuxLayer}} <: - AbstractLuxLayer - to_array::ToArray - layer::SL - lux_layer::LL - - function SimpleChainsLayer{ToArray}(layer, lux_layer=nothing) where {ToArray} - to_array = static(ToArray) - return new{typeof(to_array), typeof(layer), typeof(lux_layer)}( - to_array, layer, lux_layer) - end - function SimpleChainsLayer(layer, ToArray::BoolType=False()) - to_array = static(ToArray) - return new{typeof(to_array), typeof(layer), Nothing}(to_array, layer, nothing) - end +@concrete struct SimpleChainsLayer <: AbstractLuxLayer + layer + lux_layer <: Union{Nothing, AbstractLuxLayer} + to_array <: StaticBool +end + +function SimpleChainsLayer(layer, to_array::BoolType=False()) + return SimpleChainsLayer(layer, nothing, static(to_array)) end -function Base.show( - io::IO, ::MIME"text/plain", s::SimpleChainsLayer{ToArray}) where {ToArray} - PrettyPrinting.print_wrapper_model( - io, "SimpleChainsLayer{to_array=$ToArray}", s.lux_layer) +function Base.show(io::IO, ::MIME"text/plain", s::SimpleChainsLayer) + PrettyPrinting.print_wrapper_model(io, "SimpleChainsLayer", s.lux_layer) end function (sc::SimpleChainsLayer)(x, ps, st) y = match_eltype(sc, ps, st, x) return ( - simple_chain_output( - sc, apply_simple_chain(sc.layer, y, ps.params, MLDataDevices.get_device(x))), + to_array(sc.to_array, + apply_simple_chain(sc.layer, y, ps.params, MLDataDevices.get_device(x))), st) end -simple_chain_output(::SimpleChainsLayer{False}, y) = y -simple_chain_output(::SimpleChainsLayer{True}, y) = convert(Array, y) +to_array(::False, y) = y +to_array(::True, y) = convert(Array, y) apply_simple_chain(layer, x, ps, ::CPUDevice) = layer(x, ps) diff --git a/src/transform/simplechains.jl b/src/transform/simplechains.jl index f6e2ecb7e9..18c840d590 100644 --- a/src/transform/simplechains.jl +++ b/src/transform/simplechains.jl @@ -69,7 +69,7 @@ function Adapt.adapt(to::ToSimpleChainsAdaptor, L::AbstractLuxLayer) error("`ToSimpleChainsAdaptor` requires `SimpleChains.jl` to be loaded.") end sc_layer = fix_simplechain_input_dims(make_simplechain_network(L), to.input_dims) - return SimpleChainsLayer{to.convert_to_array}(sc_layer, L) + return SimpleChainsLayer(sc_layer, L, static(to.convert_to_array)) end function make_simplechain_network end