Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

fix!: remove deprecations for 1.0 release #82

Merged
merged 7 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.51"
version = "1.0.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down Expand Up @@ -67,9 +67,9 @@ Hwloc = "3.2"
KernelAbstractions = "0.9.22"
LinearAlgebra = "1.10"
LoopVectorization = "0.12.171"
LuxCore = "0.1.13, 1"
LuxCore = "1"
MKL = "0.7"
MLDataDevices = "1.0.0"
MLDataDevices = "1"
Markdown = "1.10"
NNlib = "0.9.21"
Octavian = "0.3.28"
Expand Down
1 change: 1 addition & 0 deletions benchmarks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand Down
1 change: 1 addition & 0 deletions benchmarks/setup.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using MLDataDevices, StableRNGs, Random
using NNlib
using Zygote

synchronize(::CPUDevice) = nothing
Expand Down
6 changes: 1 addition & 5 deletions src/LuxLib.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module LuxLib

using Compat: @compat
using Random: AbstractRNG
using Reexport: @reexport
using Static: Static, known
using UnrolledUtilities: unrolled_filter
Expand All @@ -10,9 +9,7 @@ using ChainRulesCore: ChainRulesCore, NoTangent

using LuxCore: LuxCore
using MLDataDevices: get_device_type, AbstractGPUDevice
using NNlib: NNlib, ConvDims, σ

@reexport using NNlib
using NNlib: NNlib

const Optional{T} = Union{Nothing, T}
const Numeric = Union{AbstractArray{<:T}, T} where {T <: Number}
Expand All @@ -23,7 +20,6 @@ include("utils.jl")
include("traits.jl")
include("impl/Impl.jl")
include("api/API.jl")
include("deprecations.jl")

@compat(public,
(internal_operation_mode, GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp))
Expand Down
22 changes: 9 additions & 13 deletions src/api/layernorm.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@doc doc"""
layernorm(x, scale, bias, σ = identity, dims=Colon(),
epsilon = eps(eltype(x)) ^ (5 / 7))
layernorm(x::AbstractArray{xT, N}, scale, bias, σ = identity, dims=1:(N - 1),
epsilon = eps(eltype(x)) ^ (5 / 7)) where {xT, N}

Layer Normalization. For details see [1].

Expand All @@ -18,17 +18,13 @@ and applies the activation function `σ` elementwise to `y`.
- `scale`: Scale factor (``\gamma``) (can be `nothing`)
- `bias`: Bias factor (``\beta``) (can be `nothing`)
- `σ`: Activation function (default: `identity`)
- `dims`: Dimensions along which the mean and std of `x` is computed (default: `Colon()`).
If `nothing` is passed, the dims are inferred based on the dimensions of scale and
bias. For example, if `x` is `N` dimensional and `scale` and `bias` are `M`
dimensional, then the dims will be `1:(N - M)`.
- `dims`: Dimensions along which the mean and std of `x` is computed. If `nothing` is
passed, the dims are inferred based on the dimensions of scale and bias. For example,
if `x` is `N` dimensional and `scale` and `bias` are `M` dimensional, then the dims
will be `1:(N - M)`.
- `epsilon`: Value added to the denominator for numerical stability
(default: `eps(eltype(x)) ^ (5 / 7)`)

!!! danger "Default `dims` to be changed in v1"

By default, `dims` will exclude the batch dimension.

## Returns

Normalized Array of same size as `x`.
Expand All @@ -38,9 +34,9 @@ Normalized Array of same size as `x`.
[1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv
preprint arXiv:1607.06450 (2016).
"""
function layernorm(x::AbstractArray{xT}, scale::Optional{<:AbstractArray},
bias::Optional{<:AbstractArray}, σ::F=identity, dims=Colon(),
epsilon::Real=default_epsilon(x)) where {F, xT}
function layernorm(x::AbstractArray{xT, N}, scale::Optional{<:AbstractArray},
bias::Optional{<:AbstractArray}, σ::F=identity, dims=1:(N - 1),
epsilon::Real=default_epsilon(x)) where {F, xT, N}
return layernorm_impl(
x, scale, bias, select_fastest_activation(σ, x, scale, bias), dims, epsilon)
end
2 changes: 1 addition & 1 deletion src/impl/Impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ using NNlib: NNlib, ConvDims
using ..LuxLib: Optional, Numeric, ∂∅, internal_operation_mode, AbstractInternalArrayOpMode,
GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp
using ..Utils: Utils, NotaNumber, batchview, concrete_bias_act_output_eltype, contiguous,
copy_drop_gradients, depwarn, eltype_mismatch, expand_batchdim,
copy_drop_gradients, eltype_mismatch, expand_batchdim,
maybe_reduce_BLAS_threads, ofeltype_array, only_derivative, remove_tracking,
reset_BLAS_threads, run_ka_kernel, safe_eltype, safe_vec, safe_warning,
unsafe_known, @enzyme_alternative
Expand Down
19 changes: 9 additions & 10 deletions src/impl/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,8 @@ function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, p::T,
end

function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray,
p::T, ::True, ::False, invp::T, dims) where {T}
if dropout_shape(x, dims) != size(mask)
depwarn(
"`update_mask` is `Val(false)` but `mask` is not of the same size \
as `LuxLib.dropout_shape(x, dims)`. This has been deprecated and \
will be removed in the next release. Set `update_mask` to \
`Val(true)` to avoid this.", :dropout)
mask, rngₙ = generate_dropout_mask(rng, x, p, invp, dims)
return dropout_dot_mul(x, mask), mask, rngₙ
end
::T, ::True, ::False, invp::T, dims) where {T}
check_dropout_mask_shape_mismatch(x, mask, dims)
return dropout_dot_mul(x, mask), mask, rng
end

Expand All @@ -31,6 +23,13 @@ function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray,
return (x, mask, rng)
end

function check_dropout_mask_shape_mismatch(x::AbstractArray, mask::AbstractArray, dims)
@assert dropout_shape(x, dims)==size(mask) "`mask` is not of the same size as `LuxLib.dropout_shape(x, dims)`."
return nothing
end

CRC.@non_differentiable check_dropout_mask_shape_mismatch(::Any...)

## alpha_dropout
function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, ::True) where {T}
α = T(-1.7580993408473766)
Expand Down
4 changes: 2 additions & 2 deletions test/common_ops/dense_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ end
end

@testitem "Fused Dense: StaticArrays" tags=[:dense] begin
using StaticArrays
using StaticArrays, NNlib

x = @SArray rand(2, 4)
weight = @SArray rand(3, 2)
Expand All @@ -112,7 +112,7 @@ end
end

@testitem "Fused Dense: CPU No Scalar Indexing" tags=[:dense] begin
using JLArrays
using JLArrays, NNlib

x = JLArray(rand(Float32, 2, 4))
weight = JLArray(rand(Float32, 3, 2))
Expand Down
34 changes: 1 addition & 33 deletions test/common_ops/dropout_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@
end

@testitem "Dropout with Preset Mask" tags=[:other_ops] setup=[SharedTestSetup] begin
Enzyme.API.runtimeActivity!(true) # TODO: remove in 1.0 after deprecation

using Statistics

rng = StableRNG(12345)
Expand Down Expand Up @@ -100,8 +98,7 @@ end

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon())))
# Branching based on runtime values
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any

__f = let rng = rng, mask = mask
x -> sum(first(dropout(
Expand All @@ -115,35 +112,6 @@ end
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType

# Try using mask if possible (not possible!!)
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())) isa Any

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())

@test y isa aType{T, length(x_shape)}
@test size(y) == x_shape
@test mask_ isa aType{T, length(x_shape)}
@test size(mask_) == x_shape
@test rng != rng_
@test mask != mask_

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, Colon())))
# Branching based on runtime activity
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any broken=true

__f = let rng = rng, mask = mask
x -> sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
end
test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3,
soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []),
broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : []))

@jet sum(first(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), Colon())))
# Testing Mode
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(false), Val(false), T(2), Colon())) isa Any
Expand Down
2 changes: 1 addition & 1 deletion test/others/qa_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testitem "Aqua: Quality Assurance" tags=[:others] begin
using Aqua, ChainRulesCore, EnzymeCore
using Aqua, ChainRulesCore, EnzymeCore, NNlib
using EnzymeCore: EnzymeRules

Aqua.test_all(
Expand Down
2 changes: 1 addition & 1 deletion test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import Reexport: @reexport

using LuxLib, MLDataDevices
@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote
@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote, NNlib

LuxTestUtils.jet_target_modules!(["LuxLib"])

Expand Down
Loading