diff --git a/.buildkite/benchmarks.yml b/.buildkite/benchmarks.yml index 1ba075194..52a4a7660 100644 --- a/.buildkite/benchmarks.yml +++ b/.buildkite/benchmarks.yml @@ -15,7 +15,11 @@ steps: command: | julia --project=benchmarks -e 'println("--- :julia: Instantiating project") using Pkg - Pkg.develop([PackageSpec(path=pwd())])' + Pkg.develop([ + PackageSpec(path=pwd()), + PackageSpec(path="lib/LuxLib"), + PackageSpec(path="lib/MLDataDevices"), + ])' julia --project=benchmarks -e 'println("--- :julia: Run Benchmarks") include("benchmarks/runbenchmarks.jl")' @@ -36,12 +40,16 @@ steps: version: "1" command: | julia --project=benchmarks -e 'println("--- :julia: Instantiating project") - using Pkg - Pkg.develop([PackageSpec(path=pwd())])' + using Pkg; + Pkg.develop([ + PackageSpec(path=pwd()), + PackageSpec(path="lib/LuxLib"), + PackageSpec(path="lib/MLDataDevices"), + ])' julia --project=benchmarks -e 'println("--- :julia: Add CUDA to benchmarks environment") using Pkg - Pkg.add("LuxCUDA")' + Pkg.develop([PackageSpec(path="lib/LuxCUDA")])' julia --project=benchmarks -e 'println("--- :julia: Run Benchmarks") include("benchmarks/runbenchmarks.jl")' diff --git a/.buildkite/documentation.yml b/.buildkite/documentation.yml index 80dc74d65..fa7ddf121 100644 --- a/.buildkite/documentation.yml +++ b/.buildkite/documentation.yml @@ -69,7 +69,11 @@ steps: julia --code-coverage=user --color=yes --project=docs -e ' println("--- :julia: Instantiating project") using Pkg - Pkg.develop(PackageSpec(path=pwd())) + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxLib", "lib/LuxCore", "lib/MLDataDevices", "lib/LuxTestUtils", "lib/WeightInitializers", ".") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); + end + Pkg.develop(dev_pkgs) Pkg.instantiate() println("+++ :julia: Building documentation") include("docs/make.jl")' diff --git a/.buildkite/downstream.yml b/.buildkite/downstream.yml new file mode 100644 index 000000000..1fb8c3283 --- /dev/null +++ b/.buildkite/downstream.yml @@ -0,0 +1,68 @@ +steps: + - group: ":telescope: Downstream CUDA" + steps: + - label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)" + plugins: + - JuliaCI/julia#v1: + version: "1.10" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + - lib/LuxCore/src + - lib/LuxCore/ext + - lib/LuxLib/src + - lib/LuxLib/ext + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + - lib/WeightInitializers/src + - lib/WeightInitializers/ext + command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" + agents: + queue: "juliagpu" + cuda: "*" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" + timeout_in_minutes: 60 + matrix: + setup: + repo: + - "Boltz" + - "NeuralPDE" + - "DeepEquilibriumNetworks" + - "NeuralOperators" + + - group: ":telescope: Downstream AMD GPU" + steps: + - label: ":julia: {{matrix.repo}} (Julia 1 + AMD GPU)" + plugins: + - JuliaCI/julia#v1: + version: "1.10" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + - lib/LuxCore/src + - lib/LuxCore/ext + - lib/LuxLib/src + - lib/LuxLib/ext + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + - lib/WeightInitializers/src + - lib/WeightInitializers/ext + command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "AMDGPU" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" + timeout_in_minutes: 60 + matrix: + setup: + repo: + - "Boltz" + - "NeuralOperators" + +env: + SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 4379ec8e1..d789d816a 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -8,6 +8,7 @@ steps: diff: ".buildkite/scripts/diff.sh $BUILDKITE_COMMIT" interpolation: false watch: + # Core Lux Testing - path: - "src/" - "ext/" @@ -18,6 +19,70 @@ steps: command: "buildkite-agent pipeline upload .buildkite/testing.yml" agents: queue: "juliagpu" + + # LuxCUDA Testing + - path: + - "lib/LuxCUDA/" + - ".buildkite/" + config: + command: "buildkite-agent pipeline upload .buildkite/testing_luxcuda.yml" + agents: + queue: "juliagpu" + + # WeightInitializers Testing + - path: + - "lib/WeightInitializers/" + - ".buildkite/" + config: + command: "buildkite-agent pipeline upload .buildkite/testing_weightinitializers.yml" + agents: + queue: "juliagpu" + + # MLDataDevices Testing + - path: + - "lib/MLDataDevices/" + - ".buildkite/" + config: + command: "buildkite-agent pipeline upload .buildkite/testing_mldatadevices.yml" + agents: + queue: "juliagpu" + + # LuxLib Testing + - path: + - "lib/LuxLib/" + - ".buildkite/" + - "lib/LuxTestUtils/" + - "lib/LuxCore/" + - "lib/MLDataDevices/" + config: + command: "buildkite-agent pipeline upload .buildkite/testing_luxlib.yml" + agents: + queue: "juliagpu" + + # LuxTestUtils Testing + - path: + - "lib/LuxTestUtils/" + - ".buildkite/" + config: + command: "buildkite-agent pipeline upload .buildkite/testing_luxtestutils.yml" + agents: + queue: "juliagpu" + + # Benchmarks + - path: + - "src/" + - "ext/" + - "test/" + - "Project.toml" + - ".buildkite/" + - "benchmarks/" + if: build.pull_request.labels includes "run benchmarks" + config: + command: "buildkite-agent pipeline upload .buildkite/benchmarks.yml" + agents: + queue: "juliagpu" + + # Documentation - path: - "src/" - "ext/" @@ -26,20 +91,22 @@ steps: - "docs/" - "examples/" - ".buildkite/" + - "lib" config: command: "buildkite-agent pipeline upload .buildkite/documentation.yml" agents: queue: "juliagpu" + + # Downstream - path: - "src/" - "ext/" - - "test/" + - "lib/" - "Project.toml" - ".buildkite/" - - "benchmarks/" - if: build.pull_request.labels includes "run benchmarks" + if: build.pull_request.labels includes "run downstream test" config: - command: "buildkite-agent pipeline upload .buildkite/benchmarks.yml" + command: "buildkite-agent pipeline upload .buildkite/downstream.yml" agents: queue: "juliagpu" @@ -48,6 +115,18 @@ steps: agents: queue: "juliagpu" command: | + # Core Lux Testing buildkite-agent pipeline upload .buildkite/testing.yml + + # Subpackage testing + buildkite-agent pipeline upload .buildkite/testing_luxcuda.yml + buildkite-agent pipeline upload .buildkite/testing_weightinitializers.yml + buildkite-agent pipeline upload .buildkite/testing_luxlib.yml + buildkite-agent pipeline upload .buildkite/testing_mldatadevices.yml + buildkite-agent pipeline upload .buildkite/testing_luxtestutils.yml + + # Documentation buildkite-agent pipeline upload .buildkite/documentation.yml + + # Benchmarks buildkite-agent pipeline upload .buildkite/benchmarks.yml diff --git a/.buildkite/testing.yml b/.buildkite/testing.yml index a4b85da1c..935e95fe3 100644 --- a/.buildkite/testing.yml +++ b/.buildkite/testing.yml @@ -1,20 +1,48 @@ steps: - - group: ":julia: CUDA GPU" + - group: ":julia: (Lux) CUDA GPU" steps: - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true dirs: - src - ext + - lib/LuxCore/src + - lib/LuxCore/ext + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + - lib/WeightInitializers/src + - lib/WeightInitializers/ext + - lib/LuxLib/src + - lib/LuxLib/ext + - lib/LuxTestUtils/src agents: queue: "juliagpu" cuda: "*" + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=. -e ' + import Pkg; + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/LuxCore", "lib/MLDataDevices", "lib/WeightInitializers", "lib/LuxLib",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); + end + Pkg.develop(dev_pkgs); + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.activate("test"); + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/LuxTestUtils", "lib/LuxLib", "lib/MLDataDevices", "lib/LuxCore", ".") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs); + Pkg.instantiate();' + julia --color=yes --code-coverage=user --depwarn=yes --project=test -e ' + import Pkg, Lux; + dir = dirname(pathof(Lux)); + include(joinpath(dir, "../test/runtests.jl"))' env: BACKEND_GROUP: "CUDA" if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ @@ -23,44 +51,49 @@ steps: setup: julia: - "1.10" + - "1" - - group: ":telescope: Downstream CUDA" - steps: - - label: ":julia: {{matrix.repo}} (Julia 1 + CUDA GPU)" - plugins: - - JuliaCI/julia#v1: - version: "1.10" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "CUDA" - agents: - queue: "juliagpu" - cuda: "*" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" - timeout_in_minutes: 60 - matrix: - setup: - repo: - - "Boltz" - - "NeuralPDE#GPU" - - "DeepEquilibriumNetworks" - - - group: ":julia: AMD GPU" + - group: ":julia: (Lux) AMD GPU" steps: - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" plugins: - JuliaCI/julia#v1: version: "{{matrix.julia}}" - - JuliaCI/julia-test#v1: - test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true dirs: - src - ext + - lib/LuxCore/src + - lib/LuxCore/ext + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + - lib/WeightInitializers/src + - lib/WeightInitializers/ext + - lib/LuxLib/src + - lib/LuxLib/ext + - lib/LuxTestUtils/src + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=. -e ' + import Pkg; + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/LuxCore", "lib/MLDataDevices", "lib/WeightInitializers", "lib/LuxLib",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); + end + Pkg.develop(dev_pkgs); + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.activate("test"); + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/LuxTestUtils", "lib/LuxLib", "lib/MLDataDevices", "lib/LuxCore", ".") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs); + Pkg.instantiate();' + julia --color=yes --code-coverage=user --depwarn=yes --project=test -e ' + import Pkg, Lux; + dir = dirname(pathof(Lux)); + include(joinpath(dir, "../test/runtests.jl"))' env: BACKEND_GROUP: "AMDGPU" agents: @@ -73,29 +106,7 @@ steps: setup: julia: - "1.10" - - - group: ":telescope: Downstream AMD GPU" - steps: - - label: ":julia: {{matrix.repo}} (Julia 1 + AMD GPU)" - plugins: - - JuliaCI/julia#v1: - version: "1.10" - - JuliaCI/julia-coverage#v1: - codecov: true - dirs: - - src - - ext - command: julia --code-coverage=user --color=yes --project .buildkite/scripts/downstream.jl "{{matrix.repo}}" "AMDGPU" - agents: - queue: "juliagpu" - rocm: "*" - rocmgpu: "*" - if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test" - timeout_in_minutes: 60 - matrix: - setup: - repo: - - "Boltz" + - "1" env: SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/.buildkite/testing_luxcuda.yml b/.buildkite/testing_luxcuda.yml new file mode 100644 index 000000000..b5beec1b4 --- /dev/null +++ b/.buildkite/testing_luxcuda.yml @@ -0,0 +1,30 @@ +steps: + - group: ":julia: (LuxCUDA) CUDA GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}}" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/LuxCUDA/src + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCUDA -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + agents: + queue: "juliagpu" + cuda: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + +env: + SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/.buildkite/testing_luxlib.yml b/.buildkite/testing_luxlib.yml new file mode 100644 index 000000000..8a1607ec5 --- /dev/null +++ b/.buildkite/testing_luxlib.yml @@ -0,0 +1,102 @@ +steps: + - group: ":julia: (LuxLib) CUDA GPU" + steps: + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/LuxLib/src + - lib/LuxLib/ext + - lib/LuxCore/src + - lib/LuxCore/ext + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + - lib/LuxTestUtils/src + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib -e ' + import Pkg; + Pkg.Registry.update(); + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/LuxCore", "lib/MLDataDevices") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); + end; + Pkg.develop(dev_pkgs); + Pkg.instantiate(); + Pkg.activate("lib/LuxLib/test"); + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/LuxTestUtils", "lib/LuxLib", "lib/MLDataDevices") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); + end; + Pkg.develop(dev_pkgs)' + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib/test -e ' + import Pkg, LuxLib + dir = dirname(pathof(LuxLib)) + include(joinpath(dir, "../test/runtests.jl"))' + agents: + queue: "juliagpu" + cuda: "*" + env: + BACKEND_GROUP: "CUDA" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 240 + matrix: + setup: + julia: + - "1.10" + - "1" + + - group: ":julia: (LuxLib) AMD GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/LuxLib/src + - lib/LuxLib/ext + - lib/LuxCore/src + - lib/LuxCore/ext + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + - lib/LuxTestUtils/src + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib -e ' + import Pkg; + Pkg.Registry.update(); + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/LuxCore", "lib/MLDataDevices") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); + end; + Pkg.develop(dev_pkgs); + Pkg.instantiate(); + Pkg.activate("lib/LuxLib/test"); + dev_pkgs = Pkg.PackageSpec[]; + for pkg in ("lib/LuxTestUtils", "lib/LuxLib", "lib/MLDataDevices") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)); + end; + Pkg.develop(dev_pkgs)' + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib/test -e ' + import Pkg, LuxLib + dir = dirname(pathof(LuxLib)) + include(joinpath(dir, "../test/runtests.jl"))' + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + env: + BACKEND_GROUP: "AMDGPU" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 240 + matrix: + setup: + julia: + - "1.10" + - "1" + +env: + SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/.buildkite/testing_luxtestutils.yml b/.buildkite/testing_luxtestutils.yml new file mode 100644 index 000000000..58ab71095 --- /dev/null +++ b/.buildkite/testing_luxtestutils.yml @@ -0,0 +1,32 @@ +steps: + - group: ":julia: (LuxTestUtils) CUDA GPU" + steps: + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/LuxTestUtils/src + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxTestUtils -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + agents: + queue: "juliagpu" + cuda: "*" + env: + BACKEND_GROUP: "CUDA" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + +env: + SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/.buildkite/testing_mldatadevices.yml b/.buildkite/testing_mldatadevices.yml new file mode 100644 index 000000000..07a1647be --- /dev/null +++ b/.buildkite/testing_mldatadevices.yml @@ -0,0 +1,129 @@ +steps: + - group: ":julia: (MLDataDevices) CUDA GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + CUDA GPU + {{matrix.group}}" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/MLDataDevices -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + env: + BACKEND_GROUP: "{{matrix.group}}" + agents: + queue: "juliagpu" + cuda: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + group: + - "CPU" + - "XLA" + + - group: ":julia: (MLDataDevices) AMD GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/MLDataDevices -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + env: + BACKEND_GROUP: "AMDGPU" + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + + - group: ":julia: (MLDataDevices) Metal GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + Metal" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/MLDataDevices -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + env: + BACKEND_GROUP: "Metal" + agents: + queue: "juliaecosystem" + os: "macos" + arch: "aarch64" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + + - group: ":julia: (MLDataDevices) oneAPI GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + oneAPI" + soft_fail: true + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/MLDataDevices/src + - lib/MLDataDevices/ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/MLDataDevices -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + env: + BACKEND_GROUP: "oneAPI" + agents: + queue: "juliagpu" + intel: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + +env: + SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/.buildkite/testing_weightinitializers.yml b/.buildkite/testing_weightinitializers.yml new file mode 100644 index 000000000..7d6570bf6 --- /dev/null +++ b/.buildkite/testing_weightinitializers.yml @@ -0,0 +1,126 @@ +steps: + - group: ":julia: (WeightInitializers) CUDA GPU" + steps: + - label: ":julia: Julia {{matrix.julia}} + CUDA GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/WeightInitializers/src + - lib/WeightInitializers/ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/WeightInitializers -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + agents: + queue: "juliagpu" + cuda: "*" + env: + BACKEND_GROUP: "CUDA" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + + - group: ":julia: (WeightInitializers) AMD GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + AMD GPU" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/WeightInitializers/src + - lib/WeightInitializers/ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/WeightInitializers -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + agents: + queue: "juliagpu" + rocm: "*" + rocmgpu: "*" + env: + BACKEND_GROUP: "AMDGPU" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + + - group: ":julia: (WeightInitializers) Metal GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + Metal" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/WeightInitializers/src + - lib/WeightInitializers/ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/WeightInitializers -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + agents: + queue: "juliaecosystem" + os: "macos" + arch: "aarch64" + env: + BACKEND_GROUP: "Metal" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + + - group: ":julia: (WeightInitializers) oneAPI GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + oneAPI" + soft_fail: true + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - lib/WeightInitializers/src + - lib/WeightInitializers/ext + command: | + julia --color=yes --code-coverage=user --depwarn=yes --project=lib/WeightInitializers -e ' + import Pkg; + Pkg.Registry.update(); + Pkg.instantiate(); + Pkg.test(; coverage="user")' + agents: + queue: "juliagpu" + intel: "*" + env: + BACKEND_GROUP: "oneAPI" + if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1.10" + - "1" + +env: + SECRET_CODECOV_TOKEN: "jQ0BMTQgyZx7QGyU0Q2Ec7qB9mtE2q/tDu0FsfxvEG7/zOAGvXkyXrzIFFOQxvDoFcP+K2+hYZKMxicYdNqzr5wcxu505aNGN2GM3wyegAr+hO6q12bCFYx6qXzU9FLCCdeqINqn9gUSSOlGtWNFrbAlrTyz/D4Yo66TqBDzvaLL63FMnhCLaXW/zJt3hNuEAJaPY2O6Ze1rX2WZ3Y+i+s3uQ8aLImtoCJhPe8CRx+OhuYiTzGhynFfGntZ0738/1RN4gNM0S/hTC4gLE7XMVBanJpGh32rFaiDwW4zAyXKBrDkL3QA3MS1RvLTJxGJ085S16hCk0C4ddAhZCvIM9Q==;U2FsdGVkX1+bXdFeKMs5G79catOCyby2n07A2fg0FjVAvrjQLZ0yfvDS4paJiFikLkodho0khz2YALKb2Y0K6w==" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 33565d6c2..244726cd6 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -1,4 +1,4 @@ -name: CI +name: CI (Lux) on: pull_request: branches: @@ -9,6 +9,11 @@ on: - "test/**" - "Project.toml" - ".github/workflows/CI.yml" + - "lib/LuxTestUtils/**" + - "lib/LuxCore/**" + - "lib/MLDataDevices/**" + - "lib/WeightInitializers/**" + - "lib/LuxLib/**" push: branches: - main @@ -20,8 +25,7 @@ concurrency: cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: - ci: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.test_group }} + test: if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} strategy: @@ -31,6 +35,8 @@ jobs: - "1.10" os: - ubuntu-latest + - macos-latest + - windows-latest test_group: - "core_layers" - "contrib" @@ -43,13 +49,6 @@ jobs: - "eltype_match" - "fluxcompat" - "reactant" - include: - - version: "1.10" - os: macos-latest - test_group: "all" - - version: "1.10" - os: windows-latest - test_group: "all" steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -65,67 +64,35 @@ jobs: ${{ runner.os }}-test-${{ env.cache-name }}- ${{ runner.os }}-test- ${{ runner.os }}- - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 + - name: "Install Dependencies" + run: | + import Pkg + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxCore", "lib/MLDataDevices", "lib/WeightInitializers", "lib/LuxLib",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + Pkg.Registry.update() + Pkg.instantiate() + Pkg.activate("test") + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxTestUtils", "lib/LuxLib", "lib/MLDataDevices", "lib/LuxCore", ".") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + Pkg.instantiate() + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0} + - name: "Run Tests" + run: | + import Pkg, Lux + dir = dirname(pathof(Lux)) + include(joinpath(dir, "../test/runtests.jl")) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=test {0} env: LUX_TEST_GROUP: ${{ matrix.test_group }} - uses: julia-actions/julia-processcoverage@v1 with: - directories: src,ext - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - verbose: true - fail_ci_if_error: true - - downstream: - name: Downstream ${{ matrix.package.repo }}/${{ matrix.package.group }} - if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }} - runs-on: ubuntu-latest - timeout-minutes: 240 - env: - GROUP: ${{ matrix.package.group }} - strategy: - fail-fast: false - matrix: - package: - - { user: SciML, repo: DiffEqFlux.jl, group: BasicNeuralDE } - - { user: SciML, repo: DiffEqFlux.jl, group: AdvancedNeuralDE } - - { user: SciML, repo: DeepEquilibriumNetworks.jl, group: All } - - { user: SciML, repo: NeuralPDE.jl, group: NNPDE1 } - - { user: SciML, repo: NeuralPDE.jl, group: NNPDE2 } - - { user: LuxDL, repo: Boltz.jl, group: CPU } - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: "1.10" - arch: x64 - - uses: julia-actions/julia-buildpkg@v1 - - name: Clone Downstream - uses: actions/checkout@v4 - with: - repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} - path: downstream - - name: Load this and run the downstream tests - shell: julia --code-coverage=user --color=yes --project=downstream {0} - run: | - using Pkg - try - # force it to use this PR's version of the package - Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps - Pkg.update() - Pkg.test(; coverage="user") # resolver may fail with test time deps - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - # If we can't resolve that means this is incompatible by SemVer and this is fine - # It means we marked this as a breaking change, so we don't need to worry about - # Mistakenly introducing a breaking change, as we have intentionally made one - @info "Not compatible with this release. No problem." exception=err - exit(0) # Exit immediately, as a success - end - - uses: julia-actions/julia-processcoverage@v1 + directories: src,ext,lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext,lib/WeightInitializers/src,lib/WeightInitializers/ext,lib/LuxLib/src,lib/LuxLib/ext,lib/LuxTestUtils/src - uses: codecov/codecov-action@v4 with: files: lcov.info @@ -135,7 +102,6 @@ jobs: downgrade: if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} - name: Downgrade Julia 1.10 runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -143,11 +109,35 @@ jobs: with: version: "1.10" - uses: julia-actions/julia-downgrade-compat@v1 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 + with: + skip: "LuxCore,MLDataDevices,WeightInitializers,LuxLib" + - name: "Install Dependencies" + run: | + import Pkg + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxCore", "lib/MLDataDevices", "lib/WeightInitializers", "lib/LuxLib",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + Pkg.Registry.update() + Pkg.instantiate() + Pkg.activate("test") + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxTestUtils", "lib/LuxLib", "lib/MLDataDevices", "lib/LuxCore", ".") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + Pkg.instantiate() + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0} + - name: "Run Tests" + run: | + import Pkg, Lux + dir = dirname(pathof(Lux)) + include(joinpath(dir, "../test/runtests.jl")) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=test {0} - uses: julia-actions/julia-processcoverage@v1 with: - directories: src,ext + directories: src,ext,lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext,lib/WeightInitializers/src,lib/WeightInitializers/ext,lib/LuxLib/src,lib/LuxLib/ext,lib/LuxTestUtils/src - uses: codecov/codecov-action@v4 with: files: lcov.info @@ -155,34 +145,5 @@ jobs: verbose: true fail_ci_if_error: true - invalidations: - # Only run on PRs to the default branch. - # In the PR trigger above branches can be specified only explicitly whereas this check should work for master, main, or any other default branch - if: github.base_ref == github.event.repository.default_branch - runs-on: ubuntu-latest - steps: - - uses: julia-actions/setup-julia@v2 - with: - version: "1" - - uses: actions/checkout@v4 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_pr - - - uses: actions/checkout@v4 - with: - ref: ${{ github.event.repository.default_branch }} - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-invalidations@v1 - id: invs_default - - - name: Report invalidation counts - run: | - echo "Invalidations on default branch: ${{ steps.invs_default.outputs.total }} (${{ steps.invs_default.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - echo "This branch: ${{ steps.invs_pr.outputs.total }} (${{ steps.invs_pr.outputs.deps }} via deps)" >> $GITHUB_STEP_SUMMARY - - name: Check if the PR does increase number of invalidations - if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total - run: exit 1 - env: BACKEND_GROUP: "CPU" diff --git a/.github/workflows/CIPreRelease.yml b/.github/workflows/CIPreRelease.yml index 2587158fc..11a05b9f6 100644 --- a/.github/workflows/CIPreRelease.yml +++ b/.github/workflows/CIPreRelease.yml @@ -1,4 +1,4 @@ -name: CIPreRelease +name: CIPreRelease (Lux) on: pull_request: branches: @@ -9,6 +9,11 @@ on: - "test/**" - "Project.toml" - ".github/workflows/CI.yml" + - "lib/LuxTestUtils/**" + - "lib/LuxCore/**" + - "lib/MLDataDevices/**" + - "lib/WeightInitializers/**" + - "lib/LuxLib/**" push: branches: - main @@ -21,7 +26,6 @@ concurrency: jobs: ci: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.test_group }} if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} runs-on: ${{ matrix.os }} strategy: diff --git a/.github/workflows/CI_LuxCUDA.yml b/.github/workflows/CI_LuxCUDA.yml new file mode 100644 index 000000000..576886614 --- /dev/null +++ b/.github/workflows/CI_LuxCUDA.yml @@ -0,0 +1,87 @@ +name: CI (LuxCUDA) +on: + pull_request: + branches: + - main + paths: + - "lib/LuxCUDA/**" + - ".github/workflows/CI_LuxCUDA.yml" + push: + branches: + - main + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: + - "1.10" + - "1" + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCUDA {0} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/LuxCUDA/src + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: "1.10" + - uses: julia-actions/julia-downgrade-compat@v1 + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCUDA {0} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/LuxCUDA/src + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + +env: + BACKEND_GROUP: "CPU" diff --git a/.github/workflows/CI_LuxCore.yml b/.github/workflows/CI_LuxCore.yml new file mode 100644 index 000000000..937b32a44 --- /dev/null +++ b/.github/workflows/CI_LuxCore.yml @@ -0,0 +1,119 @@ +name: CI (LuxCore) +on: + pull_request: + branches: + - main + paths: + - "lib/LuxCore/**" + - ".github/workflows/CI_LuxCore.yml" + - "lib/MLDataDevices/**" + push: + branches: + - main + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + test: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - "1.10" + - "1" + os: + - ubuntu-latest + - macos-latest + - windows-latest + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: "Install Dependencies" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.activate("lib/LuxCore/test") + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxCore",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore {0} + - name: "Run Tests" + run: | + import Pkg, LuxCore + dir = dirname(pathof(LuxCore)) + include(joinpath(dir, "../test/runtests.jl")) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore/test {0} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: ["1.10"] + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: julia-actions/julia-downgrade-compat@v1 + - name: "Install Dependencies" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.activate("lib/LuxCore/test") + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxCore",) + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore {0} + - name: "Run Tests" + run: | + import Pkg, LuxCore + dir = dirname(pathof(LuxCore)) + include(joinpath(dir, "../test/runtests.jl")) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxCore/test {0} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + +env: + BACKEND_GROUP: "CPU" diff --git a/.github/workflows/CI_LuxLib.yml b/.github/workflows/CI_LuxLib.yml new file mode 100644 index 000000000..2ba26a789 --- /dev/null +++ b/.github/workflows/CI_LuxLib.yml @@ -0,0 +1,206 @@ +name: CI (LuxLib) +on: + pull_request: + branches: + - main + paths: + - "lib/LuxLib/**" + - ".github/workflows/CI_LuxLib.yml" + - "lib/LuxTestUtils/**" + - "lib/LuxCore/**" + - "lib/MLDataDevices/**" + push: + branches: + - main + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + test: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - "1.10" + os: + - ubuntu-latest + test_group: + - "conv" + - "dense" + - "batch_norm" + - "group_norm" + - "instance_norm" + - "layer_norm" + - "other_ops" + - "batched_ops" + - "others" + blas_backend: + - "default" + loopvec: + - "true" + include: + - os: ubuntu-latest + test_group: "dense" + blas_backend: "blis" + version: "1.10" + loopvec: "true" + - os: ubuntu-latest + test_group: "dense" + blas_backend: "mkl" + version: "1.10" + loopvec: "true" + - os: ubuntu-latest + test_group: "dense" + blas_backend: "default" + version: "1.10" + loopvec: "false" + - os: ubuntu-latest + test_group: "batched_ops" + blas_backend: "default" + version: "1.10" + loopvec: "false" + - os: ubuntu-latest + test_group: "other_ops" + blas_backend: "default" + version: "1.10" + loopvec: "false" + - os: macos-latest + test_group: "dense" + blas_backend: "appleaccelerate" + version: "1.10" + loopvec: "true" + - os: macos-latest + test_group: "all" + blas_backend: "default" + version: "1.10" + loopvec: "true" + - os: windows-latest + test_group: "all" + blas_backend: "default" + version: "1.10" + loopvec: "true" + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: "Install Dependencies" + run: | + import Pkg + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxCore", "lib/MLDataDevices") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + Pkg.Registry.update() + Pkg.instantiate() + Pkg.activate("lib/LuxLib/test") + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxTestUtils", "lib/LuxLib") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib {0} + - name: "Run Tests" + run: | + import Pkg, LuxLib + dir = dirname(pathof(LuxLib)) + include(joinpath(dir, "../test/runtests.jl")) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib/test {0} + env: + LUXLIB_TEST_GROUP: ${{ matrix.test_group }} + LUXLIB_BLAS_BACKEND: ${{ matrix.blas_backend }} + LUXLIB_LOAD_LOOPVEC: ${{ matrix.loopvec }} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/LuxLib/src,lib/LuxLib/ext,lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext,lib/LuxTestUtils/src + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + test_group: + - "conv" + - "dense" + - "batch_norm" + - "group_norm" + - "instance_norm" + - "layer_norm" + - "other_ops" + - "batched_ops" + - "others" + blas_backend: + - "default" + loopvec: + - "true" + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: "1.10" + - uses: julia-actions/julia-downgrade-compat@v1 + with: + skip: "LuxCore,MLDataDevices" + - name: "Install Dependencies" + run: | + import Pkg + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxCore", "lib/MLDataDevices") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + Pkg.Registry.update() + Pkg.instantiate() + Pkg.activate("lib/LuxLib/test") + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxTestUtils", "lib/LuxLib", "lib/MLDataDevices") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib {0} + - name: "Run Tests" + run: | + import Pkg, LuxLib + dir = dirname(pathof(LuxLib)) + include(joinpath(dir, "../test/runtests.jl")) + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxLib/test {0} + env: + LUXLIB_TEST_GROUP: ${{ matrix.test_group }} + LUXLIB_BLAS_BACKEND: ${{ matrix.blas_backend }} + LUXLIB_LOAD_LOOPVEC: ${{ matrix.loopvec }} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/LuxLib/src,lib/LuxLib/ext,lib/LuxCore/src,lib/LuxCore/ext,lib/MLDataDevices/src,lib/MLDataDevices/ext,lib/LuxTestUtils/src + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + +env: + BACKEND_GROUP: "CPU" diff --git a/.github/workflows/CI_LuxTestUtils.yml b/.github/workflows/CI_LuxTestUtils.yml new file mode 100644 index 000000000..2c77e711d --- /dev/null +++ b/.github/workflows/CI_LuxTestUtils.yml @@ -0,0 +1,96 @@ +name: CI (LuxTestUtils) +on: + pull_request: + branches: + - main + paths: + - "lib/LuxTestUtils/**" + - ".github/workflows/CI_LuxTestUtils.yml" + push: + branches: + - main + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + test: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - "1.10" + - "1" + os: + - ubuntu-latest + - macos-latest + - windows-latest + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxTestUtils {0} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/LuxTestUtils/src + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: ["1.10"] + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: julia-actions/julia-downgrade-compat@v1 + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/LuxTestUtils {0} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/LuxTestUtils/src + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + +env: + BACKEND_GROUP: "CPU" diff --git a/.github/workflows/CI_MLDataDevices.yml b/.github/workflows/CI_MLDataDevices.yml new file mode 100644 index 000000000..452a68320 --- /dev/null +++ b/.github/workflows/CI_MLDataDevices.yml @@ -0,0 +1,105 @@ +name: CI (MLDataDevices) +on: + pull_request: + branches: + - main + paths: + - "lib/MLDataDevices/**" + - ".github/workflows/CI_MLDataDevices.yml" + push: + branches: + - main + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + test: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - "1.10" + - "1" + os: + - ubuntu-latest + - macos-latest + - windows-latest + group: + - CPU + - XLA + exclude: + - os: windows-latest + group: XLA + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/MLDataDevices {0} + env: + BACKEND_GROUP: ${{ matrix.group }} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/MLDataDevices/src,lib/MLDataDevices/ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: + - "1.10" + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: julia-actions/julia-downgrade-compat@v1 + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/MLDataDevices {0} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/MLDataDevices/src,lib/MLDataDevices/ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + +env: + BACKEND_GROUP: "CPU" diff --git a/.github/workflows/CI_WeightInitializers.yml b/.github/workflows/CI_WeightInitializers.yml new file mode 100644 index 000000000..36bfd48a8 --- /dev/null +++ b/.github/workflows/CI_WeightInitializers.yml @@ -0,0 +1,96 @@ +name: CI (WeightInitializers) +on: + pull_request: + branches: + - main + paths: + - "lib/WeightInitializers/**" + - ".github/workflows/CI_WeightInitializers.yml" + push: + branches: + - main + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + test: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - "1.10" + - "1" + os: + - ubuntu-latest + - macos-latest + - windows-latest + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/WeightInitializers {0} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/WeightInitializers/src,lib/WeightInitializers/ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + + downgrade: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && github.base_ref == github.event.repository.default_branch }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: ["1.10"] + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: julia-actions/julia-downgrade-compat@v1 + - name: "Install Dependencies and Run Tests" + run: | + import Pkg + Pkg.Registry.update() + Pkg.instantiate() + Pkg.test(; coverage="user") + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=lib/WeightInitializers {0} + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: lib/WeightInitializers/src,lib/WeightInitializers/ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true + +env: + BACKEND_GROUP: "CPU" diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index d2f4fccd6..a930415b9 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -36,6 +36,7 @@ jobs: import CompatHelper subdirs = ["", "docs", "test"] append!(subdirs, joinpath.(("examples",), filter(p -> isdir(joinpath("examples", p)), readdir("examples")))) + append!(subdirs, joinpath.(("lib",), filter(p -> isdir(joinpath("lib", p)), readdir("lib")))) CompatHelper.main(; subdirs) shell: julia --color=yes {0} working-directory: "./" diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml new file mode 100644 index 000000000..c53f0cc71 --- /dev/null +++ b/.github/workflows/Downstream.yml @@ -0,0 +1,84 @@ +name: Downstream +on: + pull_request: + branches: + - main + paths: + - "src/**" + - "ext/**" + - "test/**" + - "Project.toml" + - "lib/LuxCore/**" + - "lib/LuxLib/**" + - "lib/MLDataDevices/**" + - "lib/WeightInitializers/**" + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + downstream: + if: ${{ !contains(github.event.head_commit.message, '[skip tests]') && contains(github.event.pull_request.labels.*.name, 'run downstream test') }} + runs-on: ubuntu-latest + timeout-minutes: 60 + env: + GROUP: ${{ matrix.package.group }} + strategy: + fail-fast: false + matrix: + package: + - { user: SciML, repo: DiffEqFlux.jl, group: BasicNeuralDE } + - { user: SciML, repo: DiffEqFlux.jl, group: AdvancedNeuralDE } + - { user: SciML, repo: DeepEquilibriumNetworks.jl, group: All } + - { user: SciML, repo: NeuralPDE.jl, group: NNPDE1 } + - { user: SciML, repo: NeuralPDE.jl, group: NNPDE2 } + - { user: LuxDL, repo: Boltz.jl, group: CPU } + - { user: SciML, repo: NeuralOperators.jl, group: CPU } + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: "1.10" + - name: "Build Lux" + run: | + import Pkg + dev_pkgs = Pkg.PackageSpec[] + for pkg in ("lib/LuxCore", "lib/LuxLib", "lib/MLDataDevices", "lib/WeightInitializers") + push!(dev_pkgs, Pkg.PackageSpec(path=pkg)) + end + Pkg.develop(dev_pkgs) + Pkg.instantiate() + Pkg.update() + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0} + - name: "Clone Downstream" + uses: actions/checkout@v4 + with: + repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} + path: downstream + - name: "Load this and run the downstream tests" + shell: julia --color=yes --code-coverage=user --depwarn=yes --project=downstream {0} + run: | + using Pkg + try + # force it to use this PR's version of the package + Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps + Pkg.update() + Pkg.test(; coverage=true) # resolver may fail with test time deps + catch err + err isa Pkg.Resolve.ResolverError || rethrow() + # If we can't resolve that means this is incompatible by SemVer and this is fine + # It means we marked this as a breaking change, so we don't need to worry about + # Mistakenly introducing a breaking change, as we have intentionally made one + @info "Not compatible with this release. No problem." exception=err + exit(0) # Exit immediately, as a success + end + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true diff --git a/.typos.toml b/.typos.toml index fb4c8d1e2..3d459b58c 100644 --- a/.typos.toml +++ b/.typos.toml @@ -1,3 +1,7 @@ [default.extend-words] numer = "numer" -Nd = "Nd" \ No newline at end of file +Nd = "Nd" +nd = "nd" +Ba = "Ba" +skipt = "skipt" +nin = "nin" diff --git a/Project.toml b/Project.toml index 4fc8de57f..3eaa8de65 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.2.0" +version = "1.2.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -67,7 +67,7 @@ LuxZygoteExt = "Zygote" [compat] ADTypes = "1.8.1" -Adapt = "4" +Adapt = "4.1" ArgCheck = "2.3" ArrayInterface = "7.10" CUDA = "5.3.2" diff --git a/README.md b/README.md index 503cc9c9d..f1cf9db17 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](http://lux.csail.mit.edu/stable/) [![CI](https://github.com/LuxDL/Lux.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/LuxDL/Lux.jl/actions/workflows/CI.yml) -[![CI (pre-release)](https://img.shields.io/github/actions/workflow/status/LuxDL/Lux.jl/CIPreRelease.yml?branch=main&label=CI%20(pre-release)&logo=github)](https://github.com/LuxDL/Lux.jl/actions/workflows/CIPreRelease.yml) +[![CI (pre-release)]()](https://github.com/LuxDL/Lux.jl/actions/workflows/CIPreRelease.yml) [![Build status](https://img.shields.io/buildkite/ba1f9622add5978c2d7b194563fd9327113c9c21e5734be20e/main.svg?label=gpu&branch=main&logo=buildkite)](https://buildkite.com/julialang/lux-dot-jl) [![codecov](https://codecov.io/gh/LuxDL/Lux.jl/branch/main/graph/badge.svg?token=IMqBM1e3hz)](https://codecov.io/gh/LuxDL/Lux.jl) [![Benchmarks](https://github.com/LuxDL/Lux.jl/actions/workflows/Benchmark.yml/badge.svg?branch=main)](https://lux.csail.mit.edu/benchmarks/) @@ -40,6 +40,95 @@ Pkg.add("Lux") > [!TIP] > If you are using a pre-v1 version of Lux.jl, please see the [Updating to v1 section](https://lux.csail.mit.edu/dev/introduction/updating_to_v1) for instructions on how to update. +
+ +| **Packages** | **Stable Version** | **Monthly Downloads** | **Total Downloads** | **Build Status** | +| :----------------------------------------------------- | :------------------------------------------------------------- | :-------------------------------------------------------------------- | :-------------------------------------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------- | +| 📦 [Lux.jl](./src) | [![][lux-version]][lux-juliahub] | [![][downloads-lux]][downloads-lux-url] | [![][total-downloads-lux]][downloads-lux-url] | [![][gh-actions-lux]][gh-actions-lux-url] [![][gh-actions-lux-prerelease]][gh-actions-lux-prerelease-url] [![][buildkite-badge]][buildkite-url] | +| └ 📦 [LuxLib.jl](./lib/LuxLib) | [![][luxlib-version]][luxlib-juliahub] | [![][downloads-luxlib]][downloads-luxlib-url] | [![][total-downloads-luxlib]][downloads-luxlib-url] | [![][gh-actions-luxlib]][gh-actions-luxlib-url] | +| └ 📦 [LuxCore.jl](./lib/LuxCore) | [![][luxcore-version]][luxcore-juliahub] | [![][downloads-luxcore]][downloads-luxcore-url] | [![][total-downloads-luxcore]][downloads-luxcore-url] | [![][gh-actions-luxcore]][gh-actions-luxcore-url] | +| └ 📦 [MLDataDevices.jl](./lib/MLDataDevices) | [![][mldatadevices-version]][mldatadevices-juliahub] | [![][downloads-mldatadevices]][downloads-mldatadevices-url] | [![][total-downloads-mldatadevices]][downloads-mldatadevices-url] | [![][gh-actions-mldatadevices]][gh-actions-mldatadevices-url] | +| └ 📦 [WeightInitializers.jl](./lib/WeightInitializers) | [![][weightinitializers-version]][weightinitializers-juliahub] | [![][downloads-weightinitializers]][downloads-weightinitializers-url] | [![][total-downloads-weightinitializers]][downloads-weightinitializers-url] | [![][gh-actions-weightinitializers]][gh-actions-weightinitializers-url] | +| └ 📦 [LuxTestUtils.jl](./lib/LuxTestUtils) | [![][luxtestutils-version]][luxtestutils-juliahub] | [![][downloads-luxtestutils]][downloads-luxtestutils-url] | [![][total-downloads-luxtestutils]][downloads-luxtestutils-url] | [![][gh-actions-luxtestutils]][gh-actions-luxtestutils-url] | +| └ 📦 [LuxCUDA.jl](./lib/LuxCUDA) | [![][luxcuda-version]][luxcuda-juliahub] | [![][downloads-luxcuda]][downloads-luxcuda-url] | [![][total-downloads-luxcuda]][downloads-luxcuda-url] | [![][gh-actions-luxcuda]][gh-actions-luxcuda-url] | + +
+ + + + + +[lux-version]: https://juliahub.com/docs/General/Lux/stable/version.svg?color=blue +[luxlib-version]: https://juliahub.com/docs/General/LuxLib/stable/version.svg?color=blue +[luxcore-version]: https://juliahub.com/docs/General/LuxCore/stable/version.svg?color=blue +[mldatadevices-version]: https://juliahub.com/docs/General/MLDataDevices/stable/version.svg?color=blue +[weightinitializers-version]: https://juliahub.com/docs/General/WeightInitializers/stable/version.svg?color=blue +[luxtestutils-version]: https://juliahub.com/docs/General/LuxTestUtils/stable/version.svg?color=blue +[luxcuda-version]: https://juliahub.com/docs/General/LuxCUDA/stable/version.svg?color=blue +[lux-juliahub]: https://juliahub.com/ui/Packages/General/Lux +[luxlib-juliahub]: https://juliahub.com/ui/Packages/General/LuxLib +[luxcore-juliahub]: https://juliahub.com/ui/Packages/General/LuxCore +[mldatadevices-juliahub]: https://juliahub.com/ui/Packages/General/MLDataDevices +[weightinitializers-juliahub]: https://juliahub.com/ui/Packages/General/WeightInitializers +[luxtestutils-juliahub]: https://juliahub.com/ui/Packages/General/LuxTestUtils +[luxcuda-juliahub]: https://juliahub.com/ui/Packages/General/LuxCUDA + + + +[docr-img]: https://img.shields.io/badge/docs-stable-blue.svg +[docd-img]: https://img.shields.io/badge/docs-dev-blue.svg +[docr-url]: https://lux.csail.mit.edu/stable/ +[docd-url]: https://lux.csail.mit.edu/dev/ + + + +[buildkite-badge]: https://img.shields.io/buildkite/ba1f9622add5978c2d7b194563fd9327113c9c21e5734be20e/main.svg?label=gpu&branch=main&logo=buildkite] + +[buildkite-url]: https://buildkite.com/julialang/lux-dot-jl/builds?branch=main + + + +[gh-actions-lux]: https://github.com/LuxDL/Lux.jl/workflows/CI/badge.svg +[gh-actions-lux-prerelease]: https://github.com/LuxDL/Lux.jl/workflows/CIPreRelease/badge.svg +[gh-actions-luxlib]: https://github.com/LuxDL/Lux.jl/workflows/CI_LuxLib/badge.svg +[gh-actions-luxcore]: https://github.com/LuxDL/Lux.jl/workflows/CI_LuxCore/badge.svg +[gh-actions-mldatadevices]: https://github.com/LuxDL/Lux.jl/workflows/CI_MLDataDevices/badge.svg +[gh-actions-weightinitializers]: https://github.com/LuxDL/Lux.jl/workflows/CI_WeightInitializers/badge.svg +[gh-actions-luxtestutils]: https://github.com/LuxDL/Lux.jl/workflows/CI_LuxTestUtils/badge.svg +[gh-actions-luxcuda]: https://github.com/LuxDL/Lux.jl/workflows/CI_LuxCUDA/badge.svg +[gh-actions-lux-url]: https://github.com/LuxDL/Lux.jl/actions/workflows/CI.yml +[gh-actions-lux-prerelease-url]: https://github.com/LuxDL/Lux.jl/actions/workflows/CIPreRelease.yml +[gh-actions-luxlib-url]: https://github.com/LuxDL/Lux.jl/actions/workflows/CI_LuxLib.yml +[gh-actions-luxcore-url]: https://github.com/LuxDL/Lux.jl/actions/workflows/CI_LuxCore.yml +[gh-actions-mldatadevices-url]: https://github.com/LuxDL/Lux.jl/actions/workflows/CI_MLDataDevices.yml +[gh-actions-weightinitializers-url]: https://github.com/LuxDL/Lux.jl/actions/workflows/CI_WeightInitializers.yml +[gh-actions-luxtestutils-url]: https://github.com/LuxDL/Lux.jl/actions/workflows/CI_LuxTestUtils.yml +[gh-actions-luxcuda-url]: https://github.com/LuxDL/Lux.jl/actions/workflows/CI_LuxCUDA.yml + + + +[total-downloads-lux]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FLux&query=total_requests&label=Downloads +[total-downloads-luxlib]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FLuxLib&query=total_requests&label=Downloads +[total-downloads-luxcore]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FLuxCore&query=total_requests&label=Downloads +[total-downloads-mldatadevices]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FMLDataDevices&query=total_requests&label=Downloads +[total-downloads-weightinitializers]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FWeightInitializers&query=total_requests&label=Downloads +[total-downloads-luxtestutils]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FLuxTestUtils&query=total_requests&label=Downloads +[total-downloads-luxcuda]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Ftotal_downloads%2FLuxCUDA&query=total_requests&label=Downloads +[downloads-lux]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FLux&query=total_requests&suffix=%2Fmonth&label=Downloads +[downloads-luxlib]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FLuxLib&query=total_requests&suffix=%2Fmonth&label=Downloads +[downloads-luxcore]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FLuxCore&query=total_requests&suffix=%2Fmonth&label=Downloads +[downloads-mldatadevices]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FMLDataDevices&query=total_requests&suffix=%2Fmonth&label=Downloads +[downloads-weightinitializers]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FWeightInitializers&query=total_requests&suffix=%2Fmonth&label=Downloads +[downloads-luxtestutils]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FLuxTestUtils&query=total_requests&suffix=%2Fmonth&label=Downloads +[downloads-luxcuda]: https://img.shields.io/badge/dynamic/json?url=http%3A%2F%2Fjuliapkgstats.com%2Fapi%2Fv1%2Fmonthly_downloads%2FLuxCUDA&query=total_requests&suffix=%2Fmonth&label=Downloads +[downloads-lux-url]: http://juliapkgstats.com/pkg/Lux +[downloads-luxlib-url]: http://juliapkgstats.com/pkg/LuxLib +[downloads-luxcore-url]: http://juliapkgstats.com/pkg/LuxCore +[downloads-mldatadevices-url]: http://juliapkgstats.com/pkg/MLDataDevices +[downloads-weightinitializers-url]: http://juliapkgstats.com/pkg/WeightInitializers +[downloads-luxtestutils-url]: http://juliapkgstats.com/pkg/LuxTestUtils +[downloads-luxcuda-url]: http://juliapkgstats.com/pkg/LuxCUDA + ## 🤸 Quickstart ```julia diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 95b330c1a..6771aec14 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -9,11 +9,14 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" ThreadPinning = "811555cd-349b-4f26-b7bc-1f208b848042" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/benchmarks/setup.jl b/benchmarks/setup.jl index e2d05bc88..e08cd4e2e 100644 --- a/benchmarks/setup.jl +++ b/benchmarks/setup.jl @@ -1,30 +1,42 @@ -using ADTypes: ADTypes, AutoEnzyme, AutoZygote +using ADTypes using Adapt: adapt -using Lux: Lux, BatchNorm, Chain, Conv, Dense, Dropout, FlattenLayer, MaxPool -using MLDataDevices: AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice -using NNlib: relu, gelu +using Lux +using LuxLib +using MLDataDevices +using MLDataDevices: AbstractDevice +using NNlib using Random: Random +using StableRNGs: StableRNG # AD Backends using Enzyme: Enzyme using Zygote: Zygote # Helper Functions -@inline synchronize(::CPUDevice) = nothing -@inline synchronize(::AMDGPUDevice) = AMDGPU.synchronize() -@inline synchronize(::CUDADevice) = CUDA.synchronize() - -@inline reclaim(::CPUDevice) = GC.gc() -@inline reclaim(::AMDGPUDevice) = AMDGPU.HIP.reclaim() -@inline reclaim(::CUDADevice) = CUDA.reclaim() - -@inline sumabs2(model, x, p, st) = sum(abs2, first(Lux.apply(model, x, p, st))) -@inline sumabs2(model, x) = sum(abs2, model(x)) +synchronize(::CPUDevice) = nothing +synchronize(::AMDGPUDevice) = AMDGPU.synchronize() +synchronize(::CUDADevice) = CUDA.synchronize() +synchronize(::MetalDevice) = Metal.synchronize() +synchronize(::oneAPIDevice) = oneAPI.synchronize() + +reclaim(::CPUDevice) = GC.gc() +reclaim(::AMDGPUDevice) = AMDGPU.HIP.reclaim() +reclaim(::CUDADevice) = CUDA.reclaim() +reclaim(::MetalDevice) = nothing # Metal.reclaim() +reclaim(::oneAPIDevice) = nothing # oneAPI.reclaim() + +function sumabs2(model::Lux.AbstractLuxLayer, x, p, st) + return sum(abs2, first(Lux.apply(model, x, p, st))) +end +sumabs2(f::F, args...) where {F} = sum(abs2, f(args...)) +sumabs2first(f::F, args...) where {F} = sum(abs2, first(f(args...))) function benchmark_group_to_backend(benchmark_group::String) benchmark_group == "CPU" && return CPUDevice() benchmark_group == "AMDGPU" && return AMDGPUDevice() benchmark_group == "CUDA" && return CUDADevice() + benchmark_group == "Metal" && return MetalDevice() + benchmark_group == "oneAPI" && return oneAPIDevice() error("Unknown backend: $(benchmark_group)") end @@ -39,12 +51,14 @@ end # Main benchmark files include("setups/layers.jl") include("setups/models.jl") +include("setups/luxlib.jl") function setup_benchmarks!(suite::BenchmarkGroup, backend::String, num_cpu_threads::Int64) dev = benchmark_group_to_backend(backend) cpu_or_gpu = backend == "CPU" ? "CPU" : "GPU" final_backend = backend == "CPU" ? string(num_cpu_threads, " ", "thread(s)") : backend + # Model Benchmarks setup_dense_benchmarks!(suite, cpu_or_gpu, final_backend, dev) setup_conv_benchmarks!(suite, cpu_or_gpu, final_backend, dev) @@ -54,6 +68,19 @@ function setup_benchmarks!(suite::BenchmarkGroup, backend::String, num_cpu_threa setup_mlp_benchmarks!(suite, cpu_or_gpu, final_backend, dev) setup_lenet_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + # Layer Benchmarks + setup_dense_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + setup_bias_activation_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + setup_batchnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + setup_layernorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + setup_groupnorm_benchmarks!(suite, cpu_or_gpu, final_backend, dev) + + setup_batched_matmul_benchmarks!(suite, cpu_or_gpu, final_backend, dev) end function setup_forward_pass_benchmark!(suite::BenchmarkGroup, benchmark_name::String, diff --git a/benchmarks/setups/luxlib.jl b/benchmarks/setups/luxlib.jl new file mode 100644 index 000000000..fa2940dd4 --- /dev/null +++ b/benchmarks/setups/luxlib.jl @@ -0,0 +1,212 @@ +function dense_setup(N::Int, bias::Bool, dev::MLDataDevices.AbstractDevice) + rng = StableRNG(123) + x = randn(rng, Float32, N, 128) |> dev + w = randn(rng, Float32, N, N) |> dev + b = (bias ? randn(rng, Float32, N) : nothing) |> dev + return x, w, b +end + +function setup_dense_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, + backend::String, dev::MLDataDevices.AbstractDevice) + for bias in [true, false], activation in [identity, relu, gelu], N in [2, 32, 512] + benchmark_name = "dense($N, bias=$bias, act=$activation)($N x 128)" + suite[benchmark_name]["forward"][cpu_or_gpu][backend] = @benchmarkable begin + fused_dense_bias_activation($activation, w, x, b) + synchronize($dev) + end setup=begin + x, w, b = dense_setup($N, $bias, $dev) + end + + suite[benchmark_name]["zygote"][cpu_or_gpu][backend] = @benchmarkable begin + Zygote.gradient(sumabs2, fused_dense_bias_activation, $activation, w, x, b) + synchronize($dev) + end setup=begin + x, w, b = dense_setup($N, $bias, $dev) + Zygote.gradient(sumabs2, fused_dense_bias_activation, $activation, w, x, b) + end + end +end + +# Bias Activation +function bias_activation_setup(N::Int, dev::MLDataDevices.AbstractDevice) + rng = StableRNG(123) + x = randn(rng, Float32, N, 128) |> dev + b = randn(rng, Float32, N) |> dev + return x, b +end + +function setup_bias_activation_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, + backend::String, dev::MLDataDevices.AbstractDevice) + for activation in [tanh, relu, gelu], N in [2, 32, 512] + benchmark_name = "bias_activation($N, act=$activation)($N x 128)" + suite[benchmark_name]["forward"][cpu_or_gpu][backend] = @benchmarkable begin + bias_activation($activation, x, b) + synchronize($dev) + end setup=begin + reclaim($dev) + x, b = bias_activation_setup($N, $dev) + end + + suite[benchmark_name]["zygote"][cpu_or_gpu][backend] = @benchmarkable begin + Zygote.gradient(sumabs2, bias_activation, $activation, x, b) + synchronize($dev) + end setup=begin + reclaim($dev) + x, b = bias_activation_setup($N, $dev) + Zygote.gradient(sumabs2, bias_activation, $activation, x, b) + end + end +end + +# BatchNorm +function batchnorm_setup(shape::Dims, affine::Bool, dev::MLDataDevices.AbstractDevice) + rng = StableRNG(123) + x = randn(rng, Float32, shape...) |> dev + scale = (affine ? randn(rng, Float32, shape[end - 1]) : nothing) |> dev + bias = (affine ? randn(rng, Float32, shape[end - 1]) : nothing) |> dev + running_mean = rand(rng, Float32, shape[end - 1]) |> dev + running_var = rand(rng, Float32, shape[end - 1]) |> dev + return x, scale, bias, running_mean, running_var +end + +function setup_batchnorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, + backend::String, dev::MLDataDevices.AbstractDevice) + for activation in [identity, relu, gelu], ndims in (2, 4) + shapes = [(ntuple(Returns(16), ndims - 2)..., 4, 32), + (ntuple(Returns(16), ndims - 2)..., 32, 32)] + for shape in shapes, affine in (true, false) + benchmark_name = "batchnorm($ndims, act=$activation, affine=$affine)(\ + $(join(shape, " x ")))" + + suite[benchmark_name]["forward"][cpu_or_gpu][backend] = @benchmarkable begin + batchnorm( + x, scale, bias, running_mean, running_var, Val(false), $activation) + synchronize($dev) + end setup=begin + x, scale, bias, running_mean, running_var = batchnorm_setup( + $shape, $affine, $dev) + end + + suite[benchmark_name]["zygote"][cpu_or_gpu][backend] = @benchmarkable begin + Zygote.gradient(sumabs2first, batchnorm, x, scale, bias, + running_mean, running_var, Val(true), $activation) + synchronize($dev) + end setup=begin + reclaim($dev) + x, scale, bias, running_mean, running_var = batchnorm_setup( + $shape, $affine, $dev) + Zygote.gradient(sumabs2first, batchnorm, x, scale, bias, + running_mean, running_var, Val(true), $activation) + end + end + end +end + +# LayerNorm +function layernorm_setup(shape::Dims, affine::Bool, dev::MLDataDevices.AbstractDevice) + rng = StableRNG(123) + x = randn(rng, Float32, shape...) |> dev + scale = (affine ? randn(rng, Float32, shape[1:(end - 1)]..., 1) : nothing) |> dev + bias = (affine ? randn(rng, Float32, shape[1:(end - 1)]..., 1) : nothing) |> dev + return x, scale, bias +end + +function setup_layernorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, + backend::String, dev::MLDataDevices.AbstractDevice) + for activation in [identity, relu, gelu], ndims in (2, 4) + shapes = [(ntuple(Returns(16), ndims - 2)..., 4, 32), + (ntuple(Returns(16), ndims - 2)..., 32, 32)] + for shape in shapes, affine in (true, false) + benchmark_name = "layernorm($ndims, act=$activation, affine=$affine)(\ + $(join(shape, " x ")))" + + suite[benchmark_name]["forward"][cpu_or_gpu][backend] = @benchmarkable begin + layernorm(x, scale, bias, $activation, 1:($ndims - 1)) + synchronize($dev) + end setup=begin + reclaim($dev) + x, scale, bias = layernorm_setup($shape, $affine, $dev) + end + + suite[benchmark_name]["zygote"][cpu_or_gpu][backend] = @benchmarkable begin + Zygote.gradient( + sumabs2, layernorm, x, scale, bias, $activation, 1:($ndims - 1)) + synchronize($dev) + end setup=begin + reclaim($dev) + x, scale, bias = layernorm_setup($shape, $affine, $dev) + Zygote.gradient( + sumabs2, layernorm, x, scale, bias, $activation, 1:($ndims - 1)) + end + end + end +end + +# GroupNorm +function groupnorm_setup(shape::Dims, affine::Bool, dev::MLDataDevices.AbstractDevice) + rng = StableRNG(123) + x = randn(rng, Float32, shape...) |> dev + scale = (affine ? randn(rng, Float32, shape[end - 1]) : nothing) |> dev + bias = (affine ? randn(rng, Float32, shape[end - 1]) : nothing) |> dev + return x, scale, bias +end + +function setup_groupnorm_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, + backend::String, dev::MLDataDevices.AbstractDevice) + for activation in [identity, relu, gelu], ndims in (2, 4) + shapes = [(ntuple(Returns(16), ndims - 2)..., 4, 32), + (ntuple(Returns(16), ndims - 2)..., 32, 32)] + for shape in shapes, affine in (true, false) + benchmark_name = "groupnorm($ndims, act=$activation, affine=$affine)(\ + $(join(shape, " x ")))" + + suite[benchmark_name]["forward"][cpu_or_gpu][backend] = @benchmarkable begin + groupnorm(x, scale, bias, 4, $activation) + synchronize($dev) + end setup=begin + reclaim($dev) + x, scale, bias = groupnorm_setup($shape, $affine, $dev) + end + + suite[benchmark_name]["zygote"][cpu_or_gpu][backend] = @benchmarkable begin + Zygote.gradient(sumabs2, groupnorm, x, scale, bias, 4, $activation) + synchronize($dev) + end setup=begin + reclaim($dev) + x, scale, bias = groupnorm_setup($shape, $affine, $dev) + Zygote.gradient(sumabs2, groupnorm, x, scale, bias, 4, $activation) + end + end + end +end + +# Batched Matrix Multiplication +function batchedmm_setup(N::Int, Bsize::Int, dev::MLDataDevices.AbstractDevice) + rng = StableRNG(123) + x = randn(rng, Float32, N, N, Bsize) |> dev + return x +end + +function setup_batched_matmul_benchmarks!(suite::BenchmarkGroup, cpu_or_gpu::String, + backend::String, dev::MLDataDevices.AbstractDevice) + for N in [2, 16, 128, 512], Bsize in [4, 32, 128, 512] + benchmark_name = "batchedmm($N, Bsize=$Bsize)" + + suite[benchmark_name]["forward"][cpu_or_gpu][backend] = @benchmarkable begin + batched_matmul(x, x) + synchronize($dev) + end setup=begin + reclaim($dev) + x = batchedmm_setup($N, $Bsize, $dev) + end + + suite[benchmark_name]["zygote"][cpu_or_gpu][backend] = @benchmarkable begin + Zygote.gradient(sumabs2, batched_matmul, x, x) + synchronize($dev) + end setup=begin + reclaim($dev) + x = batchedmm_setup($N, $Bsize, $dev) + Zygote.gradient(sumabs2, batched_matmul, x, x) + end + end +end diff --git a/docs/src/api/Building_Blocks/WeightInitializers.md b/docs/src/api/Building_Blocks/WeightInitializers.md index 5981b04f9..79df20f22 100644 --- a/docs/src/api/Building_Blocks/WeightInitializers.md +++ b/docs/src/api/Building_Blocks/WeightInitializers.md @@ -18,8 +18,8 @@ learning models. | `AMDGPU.rocrand_rng()` | `ROCArray` | | | `AMDGPU.gpuarrays_rng()` | `ROCArray` | | | `GPUArrays.default_rng(ROCArray)` | `ROCArray` | | -| `Metal.gpuarrays_rng()` | `MtlArray` | [`orthogonal`](@ref), [`truncated_normal`](@ref) | -| `GPUArrays.default_rng(MtlArray)` | `MtlArray` | [`orthogonal`](@ref), [`truncated_normal`](@ref) | +| `Metal.gpuarrays_rng()` | `MtlArray` | [`orthogonal`](@ref) | +| `GPUArrays.default_rng(MtlArray)` | `MtlArray` | [`orthogonal`](@ref) | | `oneAPI.gpuarrays_rng()` | `oneArray` | [`orthogonal`](@ref), [`truncated_normal`](@ref) | | `GPUArrays.default_rng(oneArray)` | `oneArray` | [`orthogonal`](@ref), [`truncated_normal`](@ref) | diff --git a/docs/src/ecosystem.md b/docs/src/ecosystem.md index 09ce703ea..d1f0e00e5 100644 --- a/docs/src/ecosystem.md +++ b/docs/src/ecosystem.md @@ -210,7 +210,7 @@ const nnprimitives = [ name: 'LuxLib.jl', desc: 'Backend for Lux.jl', links: [ - { icon: 'github', link: 'https://github.com/LuxDL/LuxLib.jl' } + { icon: 'github', link: 'https://github.com/LuxDL/tree/main/lib/LuxLib.jl' } ] } ]; @@ -310,7 +310,7 @@ const test_utils = [ name: 'LuxTestUtils.jl', desc: 'Collection of Functions useful for testing various packages in the Lux Ecosystem', links: [ - { icon: 'github', link: 'https://github.com/LuxDL/LuxTestUtils.jl' } + { icon: 'github', link: 'https://github.com/LuxDL/tree/main/lib/LuxTestUtils' } ] } ]; diff --git a/lib/LuxCUDA/LICENSE b/lib/LuxCUDA/LICENSE new file mode 100644 index 000000000..e87b80c0d --- /dev/null +++ b/lib/LuxCUDA/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Avik Pal and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lib/LuxCUDA/Project.toml b/lib/LuxCUDA/Project.toml new file mode 100644 index 000000000..a0de0761c --- /dev/null +++ b/lib/LuxCUDA/Project.toml @@ -0,0 +1,15 @@ +name = "LuxCUDA" +uuid = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +authors = ["Avik Pal and contributors"] +version = "0.3.3" + +[deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[compat] +CUDA = "5.3.2" +Reexport = "1" +cuDNN = "1.3" +julia = "1.10" diff --git a/lib/LuxCUDA/README.md b/lib/LuxCUDA/README.md new file mode 100644 index 000000000..453ffb332 --- /dev/null +++ b/lib/LuxCUDA/README.md @@ -0,0 +1,4 @@ +# LuxCUDA + +`LuxCUDA` is meant to be used as a trigger package for all CUDA dependencies in `Lux`. +Users requiring CUDA support should install `LuxCUDA` and load it alongside `Lux`. diff --git a/lib/LuxCUDA/src/LuxCUDA.jl b/lib/LuxCUDA/src/LuxCUDA.jl new file mode 100644 index 000000000..766058dcd --- /dev/null +++ b/lib/LuxCUDA/src/LuxCUDA.jl @@ -0,0 +1,36 @@ +module LuxCUDA + +using Reexport + +@reexport using CUDA, CUDA.CUDAKernels, cuDNN + +const USE_CUDA_GPU = Ref{Union{Nothing, Bool}}(nothing) + +function _check_use_cuda!() + USE_CUDA_GPU[] === nothing || return + + USE_CUDA_GPU[] = CUDA.functional() + if USE_CUDA_GPU[] + if !cuDNN.has_cudnn() + @warn """ + cuDNN is not functional in CUDA.jl. Some functionality will not be available. + """ maxlog=1 + end + else + @warn "LuxCUDA is loaded but the CUDA GPU is not functional." maxlog=1 + end + + return +end + +""" + functional() + +Check if LuxCUDA is functional. +""" +function functional()::Bool + _check_use_cuda!() + return USE_CUDA_GPU[] +end + +end diff --git a/lib/LuxCUDA/test/Project.toml b/lib/LuxCUDA/test/Project.toml new file mode 100644 index 000000000..379f4f88e --- /dev/null +++ b/lib/LuxCUDA/test/Project.toml @@ -0,0 +1,7 @@ +[deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +Aqua = "0.8.4" +Test = "1.10" diff --git a/lib/LuxCUDA/test/runtests.jl b/lib/LuxCUDA/test/runtests.jl new file mode 100644 index 000000000..4e68ea44f --- /dev/null +++ b/lib/LuxCUDA/test/runtests.jl @@ -0,0 +1,10 @@ +using Aqua, LuxCUDA, Test + +@testset "LuxCUDA" begin + @test LuxCUDA.USE_CUDA_GPU[] === nothing + + @test LuxCUDA.functional() isa Bool + + Aqua.test_all(LuxCUDA; ambiguities=false, undefined_exports=false) + Aqua.test_ambiguities(LuxCUDA) +end diff --git a/lib/LuxCore/LICENSE b/lib/LuxCore/LICENSE new file mode 100644 index 000000000..1f70fe758 --- /dev/null +++ b/lib/LuxCore/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Avik Pal and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lib/LuxCore/Project.toml b/lib/LuxCore/Project.toml new file mode 100644 index 000000000..83b0e2730 --- /dev/null +++ b/lib/LuxCore/Project.toml @@ -0,0 +1,42 @@ +name = "LuxCore" +uuid = "bb33d45b-7691-41d6-9220-0943567d0623" +authors = ["Avik Pal and contributors"] +version = "1.0.1" + +[deps] +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[weakdeps] +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[extensions] +LuxCoreArrayInterfaceReverseDiffExt = ["ArrayInterface", "ReverseDiff"] +LuxCoreArrayInterfaceTrackerExt = ["ArrayInterface", "Tracker"] +LuxCoreChainRulesCoreExt = "ChainRulesCore" +LuxCoreEnzymeCoreExt = "EnzymeCore" +LuxCoreFunctorsExt = "Functors" +LuxCoreMLDataDevicesExt = "MLDataDevices" +LuxCoreSetfieldExt = "Setfield" + +[compat] +ArrayInterface = "7.9" +ChainRulesCore = "1.24" +Compat = "4.15.0" +DispatchDoctor = "0.4.10" +EnzymeCore = "0.7.7, 0.8" +Functors = "0.4.12" +MLDataDevices = "1" +Random = "1.10" +ReverseDiff = "1.15" +Setfield = "1" +Tracker = "0.2.34" +julia = "1.10" diff --git a/lib/LuxCore/README.md b/lib/LuxCore/README.md new file mode 100644 index 000000000..d4e0444bf --- /dev/null +++ b/lib/LuxCore/README.md @@ -0,0 +1,6 @@ +# LuxCore + +`LuxCore.jl` defines the abstract layers for Lux. Allows users to be compatible with the +entirely of `Lux.jl` without having such a heavy dependency. If you are depending on +`Lux.jl` directly, you do not need to depend on `LuxCore.jl` (all the functionality is +exported via `Lux.jl`). diff --git a/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl b/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl new file mode 100644 index 000000000..197fcec48 --- /dev/null +++ b/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl @@ -0,0 +1,22 @@ +module LuxCoreArrayInterfaceReverseDiffExt + +using ArrayInterface: ArrayInterface +using LuxCore: LuxCore, AbstractLuxLayer +using ReverseDiff: TrackedReal, TrackedArray + +# AoS to SoA conversion +function LuxCore.apply( + m::AbstractLuxLayer, x::AbstractArray{<:TrackedReal}, ps, st) + @warn "Lux.apply(m::AbstractLuxLayer, \ + x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to \ + Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).\n\n\ + 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ + 2. This might have performance implications. Check which layer was causing this \ + problem using `Lux.Experimental.@debug_mode`." maxlog=1 + return LuxCore.apply(m, reshape(ArrayInterface.aos_to_soa(x), size(x)), ps, st) +end + +## Prevent an infinite loop +LuxCore.apply(m::AbstractLuxLayer, x::TrackedArray, ps, st) = m(x, ps, st) + +end diff --git a/lib/LuxCore/ext/LuxCoreArrayInterfaceTrackerExt.jl b/lib/LuxCore/ext/LuxCoreArrayInterfaceTrackerExt.jl new file mode 100644 index 000000000..3bfa514b7 --- /dev/null +++ b/lib/LuxCore/ext/LuxCoreArrayInterfaceTrackerExt.jl @@ -0,0 +1,21 @@ +module LuxCoreArrayInterfaceTrackerExt + +using ArrayInterface: ArrayInterface +using LuxCore: LuxCore, AbstractLuxLayer +using Tracker: TrackedReal, TrackedArray + +# AoS to SoA conversion +function LuxCore.apply(m::AbstractLuxLayer, x::AbstractArray{<:TrackedReal}, ps, st) + @warn "LuxCore.apply(m::AbstractLuxLayer, \ + x::AbstractArray{<:Tracker.TrackedReal}, ps, st) input was corrected to \ + LuxCore.apply(m::AbstractLuxLayer, x::Tracker.TrackedArray}, ps, st).\n\n\ + 1. If this was not the desired behavior overload the dispatch on `m`.\n\n\ + 2. This might have performance implications. Check which layer was causing this \ + problem using `Lux.Experimental.@debug_mode`." maxlog=1 + return LuxCore.apply(m, ArrayInterface.aos_to_soa(x), ps, st) +end + +## Prevent an infinite loop +LuxCore.apply(m::AbstractLuxLayer, x::TrackedArray, ps, st) = m(x, ps, st) + +end diff --git a/lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl b/lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl new file mode 100644 index 000000000..6b0babd8f --- /dev/null +++ b/lib/LuxCore/ext/LuxCoreChainRulesCoreExt.jl @@ -0,0 +1,15 @@ +module LuxCoreChainRulesCoreExt + +using ChainRulesCore: ChainRulesCore, @non_differentiable +using LuxCore: LuxCore, AbstractLuxLayer +using Random: AbstractRNG + +@non_differentiable LuxCore.replicate(::AbstractRNG) + +function ChainRulesCore.rrule(::typeof(getproperty), m::AbstractLuxLayer, x::Symbol) + mₓ = getproperty(m, x) + ∇getproperty(_) = ntuple(Returns(ChainRulesCore.NoTangent()), 3) + return mₓ, ∇getproperty +end + +end diff --git a/lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl b/lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl new file mode 100644 index 000000000..237ad01fc --- /dev/null +++ b/lib/LuxCore/ext/LuxCoreEnzymeCoreExt.jl @@ -0,0 +1,37 @@ +module LuxCoreEnzymeCoreExt + +using EnzymeCore: EnzymeCore, EnzymeRules +using LuxCore: LuxCore +using Random: AbstractRNG + +EnzymeRules.inactive(::typeof(LuxCore.replicate), ::AbstractRNG) = nothing + +# Handle common mistakes users might make +const LAYER_DERIVATIVE_ERROR_MSG = """ +Lux Layers only support `EnzymeCore.Const` annotation. + +Lux Layers are immutable constants and gradients w.r.t. them are `nothing`. To +compute the gradients w.r.t. the layer's parameters, use the first argument returned +by `LuxCore.setup(rng, layer)` instead. +""" + +function EnzymeCore.Active(::LuxCore.AbstractLuxLayer) + throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG)) +end + +for annotation in (:Duplicated, :DuplicatedNoNeed) + @eval function EnzymeCore.$(annotation)( + ::LuxCore.AbstractLuxLayer, ::LuxCore.AbstractLuxLayer) + throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG)) + end +end + +for annotation in (:BatchDuplicated, :BatchDuplicatedNoNeed) + @eval function EnzymeCore.$(annotation)( + ::LuxCore.AbstractLuxLayer, ::NTuple{N, <:LuxCore.AbstractLuxLayer}, + check::Bool=true) where {N} + throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG)) + end +end + +end diff --git a/lib/LuxCore/ext/LuxCoreFunctorsExt.jl b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl new file mode 100644 index 000000000..d97ed3109 --- /dev/null +++ b/lib/LuxCore/ext/LuxCoreFunctorsExt.jl @@ -0,0 +1,33 @@ +module LuxCoreFunctorsExt + +using LuxCore: LuxCore +using Functors: Functors + +LuxCore.Internal.is_extension_loaded(::Val{:Functors}) = true + +LuxCore.Internal.isleaf_impl(args...; kwargs...) = Functors.isleaf(args...; kwargs...) +LuxCore.Internal.fmap_impl(args...; kwargs...) = Functors.fmap(args...; kwargs...) +function LuxCore.Internal.fmap_with_path_impl(args...; kwargs...) + return Functors.fmap_with_path(args...; kwargs...) +end +LuxCore.Internal.fleaves_impl(args...; kwargs...) = Functors.fleaves(args...; kwargs...) + +function Functors.functor(::Type{<:LuxCore.AbstractLuxContainerLayer{layers}}, + x) where {layers} + children = NamedTuple{layers}(getproperty.((x,), layers)) + layer_reconstructor = let x = x, layers = layers + z -> reduce(LuxCore.Internal.setfield, zip(layers, z); init=x) + end + return children, layer_reconstructor +end + +function Functors.functor(::Type{<:LuxCore.AbstractLuxWrapperLayer{layer}}, + x) where {layer} + children = NamedTuple{(layer,)}((getproperty(x, layer),)) + layer_reconstructor = let x = x, layer = layer + z -> LuxCore.Internal.setfield(x, layer, getproperty(z, layer)) + end + return children, layer_reconstructor +end + +end diff --git a/lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl b/lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl new file mode 100644 index 000000000..1a2dbbd69 --- /dev/null +++ b/lib/LuxCore/ext/LuxCoreMLDataDevicesExt.jl @@ -0,0 +1,16 @@ +module LuxCoreMLDataDevicesExt + +using LuxCore: LuxCore +using MLDataDevices: MLDataDevices + +for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) + ldev = Symbol(dev, :Device) + @eval function (::MLDataDevices.$(ldev))(NN::LuxCore.AbstractLuxLayer) + @warn "Lux layers are stateless and hence don't participate in device transfers. \ + Apply this function on the parameters and states generated using \ + `LuxCore.setup`." + return NN + end +end + +end diff --git a/lib/LuxCore/ext/LuxCoreSetfieldExt.jl b/lib/LuxCore/ext/LuxCoreSetfieldExt.jl new file mode 100644 index 000000000..b814536d9 --- /dev/null +++ b/lib/LuxCore/ext/LuxCoreSetfieldExt.jl @@ -0,0 +1,15 @@ +module LuxCoreSetfieldExt + +using LuxCore: LuxCore +using Setfield: Setfield + +LuxCore.Internal.is_extension_loaded(::Val{:Setfield}) = true + +function LuxCore.Internal.setfield_impl(x, prop, val) + return Setfield.set(x, Setfield.PropertyLens{prop}(), val) +end +function LuxCore.Internal.setfield_impl(x, (prop, val)) + return LuxCore.Internal.setfield_impl(x, prop, val) +end + +end diff --git a/lib/LuxCore/src/LuxCore.jl b/lib/LuxCore/src/LuxCore.jl new file mode 100644 index 000000000..4e9082786 --- /dev/null +++ b/lib/LuxCore/src/LuxCore.jl @@ -0,0 +1,388 @@ +module LuxCore + +using Compat: @compat +using DispatchDoctor: @stable +using Random: Random, AbstractRNG + +# PRNG Handling +""" + replicate(rng::AbstractRNG) + +Creates a copy of the `rng` state depending on its type. +""" +@generated function replicate(rng::T) where {T <: AbstractRNG} + hasmethod(copy, (T,)) && return :(copy(rng)) + return :(deepcopy(rng)) +end +function replicate(rng::Random.TaskLocalRNG) + @warn "`replicate` doesn't work for `TaskLocalRNG`. Returning the same \ + `TaskLocalRNG`." maxlog=1 + return rng +end + +""" + abstract type AbstractLuxLayer + +Abstract Type for all Lux Layers + +Users implementing their custom layer, **must** implement + + - `initialparameters(rng::AbstractRNG, layer::CustomAbstractLuxLayer)` -- This + returns a `NamedTuple` containing the trainable parameters for the layer. + - `initialstates(rng::AbstractRNG, layer::CustomAbstractLuxLayer)` -- This returns a + NamedTuple containing the current state for the layer. For most layers this is typically + empty. Layers that would potentially contain this include `BatchNorm`, `LSTM`, `GRU`, + etc. + +Optionally: + + - `parameterlength(layer::CustomAbstractLuxLayer)` -- These can be automatically + calculated, but it is recommended that the user defines these. + - `statelength(layer::CustomAbstractLuxLayer)` -- These can be automatically + calculated, but it is recommended that the user defines these. + +See also [`AbstractLuxContainerLayer`](@ref) +""" +abstract type AbstractLuxLayer end + +""" + initialparameters(rng::AbstractRNG, layer) + +Generate the initial parameters of the layer `l`. +""" +function initialparameters end + +""" + initialstates(rng::AbstractRNG, layer) + +Generate the initial states of the layer `l`. +""" +function initialstates end + +for op in (:initialparameters, :initialstates) + @eval begin + $(op)(::AbstractRNG, ::Union{AbstractLuxLayer, Nothing}) = NamedTuple() + $(op)(rng::AbstractRNG, l::NamedTuple) = map(Base.Fix1($op, rng), l) + function $(op)(rng::AbstractRNG, l) + contains_lux_layer(l) || throw(MethodError($op, (rng, l))) + return Internal.fmap(Base.Fix1($op, rng), l; exclude=Internal.isleaf) + end + end +end + +""" + parameterlength(layer) + +Return the total number of parameters of the layer `l`. +""" +function parameterlength(l::AbstractLuxLayer) + return parameterlength(initialparameters(Internal.default_rng(), l)) +end +function parameterlength(nt::Union{NamedTuple, Tuple}) + return length(nt) == 0 ? 0 : sum(parameterlength, nt) +end +parameterlength(a::AbstractArray) = length(a) + +""" + statelength(layer) + +Return the total number of states of the layer `l`. +""" +statelength(l::AbstractLuxLayer) = statelength(initialstates(Internal.default_rng(), l)) +statelength(nt::Union{NamedTuple, Tuple}) = length(nt) == 0 ? 0 : sum(statelength, nt) +statelength(a::AbstractArray) = length(a) +statelength(::Any) = 1 + +""" + outputsize(layer, x, rng) + +Return the output size of the layer. + +The fallback implementation of this function assumes the inputs were batched, i.e., +if any of the outputs are Arrays, with `ndims(A) > 1`, it will return +`size(A)[1:(end - 1)]`. If this behavior is undesirable, provide a custom +`outputsize(layer, x, rng)` implementation). + +!!! warning "Fallback Implementation" + + The fallback implementation of this function is defined once `Lux.jl` is loaded. + +!!! warning "Changes from Pre-1.0 Behavior" + + Previously it was possible to override this function by defining `outputsize(layer)`. + However, this can potentially introduce a bug that is hard to bypass. See + [this PR](https://github.com/LuxDL/LuxCore.jl/pull/43) for more information. +""" +function outputsize end + +""" + setup(rng::AbstractRNG, layer) + +Shorthand for getting the parameters and states of the layer `l`. Is equivalent to +`(initialparameters(rng, l), initialstates(rng, l))`. + +!!! warning + + This function is not pure, it mutates `rng`. +""" +setup(rng::AbstractRNG, l) = (initialparameters(rng, l), initialstates(rng, l)) + +""" + apply(model, x, ps, st) + +In most cases this function simply calls `model(x, ps, st)`. However, it is still +recommended to call `apply` instead of `model(x, ps, st)` directly. Some of the reasons for +this include: + + 1. For certain types of inputs `x`, we might want to perform preprocessing before calling + `model`. For eg, if `x` is an Array of `ReverseDiff.TrackedReal`s this can cause + significant regressions in `model(x, ps, st)` (since it won't hit any of the BLAS + dispatches). In those cases, we would automatically convert `x` to a + `ReverseDiff.TrackedArray`. + 2. Certain user defined inputs need to be applied to specific layers but we want the + datatype of propagate through all the layers (even unsupported ones). In these cases, + we can unpack the input in `apply` and pass it to the appropriate layer and then + repack it before returning. See the Lux manual on Custom Input Types for a motivating + example. + +!!! tip + + `apply` is integrated with `DispatchDoctor.jl` that allows automatic verification of + type stability. By default this is "disable"d. For more information, see the + [documentation](https://github.com/MilesCranmer/DispatchDoctor.jl). +""" +@stable default_mode="disable" function apply(model::AbstractLuxLayer, x, ps, st) + return model(x, ps, st) +end + +""" + stateless_apply(model, x, ps) + +Calls `apply` and only returns the first argument. This function requires that `model` has +an empty state of `NamedTuple()`. Behavior of other kinds of models are undefined and it is +the responsibility of the user to ensure that the model has an empty state. +""" +function stateless_apply(model::AbstractLuxLayer, x, ps) + return first(apply(model, x, ps, Internal.get_empty_state(model))) +end + +""" + display_name(layer::AbstractLuxLayer) + +Printed Name of the `layer`. If the `layer` has a field `name` that is used, else the type +name is used. +""" +@generated function display_name(l::L) where {L <: AbstractLuxLayer} + hasfield(L, :name) && + return :(ifelse(l.name === nothing, $(string(nameof(L))), string(l.name))) + return :($(string(nameof(L)))) +end +display_name(::T) where {T} = string(nameof(T)) + +# Abstract Container Layers +""" + abstract type AbstractLuxContainerLayer{layers} <: AbstractLuxLayer + +Abstract Container Type for certain Lux Layers. `layers` is a tuple containing fieldnames +for the layer, and constructs the parameters and states using those. + +Users implementing their custom layer can extend the same functions as in +[`AbstractLuxLayer`](@ref). + +!!! tip "Advanced Structure Manipulation" + + Advanced structure manipulation of these layers post construction is possible via + `Functors.fmap`. For a more flexible interface, we recommend using + `Lux.Experimental.@layer_map`. + +!!! note "`fmap` Support" + + `fmap` support needs to be explicitly enabled by loading `Functors.jl` and + `Setfield.jl`. + +!!! warning "Changes from Pre-1.0 Behavior" + + Previously if `layers` was a singleton tuple, [`initialparameters`](@ref) and + [`initialstates`](@ref) would return the parameters and states for the single field + `layers`. From `v1.0.0` onwards, even for singleton tuples, the parameters/states + are wrapped in a `NamedTuple` with the same name as the field. See + [`AbstractLuxWrapperLayer`](@ref) to replicate the previous behavior of singleton + tuples. +""" +abstract type AbstractLuxContainerLayer{layers} <: AbstractLuxLayer end + +function initialparameters(rng::AbstractRNG, + l::AbstractLuxContainerLayer{layers}) where {layers} + return NamedTuple{layers}(initialparameters.(rng, getfield.((l,), layers))) +end + +function initialstates(rng::AbstractRNG, + l::AbstractLuxContainerLayer{layers}) where {layers} + return NamedTuple{layers}(initialstates.(rng, getfield.((l,), layers))) +end + +function parameterlength(l::AbstractLuxContainerLayer{layers}) where {layers} + return sum(parameterlength, getfield.((l,), layers)) +end + +function statelength(l::AbstractLuxContainerLayer{layers}) where {layers} + return sum(statelength, getfield.((l,), layers)) +end + +""" + abstract type AbstractLuxWrapperLayer{layer} <: AbstractLuxLayer + +See [`AbstractLuxContainerLayer`](@ref) for detailed documentation. This abstract type is +very similar to [`AbstractLuxContainerLayer`](@ref) except that it allows for a single +layer to be wrapped in a container. + +Additionally, on calling [`initialparameters`](@ref) and [`initialstates`](@ref), the +parameters and states are **not** wrapped in a `NamedTuple` with the same name as the +field. + +As a convenience, we define the fallback call `(::AbstractLuxWrapperLayer)(x, ps, st)`, +which calls `getfield(x, layer)(x, ps, st)`. +""" +abstract type AbstractLuxWrapperLayer{layer} <: AbstractLuxLayer end + +function initialparameters( + rng::AbstractRNG, l::AbstractLuxWrapperLayer{layer}) where {layer} + return initialparameters(rng, getfield(l, layer)) +end + +function initialstates(rng::AbstractRNG, l::AbstractLuxWrapperLayer{layer}) where {layer} + return initialstates(rng, getfield(l, layer)) +end + +function parameterlength(l::AbstractLuxWrapperLayer{layer}) where {layer} + return parameterlength(getfield(l, layer)) +end + +function statelength(l::AbstractLuxWrapperLayer{layer}) where {layer} + return statelength(getfield(l, layer)) +end + +function (l::AbstractLuxWrapperLayer{layer})(x, ps, st) where {layer} + return apply(getfield(l, layer), x, ps, st) +end + +# Test Mode +""" + testmode(st::NamedTuple) + +Make all occurrences of `training` in state `st` -- `Val(false)`. +""" +testmode(st::NamedTuple) = update_state(st, :training, Val(false)) + +""" + trainmode(st::NamedTuple) + +Make all occurrences of `training` in state `st` -- `Val(true)`. +""" +trainmode(st::NamedTuple) = update_state(st, :training, Val(true)) + +""" + update_state(st::NamedTuple, key::Symbol, value; exclude=Internal.isleaf) + +Recursively update all occurrences of the `key` in the state `st` with the `value`. +`exclude` is a function that is passed to `Functors.fmap_with_path`'s `exclude` keyword. + +!!! warning "Needs Functors.jl" + + This function requires `Functors.jl` to be loaded. +""" +function update_state(st::NamedTuple, key::Symbol, value; exclude=Internal.isleaf) + fmap_fn = let key = key, value = value + (kp, val) -> begin + last(kp) == key && return value + return val + end + end + return Internal.fmap_with_path(fmap_fn, st; exclude) +end + +""" + contains_lux_layer(l) -> Bool + +Check if the structure `l` is a Lux AbstractLuxLayer or a container of such a layer. +""" +function contains_lux_layer(l) + return check_fmap_condition(Base.Fix2(isa, AbstractLuxLayer), AbstractLuxLayer, l) +end + +""" + check_fmap_condition(cond, tmatch::Union{Type, Nothing}, x) -> Bool + +`fmap`s into the structure `x` and see if `cond` is satisfied for any of the leaf elements. + +## Arguments + + - `cond` - A function that takes a single argument and returns a `Bool`. + - `tmatch` - A shortcut to check if `x` is of type `tmatch`. Can be disabled by passing + `nothing`. + - `x` - The structure to check. + +## Returns + +A Boolean Value +""" +check_fmap_condition(cond::C, ::Nothing, x) where {C} = any(cond, Internal.fleaves(x)) +check_fmap_condition(cond::C, ::Nothing, ::NamedTuple{()}) where {C} = any(cond, ()) +function check_fmap_condition(cond::C, ::Type{T}, x) where {C, T} + x isa T && return true + return check_fmap_condition(cond, nothing, x) +end + +module Internal + +using Random: Xoshiro +using ..LuxCore: AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer + +is_extension_loaded(::Val) = false + +function fmap_impl end # Defined in FunctorsExt +function fmap_with_path_impl end # Defined in FunctorsExt +function fleaves_impl end # Defined in FunctorsExt +function isleaf_impl end # Defined in FunctorsExt + +for op in (:fmap, :fleaves, :isleaf, :fmap_with_path) + main_op = Symbol(op, :_impl) + err_msg = "`$op` requires `Functors.jl` to be loaded." + @eval begin + function $(op)(args...; kwargs...) + is_extension_loaded(Val(:Functors)) || throw(ArgumentError($err_msg)) + return $main_op(args...; kwargs...) + end + end +end + +isleaf(::AbstractLuxLayer) = true + +function setfield_impl end # Defined in SetfieldExt + +function setfield(args...; kwargs...) + is_extension_loaded(Val(:Setfield)) && return setfield_impl(args...; kwargs...) + throw(ArgumentError("`setfield` requires `Setfield.jl` to be loaded.")) +end + +default_rng() = Xoshiro(1234) + +get_empty_state(::AbstractLuxLayer) = NamedTuple() +get_empty_state(l::NamedTuple) = map(get_empty_state, l) +function get_empty_state(l::AbstractLuxContainerLayer{layers}) where {layers} + return NamedTuple{layers}(get_empty_state.(getfield.((l,), layers))) +end +function get_empty_state(l::AbstractLuxWrapperLayer{layer}) where {layer} + return get_empty_state(getfield(l, layer)) +end + +end + +@compat(public, + (replicate, trainmode, testmode, update_state, contains_lux_layer, + check_fmap_condition, initialparameters, initialstates, parameterlength, + statelength, outputsize, setup, apply, stateless_apply, display_name)) + +export AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer + +end diff --git a/lib/LuxCore/test/Project.toml b/lib/LuxCore/test/Project.toml new file mode 100644 index 000000000..a1705ea09 --- /dev/null +++ b/lib/LuxCore/test/Project.toml @@ -0,0 +1,20 @@ +[deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +Aqua = "0.8.7" +EnzymeCore = "0.7.7" +ExplicitImports = "1.9.0" +Functors = "0.4.12" +MLDataDevices = "1.0.0" +Optimisers = "0.3.3" +Random = "1.10" +Test = "1.10" diff --git a/lib/LuxCore/test/runtests.jl b/lib/LuxCore/test/runtests.jl new file mode 100644 index 000000000..f55dba799 --- /dev/null +++ b/lib/LuxCore/test/runtests.jl @@ -0,0 +1,423 @@ +using LuxCore, Test + +@testset "Extension Loading Checks (Fail)" begin + @test !LuxCore.Internal.is_extension_loaded(Val(:Setfield)) + @test !LuxCore.Internal.is_extension_loaded(Val(:Functors)) + @test_throws ArgumentError LuxCore.Internal.setfield(1, 2, 3) + @test_throws ArgumentError LuxCore.Internal.fmap(identity, 1) + @test_throws ArgumentError LuxCore.Internal.fleaves(1) +end + +using Functors, Setfield + +@testset "Extension Loading Checks (Pass)" begin + @test LuxCore.Internal.is_extension_loaded(Val(:Setfield)) + @test LuxCore.Internal.is_extension_loaded(Val(:Functors)) +end + +using Aqua, ExplicitImports, Optimisers, Random, EnzymeCore, MLDataDevices + +rng = LuxCore.Internal.default_rng() + +# Define some custom layers +struct Dense <: AbstractLuxLayer + in::Int + out::Int +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::Dense) + return (w=randn(rng, l.out, l.in), b=randn(rng, l.out)) +end + +(::Dense)(x, ps, st) = x, st # Dummy Forward Pass + +struct DenseWrapper{L} <: AbstractLuxWrapperLayer{:layer} + layer::L +end + +# For checking ambiguities in the dispatch +struct DenseWrapper2{L} <: AbstractLuxWrapperLayer{:layer} + layer::L +end + +(d::DenseWrapper2)(x::AbstractArray, ps, st) = d.layer(x, ps, st) + +struct Chain{L} <: AbstractLuxContainerLayer{(:layers,)} + layers::L +end + +function (c::Chain)(x, ps, st) + y, st1 = c.layers[1](x, ps.layers.layer_1, st.layers.layer_1) + y, st2 = c.layers[2](y, ps.layers.layer_2, st.layers.layer_2) + return y, (; layers=(; layer_1=st1, layer_2=st2)) +end + +struct ChainWrapper{L} <: AbstractLuxWrapperLayer{:layers} + layers::L +end + +function (c::ChainWrapper)(x, ps, st) + y, st1 = c.layers[1](x, ps.layer_1, st.layer_1) + y, st2 = c.layers[2](y, ps.layer_2, st.layer_2) + return y, (; layer_1=st1, layer_2=st2) +end + +struct Chain2{L1, L2} <: AbstractLuxContainerLayer{(:layer1, :layer2)} + layer1::L1 + layer2::L2 +end + +function (c::Chain2)(x, ps, st) + y, st1 = c.layer1(x, ps.layer1, st.layer1) + y, st2 = c.layer1(y, ps.layer2, st.layer2) + return y, (; layer1=st1, layer2=st2) +end + +@testset "LuxCore.jl Tests" begin + @testset "AbstractLuxLayer Interface" begin + @testset "Custom Layer" begin + model = Dense(5, 6) + x = randn(rng, Float32, 5) + ps, st = LuxCore.setup(rng, model) + + @test LuxCore.parameterlength(ps) == LuxCore.parameterlength(model) + @test LuxCore.statelength(st) == LuxCore.statelength(model) + + @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + + @test LuxCore.stateless_apply(model, x, ps) == + first(LuxCore.apply(model, x, ps, NamedTuple())) + + @test_nowarn println(model) + + @testset for wrapper in (DenseWrapper, DenseWrapper2) + model2 = DenseWrapper(model) + ps, st = LuxCore.setup(rng, model2) + + @test LuxCore.parameterlength(ps) == LuxCore.parameterlength(model2) + @test LuxCore.statelength(st) == LuxCore.statelength(model2) + + @test model2(x, ps, st)[1] == model(x, ps, st)[1] + + @test_nowarn println(model2) + end + end + + @testset "Default Fallbacks" begin + struct NoParamStateLayer <: AbstractLuxLayer end + + layer = NoParamStateLayer() + @test LuxCore.initialparameters(rng, layer) == NamedTuple() + @test LuxCore.initialstates(rng, layer) == NamedTuple() + + @test LuxCore.parameterlength(zeros(10, 2)) == 20 + @test LuxCore.statelength(zeros(10, 2)) == 20 + @test LuxCore.statelength(Val(true)) == 1 + @test LuxCore.statelength((zeros(10), zeros(5, 2))) == 20 + @test LuxCore.statelength((layer_1=zeros(10), layer_2=zeros(5, 2))) == 20 + + @test LuxCore.initialparameters(rng, NamedTuple()) == NamedTuple() + @test_throws MethodError LuxCore.initialparameters(rng, ()) + @test LuxCore.initialparameters(rng, nothing) == NamedTuple() + @test LuxCore.initialparameters(rng, (nothing, layer)) == + (NamedTuple(), NamedTuple()) + + @test LuxCore.initialstates(rng, NamedTuple()) == NamedTuple() + @test_throws MethodError LuxCore.initialstates(rng, ()) + @test LuxCore.initialstates(rng, nothing) == NamedTuple() + @test LuxCore.initialparameters(rng, (nothing, layer)) == + (NamedTuple(), NamedTuple()) + end + end + + @testset "AbstractLuxContainerLayer Interface" begin + model = Chain((; layer_1=Dense(5, 5), layer_2=Dense(5, 6))) + x = randn(rng, Float32, 5) + ps, st = LuxCore.setup(rng, model) + + @test fieldnames(typeof(ps)) == (:layers,) + @test fieldnames(typeof(st)) == (:layers,) + + @test LuxCore.parameterlength(ps) == + LuxCore.parameterlength(model) == + LuxCore.parameterlength(model.layers[1]) + + LuxCore.parameterlength(model.layers[2]) + @test LuxCore.statelength(st) == + LuxCore.statelength(model) == + LuxCore.statelength(model.layers[1]) + LuxCore.statelength(model.layers[2]) + + @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + + @test LuxCore.stateless_apply(model, x, ps) == + first(LuxCore.apply(model, x, ps, st)) + + @test_nowarn println(model) + + model = Chain2(Dense(5, 5), Dense(5, 6)) + x = randn(rng, Float32, 5) + ps, st = LuxCore.setup(rng, model) + + @test LuxCore.parameterlength(ps) == + LuxCore.parameterlength(model) == + LuxCore.parameterlength(model.layer1) + LuxCore.parameterlength(model.layer2) + @test LuxCore.statelength(st) == + LuxCore.statelength(model) == + LuxCore.statelength(model.layer1) + LuxCore.statelength(model.layer2) + + @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + + @test LuxCore.stateless_apply(model, x, ps) == + first(LuxCore.apply(model, x, ps, st)) + + @test_nowarn println(model) + end + + @testset "AbstractLuxWrapperLayer Interface" begin + model = ChainWrapper((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))) + x = randn(rng, Float32, 5) + ps, st = LuxCore.setup(rng, model) + + @test fieldnames(typeof(ps)) == (:layer_1, :layer_2) + @test fieldnames(typeof(st)) == (:layer_1, :layer_2) + + @test LuxCore.parameterlength(ps) == + LuxCore.parameterlength(model) == + LuxCore.parameterlength(model.layers.layer_1) + + LuxCore.parameterlength(model.layers.layer_2) + @test LuxCore.statelength(st) == + LuxCore.statelength(model) == + LuxCore.statelength(model.layers.layer_1) + + LuxCore.statelength(model.layers.layer_2) + + @test LuxCore.apply(model, x, ps, st) == model(x, ps, st) + + @test LuxCore.stateless_apply(model, x, ps) == + first(LuxCore.apply(model, x, ps, st)) + + @test_nowarn println(model) + end + + @testset "update_state API" begin + st = (layer_1=(training=Val(true), val=1), + layer_2=(layer_1=(val=2,), layer_2=(training=Val(true),))) + + st_ = LuxCore.testmode(st) + + @test st_.layer_1.training == Val(false) && + st_.layer_2.layer_2.training == Val(false) && + st_.layer_1.val == st.layer_1.val && + st_.layer_2.layer_1.val == st.layer_2.layer_1.val + + st = st_ + st_ = LuxCore.trainmode(st) + + @test st_.layer_1.training == Val(true) && + st_.layer_2.layer_2.training == Val(true) && + st_.layer_1.val == st.layer_1.val && + st_.layer_2.layer_1.val == st.layer_2.layer_1.val + + st_ = LuxCore.update_state(st, :val, -1) + @test st_.layer_1.training == st.layer_1.training && + st_.layer_2.layer_2.training == st.layer_2.layer_2.training && + st_.layer_1.val == -1 && + st_.layer_2.layer_1.val == -1 + end + + @testset "Functor Compatibility" begin + @testset "Basic Usage" begin + model = Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))) + + children, reconstructor = Functors.functor(model) + + @test children isa NamedTuple + @test fieldnames(typeof(children)) == (:layers,) + @test children.layers isa NamedTuple + @test fieldnames(typeof(children.layers)) == (:layer_1, :layer_2) + @test children.layers.layer_1 isa Dense + @test children.layers.layer_2 isa Dense + @test children.layers.layer_1.in == 5 + @test children.layers.layer_1.out == 10 + @test children.layers.layer_2.in == 10 + @test children.layers.layer_2.out == 5 + + new_model = reconstructor((; + layers=(; layer_1=Dense(10, 5), layer_2=Dense(5, 10)))) + + @test new_model isa Chain + @test new_model.layers.layer_1.in == 10 + @test new_model.layers.layer_1.out == 5 + @test new_model.layers.layer_2.in == 5 + @test new_model.layers.layer_2.out == 10 + + model = ChainWrapper((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))) + + children, reconstructor = Functors.functor(model) + + @test children isa NamedTuple + @test fieldnames(typeof(children)) == (:layers,) + @test children.layers isa NamedTuple + @test fieldnames(typeof(children.layers)) == (:layer_1, :layer_2) + @test children.layers.layer_1 isa Dense + @test children.layers.layer_2 isa Dense + @test children.layers.layer_1.in == 5 + @test children.layers.layer_1.out == 10 + @test children.layers.layer_2.in == 10 + @test children.layers.layer_2.out == 5 + + new_model = reconstructor((; + layers=(; layer_1=Dense(10, 5), layer_2=Dense(5, 10)))) + + @test new_model isa ChainWrapper + @test new_model.layers.layer_1.in == 10 + @test new_model.layers.layer_1.out == 5 + @test new_model.layers.layer_2.in == 5 + @test new_model.layers.layer_2.out == 10 + end + + @testset "Method Ambiguity" begin + # Needed if defining a layer that works with both Flux and Lux -- See DiffEqFlux.jl + # See https://github.com/SciML/DiffEqFlux.jl/pull/750#issuecomment-1373874944 + struct CustomLayer{M, P} <: AbstractLuxContainerLayer{(:model,)} + model::M + p::P + end + + @functor CustomLayer (p,) + + l = CustomLayer(x -> x, nothing) # Dummy Struct + + @test_nowarn Optimisers.trainable(l) + end + end + + @testset "Display Name" begin + struct StructWithoutName <: AbstractLuxLayer end + + model = StructWithoutName() + + @test LuxCore.display_name(model) == "StructWithoutName" + + struct StructWithName{N} <: AbstractLuxLayer + name::N + end + + model = StructWithName("Test") + + @test LuxCore.display_name(model) == "Test" + + model = StructWithName(nothing) + + @test LuxCore.display_name(model) == "StructWithName" + + @test LuxCore.display_name(rand(20)) == "Array" + end + + @testset "initialparameter/initialstate for Default Containers" begin + models1 = [Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))), + Chain2(Dense(5, 10), Dense(10, 5)), [Dense(5, 10), Dense(10, 5)]] + models2 = [Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))), + Chain2(Dense(5, 10), Dense(10, 5)), (Dense(5, 10), Dense(10, 5))] + + for models in (models1, models2) + ps, st = LuxCore.setup(rng, models) + @test length(ps) == length(models) + @test length(st) == length(models) + @test typeof(ps[1]) == typeof(LuxCore.initialparameters(rng, models[1])) + @test typeof(ps[2]) == typeof(LuxCore.initialparameters(rng, models[2])) + @test typeof(ps[3][1]) == typeof(LuxCore.initialparameters(rng, models[3][1])) + @test typeof(ps[3][2]) == typeof(LuxCore.initialparameters(rng, models[3][2])) + @test typeof(st[1]) == typeof(LuxCore.initialstates(rng, models[1])) + @test typeof(st[2]) == typeof(LuxCore.initialstates(rng, models[2])) + @test typeof(st[3][1]) == typeof(LuxCore.initialstates(rng, models[3][1])) + @test typeof(st[3][2]) == typeof(LuxCore.initialstates(rng, models[3][2])) + end + end + + @testset "Convenience Checks" begin + models1 = [Chain((; layer_1=Dense(5, 10), layer_2=Dense(10, 5))), + Chain2(Dense(5, 10), Dense(10, 5)), [Dense(5, 10), Dense(10, 5)]] + + @test LuxCore.contains_lux_layer(models1) + + models2 = [1, 2, 3, 4] + + @test !LuxCore.contains_lux_layer(models2) + + models3 = [1, 2, 3, (; a=Dense(5, 10), b=Dense(10, 5))] + + @test LuxCore.contains_lux_layer(models3) + end + + @testset "Quality Assurance" begin + Aqua.test_all(LuxCore) + + @test check_no_implicit_imports(LuxCore) === nothing + @test check_no_stale_explicit_imports(LuxCore) === nothing + @test check_no_self_qualified_accesses(LuxCore) === nothing + @test check_all_explicit_imports_via_owners(LuxCore) === nothing + @test check_all_qualified_accesses_via_owners(LuxCore) === nothing + @test_broken check_all_explicit_imports_are_public(LuxCore) === nothing + end + + @testset "replicate" begin + rng = Random.default_rng() + @test LuxCore.replicate(rng) === rng + @test LuxCore.replicate(rng) == rng + + rng = Xoshiro(1234) + @test LuxCore.replicate(rng) !== rng + @test LuxCore.replicate(rng) == rng + end + + @testset "empty fleaves" begin + @test length(fleaves(NamedTuple())) == 0 + @test !LuxCore.check_fmap_condition(isodd, nothing, NamedTuple()) + end + + @testset "Common Lux + Enzyme Mistakes" begin + d = Dense(2, 2) + + @test_throws ArgumentError Active(d) + @test_throws ArgumentError Duplicated(d, d) + @test_throws ArgumentError DuplicatedNoNeed(d, d) + @test_throws ArgumentError BatchDuplicated(d, (d, d)) + @test_throws ArgumentError BatchDuplicatedNoNeed(d, (d, d)) + @test Const(d) isa Const + end + + @testset "Device Transfer Warnings" begin + my_layer = Dense(2, 2) + + dev = cpu_device() + @test_logs ( + :warn, "Lux layers are stateless and hence don't participate in device \ + transfers. Apply this function on the parameters and states generated \ + using `LuxCore.setup`.") dev(my_layer) + end + + @testset "nested `training` key: Issue Lux.jl#849" begin + st = (encoder=(layer_1=NamedTuple(), layer_2=(; training=Val{true}())), + μ=NamedTuple(), + logσ=NamedTuple(), + decoder=(layer_1=NamedTuple(), layer_2=NamedTuple(), layer_3=NamedTuple(), + layer_4=(running_mean=Float32[0.0, 0.0], training=Val{true}())), + rng=Xoshiro(), + training=Val{true}()) + + @test st.encoder.layer_2.training isa Val{true} + @test st.decoder.layer_4.training isa Val{true} + @test st.training isa Val{true} + + st_test = LuxCore.testmode(st) + + @test st_test.encoder.layer_2.training isa Val{false} + @test st_test.decoder.layer_4.training isa Val{false} + @test st_test.training isa Val{false} + + st_train = LuxCore.trainmode(st_test) + + @test st_train.encoder.layer_2.training isa Val{true} + @test st_train.decoder.layer_4.training isa Val{true} + @test st_train.training isa Val{true} + end +end diff --git a/lib/LuxLib/LICENSE b/lib/LuxLib/LICENSE new file mode 100644 index 000000000..1f70fe758 --- /dev/null +++ b/lib/LuxLib/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Avik Pal and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml new file mode 100644 index 000000000..a053be070 --- /dev/null +++ b/lib/LuxLib/Project.toml @@ -0,0 +1,93 @@ +name = "LuxLib" +uuid = "82251201-b29d-42c6-8e01-566dec8acb11" +authors = ["Avik Pal and contributors"] +version = "1.3.6" + +[deps] +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +CpuId = "adafc99b-e345-5852-983c-f28acb93d879" +DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" +Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[weakdeps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" +AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924" +BLISBLAS = "6f275bd8-fec0-4d39-945b-7e95a765fa1e" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" +Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[extensions] +LuxLibAppleAccelerateExt = "AppleAccelerate" +LuxLibBLISBLASExt = "BLISBLAS" +LuxLibCUDAExt = "CUDA" +LuxLibMKLExt = "MKL" +LuxLibEnzymeExt = "Enzyme" +LuxLibLoopVectorizationExt = "LoopVectorization" +LuxLibOctavianExt = ["Octavian", "LoopVectorization"] +LuxLibReverseDiffExt = "ReverseDiff" +LuxLibSLEEFPiratesExt = "SLEEFPirates" +LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] +LuxLibTrackerExt = "Tracker" +LuxLibcuDNNExt = ["CUDA", "cuDNN"] + +[compat] +AMDGPU = "0.9.6, 1" +AppleAccelerate = "0.4" +ArrayInterface = "7.9" +BLISBLAS = "0.1" +CUDA = "5.3.2" +ChainRulesCore = "1.24" +Compat = "4.15.0" +CpuId = "0.3" +DispatchDoctor = "0.4.12" +Enzyme = "0.13.1" +EnzymeCore = "0.8.1" +FastClosures = "0.3.2" +ForwardDiff = "0.10.36" +Hwloc = "3.2" +KernelAbstractions = "0.9.27" +LinearAlgebra = "1.10" +LoopVectorization = "0.12.171" +LuxCore = "1" +MKL = "0.7" +MLDataDevices = "1.2" +Markdown = "1.10" +NNlib = "0.9.24" +Octavian = "0.3.28" +Preferences = "1.4.3" +Polyester = "0.7.15" +Random = "1.10" +Reexport = "1" +ReverseDiff = "1.15" +SLEEFPirates = "0.6.43" +Static = "0.8.4, 1" +StaticArraysCore = "1.4.3" +Statistics = "1.10" +Tracker = "0.2.34" +cuDNN = "1.3" +julia = "1.10" diff --git a/lib/LuxLib/README.md b/lib/LuxLib/README.md new file mode 100644 index 000000000..e7f0c744d --- /dev/null +++ b/lib/LuxLib/README.md @@ -0,0 +1,27 @@ +# LuxLib + +Backend for [Lux.jl](http://lux.csail.mit.edu/). + +## Tutorials + +This is a developer-facing project and most users **should not** depend on it directly. As +such, we don't have tutorials for this package. Instead, we recommend you check out the +[Lux tutorials](http://lux.csail.mit.edu/). + +## What's the distinction from [NNlib.jl](https://github.com/FluxML/NNlib.jl)? + +This is currently a place to hold more specialized kernels and layer implementations for +Lux.jl. Anyone is free to move these to NNlib.jl (this package is MIT licensed), but I +probably don't have the time to do so myself. But incase you do, open an issue here and let +me know I will delete the code from this package. + +## Changelog + +### Updating from v0.2 to v0.3 + +`groupnorm` with statistics tracking support has been removed. + +### Updating from v0.1 to v0.2 + +Support for `CUDA` has been moved to a weak dependency. If you want to use `CUDA`, you need +to install and load `LuxCUDA` as `using LuxCUDA` or `import LuxCUDA`. diff --git a/lib/LuxLib/ext/LuxLibAppleAccelerateExt.jl b/lib/LuxLib/ext/LuxLibAppleAccelerateExt.jl new file mode 100644 index 000000000..9cb55cbaa --- /dev/null +++ b/lib/LuxLib/ext/LuxLibAppleAccelerateExt.jl @@ -0,0 +1,8 @@ +module LuxLibAppleAccelerateExt + +using LuxLib: Utils +using Static: True + +Utils.is_extension_loaded(::Val{:AppleAccelerate}) = True() + +end diff --git a/lib/LuxLib/ext/LuxLibBLISBLASExt.jl b/lib/LuxLib/ext/LuxLibBLISBLASExt.jl new file mode 100644 index 000000000..c1d53768e --- /dev/null +++ b/lib/LuxLib/ext/LuxLibBLISBLASExt.jl @@ -0,0 +1,8 @@ +module LuxLibBLISBLASExt + +using LuxLib: Utils +using Static: True + +Utils.is_extension_loaded(::Val{:BLISBLAS}) = True() + +end diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl new file mode 100644 index 000000000..dd215e735 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibCUDAExt/LuxLibCUDAExt.jl @@ -0,0 +1,13 @@ +module LuxLibCUDAExt + +using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr +using LinearAlgebra: LinearAlgebra, Transpose, Adjoint +using LuxLib: LuxLib, Optional +using LuxLib.Utils: ofeltype_array +using NNlib: NNlib +using Static: True, False + +# Low level functions +include("cublaslt.jl") + +end diff --git a/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl new file mode 100644 index 000000000..438b56377 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibCUDAExt/cublaslt.jl @@ -0,0 +1,216 @@ +const TransOrAdjOrRegStridedCuMatrix{T} = Union{Transpose{T, <:StridedCuMatrix{T}}, + Adjoint{T, <:StridedCuMatrix{T}}, StridedCuMatrix{T}} + +function cublaslt_matmul_fused!( + @nospecialize(y::TransOrAdjOrRegStridedCuMatrix{<:Real}), σ::F, + @nospecialize(w::TransOrAdjOrRegStridedCuMatrix{<:Real}), + @nospecialize(x::TransOrAdjOrRegStridedCuMatrix{<:Real}), + b::Optional{<:StridedCuVector{<:Real}}, + aux::Optional{<:StridedCuMatrix{<:Real}}=nothing) where {F} + transy = y isa Transpose || y isa Adjoint + transx = x isa Transpose || x isa Adjoint + transw = w isa Transpose || x isa Adjoint + return cublaslt_matmul_fused!( + transy, parent(y), σ, transw, parent(w), transx, parent(x), b, aux) +end + +function cublaslt_matmul_fused!(transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, + transw::Bool, @nospecialize(w::StridedCuMatrix{wT}), transx::Bool, + @nospecialize(x::StridedCuMatrix{xT}), b::Optional{<:StridedCuVector}, + aux::Optional{<:StridedCuMatrix}) where {F, yT, wT, xT} + bT = b === nothing ? Bool : eltype(b) + auxT = aux === nothing ? Bool : eltype(aux) + # cuBLASLt will give wrong results if the types are not correct. As a hack we are going + # to promote the types to the largest type + wxT = promote_type(wT, xT, bT, auxT) + @warn "Mixed Precision Inputs received for `weight`: $(typeof(w)) and `x`: \ + $(typeof(x)). Promoting to $(wxT)." maxlog=1 + return cublaslt_matmul_fused!(transy, y, σ, transw, ofeltype_array(wxT, w), + transx, ofeltype_array(wxT, x), ofeltype_array(wxT, b), ofeltype_array(wxT, aux)) +end + +# TODO: use https://docs.nvidia.com/cuda/cublas/#cublasltmatmul for a more robust +# computeType mapping. Currently no one uses Lux with weird type combinations so we +# don't need to worry about it too much and just fall back to the generic +# implementation +# Returns: 0 if successful, -1 if unsuccessful +function cublaslt_matmul_fused!(transy::Bool, @nospecialize(y::StridedCuMatrix{yT}), σ::F, + transw::Bool, @nospecialize(w::StridedCuMatrix{wxT}), transx::Bool, + @nospecialize(x::StridedCuMatrix{wxT}), b::Optional{<:StridedCuVector}, + aux::Optional{<:StridedCuMatrix}) where {F, yT, wxT} + m = size(y, 1) + n = size(y, 2) + k = size(w, 2) + + if b === nothing + size(y, transy ? 2 : 1) == size(w, transw ? 2 : 1) || + throw(DimensionMismatch("size(y) = $(size(y)), size(w) = $(size(w))")) + else + size(y, transy ? 2 : 1) == size(w, transw ? 2 : 1) == size(b, 1) || + throw(DimensionMismatch("size(y) = $(size(y)), size(w) = $(size(w)), size(b) = $(size(b))")) + end + size(x, transx ? 2 : 1) == size(w, transw ? 1 : 2) || + throw(DimensionMismatch("size(x) = $(size(x)), size(w) = $(size(w))")) + + # Create the operation descriptor + operationDesc = Ref{CUBLAS.cublasLtMatmulDesc_t}() + + ## While querying the compute type, promote the types + computeType = CUBLAS.gemmExComputeType(wxT, wxT, yT, m, k, n) + computeType === nothing && return -1 + dataType = convert(CUDA.cudaDataType, yT) + CUBLAS.cublasLtMatmulDescCreate(operationDesc, computeType, dataType) + + # Set the matrix descriptors + ytransop = transy ? CUBLAS.CUBLAS_OP_T : CUBLAS.CUBLAS_OP_N + wtransop = transw ? CUBLAS.CUBLAS_OP_T : CUBLAS.CUBLAS_OP_N + xtransop = transx ? CUBLAS.CUBLAS_OP_T : CUBLAS.CUBLAS_OP_N + + CUBLAS.cublasLtMatmulDescSetAttribute( + operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_TRANSA, + Ref{CUBLAS.cublasOperation_t}(wtransop), sizeof(wtransop)) + CUBLAS.cublasLtMatmulDescSetAttribute( + operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_TRANSB, + Ref{CUBLAS.cublasOperation_t}(xtransop), sizeof(xtransop)) + CUBLAS.cublasLtMatmulDescSetAttribute( + operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_TRANSC, + Ref{CUBLAS.cublasOperation_t}(ytransop), sizeof(ytransop)) + + # Decide on the epilogue + epilogue, activation_fused = epilogue_act(σ, b, aux) + CUBLAS.cublasLtMatmulDescSetAttribute( + operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_EPILOGUE, + Ref{CUBLAS.cublasLtEpilogue_t}(epilogue), sizeof(epilogue)) + + # We have a bias so set the bias pointer + if b !== nothing + bias_ptr = Ref{CuPtr{Cvoid}}(pointer(b)) + CUBLAS.cublasLtMatmulDescSetAttribute( + operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_BIAS_POINTER, + bias_ptr, sizeof(bias_ptr)) + end + + if aux !== nothing + aux_ptr = Ref{CuPtr{Cvoid}}(pointer(aux)) + CUBLAS.cublasLtMatmulDescSetAttribute( + operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + aux_ptr, sizeof(aux_ptr)) + ldaux = max(1, stride(aux, 2)) + CUBLAS.cublasLtMatmulDescSetAttribute( + operationDesc[], CUBLAS.CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, + Ref{Csize_t}(ldaux), sizeof(ldaux)) + end + + # Create the matrix layouts + wdesc = Ref{CUBLAS.cublasLtMatrixLayout_t}() + xdesc = Ref{CUBLAS.cublasLtMatrixLayout_t}() + ydesc = Ref{CUBLAS.cublasLtMatrixLayout_t}() + + CUBLAS.cublasLtMatrixLayoutCreate( + wdesc, convert(CUDA.cudaDataType, wxT), m, k, max(1, stride(w, 2))) + CUBLAS.cublasLtMatrixLayoutCreate( + xdesc, convert(CUDA.cudaDataType, wxT), k, n, max(1, stride(x, 2))) + CUBLAS.cublasLtMatrixLayoutCreate( + ydesc, convert(CUDA.cudaDataType, yT), m, n, max(1, stride(y, 2))) + + # Create the preference. we can customize this but we will stick to the defaults + preference = Ref{CUBLAS.cublasLtMatmulPreference_t}() + CUBLAS.cublasLtMatmulPreferenceCreate(preference) + + # Create the light handle + lthandle = Ref{CUBLAS.cublasLtHandle_t}() + CUBLAS.cublasLtCreate(lthandle) + + # Search for the best algorithm + heuristic = Ref{CUBLAS.cublasLtMatmulHeuristicResult_t}() + returnedResults = Ref{Cint}(0) + CUBLAS.cublasLtMatmulAlgoGetHeuristic( + lthandle[], operationDesc[], wdesc[], xdesc[], ydesc[], + ydesc[], preference[], 1, heuristic, returnedResults) + + returnedResults[] == 0 && return -1 + + CUBLAS.cublasLtMatmul( + lthandle[], operationDesc[], Ref{wxT}(1), w, wdesc[], x, xdesc[], Ref{yT}(0), + y, ydesc[], y, ydesc[], Ref(heuristic[].algo), CUDA.CU_NULL, 0, CUDA.stream()) + + !activation_fused && (@. y = σ(y)) + + return 0 +end + +function epilogue_act(f::F, b, aux) where {F} + if f === identity + @assert aux===nothing "`aux` must be `nothing` for `identity` activation." + b === nothing && return CUBLAS.CUBLASLT_EPILOGUE_DEFAULT, true + return CUBLAS.CUBLASLT_EPILOGUE_BIAS, true + elseif f === NNlib.relu + if b === nothing + aux === nothing && return CUBLAS.CUBLASLT_EPILOGUE_RELU, true + return CUBLAS.CUBLASLT_EPILOGUE_RELU_AUX, true + else + aux === nothing && return CUBLAS.CUBLASLT_EPILOGUE_RELU_BIAS, true + return CUBLAS.CUBLASLT_EPILOGUE_RELU_AUX_BIAS, true + end + elseif f === NNlib.gelu + if b === nothing + aux === nothing && return CUBLAS.CUBLASLT_EPILOGUE_GELU, true + return CUBLAS.CUBLASLT_EPILOGUE_GELU_AUX, true + else + aux === nothing && return CUBLAS.CUBLASLT_EPILOGUE_GELU_BIAS, true + return CUBLAS.CUBLASLT_EPILOGUE_GELU_AUX_BIAS, true + end + else + @assert aux===nothing "`aux` must be `nothing` for `$(f)` activation." + b === nothing && return CUBLAS.CUBLASLT_EPILOGUE_DEFAULT, false + return CUBLAS.CUBLASLT_EPILOGUE_BIAS, false + end +end + +len(x) = length(x) +len(::Nothing) = nothing + +function LuxLib.Impl.cublasLt_fused_dense(act::F, weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}, ::False) where {F} + z = similar(x, LuxLib.concrete_fba_output_eltype(act, weight, x, b), + size(weight, 1), size(x, 2)) + LuxLib.cublasLt_fused_dense!(z, act, weight, x, b) + return z, nothing +end + +function LuxLib.Impl.cublasLt_fused_dense(act::F, weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}, ::True) where {F} + z = similar(x, LuxLib.concrete_fba_output_eltype(act, weight, x, b), + size(weight, 1), size(x, 2)) + y = similar(z) + LuxLib.cublasLt_fused_dense!(z, act, weight, x, b, y) + return z, y +end + +function LuxLib.Impl.cublasLt_fused_dense!( + z::AbstractMatrix, act::F, weight::AbstractMatrix, x::AbstractMatrix, + b::Optional{<:AbstractVector}, y::Optional{<:AbstractMatrix}=nothing) where {F} + if hasmethod(cublaslt_matmul_fused!, + (typeof(z), typeof(act), typeof(weight), typeof(x), typeof(b), typeof(y))) + retcode = cublaslt_matmul_fused!(z, act, weight, x, b, y) + retcode == 0 && return + warn_msg = LazyString( + "cuBLASLt failed for the given inputs ", act, ", ", typeof(weight), + " [", size(weight), "], ", typeof(x), " [", size(x), "], ", + typeof(b), " [", len(b), "]. Falling back to generic implementation.") + @warn warn_msg maxlog=1 + else + @warn "cuBLASLt not available. Falling back to generic implementation." maxlog=1 + end + # Generic fallback + if y === nothing + LinearAlgebra.mul!(z, weight, x) + broadcast!(act ∘ +, z, z, reshape(b, :, 1)) + return + else + LinearAlgebra.mul!(y, weight, x) + broadcast!(+, y, y, reshape(b, :, 1)) + broadcast!(act, z, y) + return + end +end diff --git a/lib/LuxLib/ext/LuxLibEnzymeExt.jl b/lib/LuxLib/ext/LuxLibEnzymeExt.jl new file mode 100644 index 000000000..958075c46 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibEnzymeExt.jl @@ -0,0 +1,8 @@ +module LuxLibEnzymeExt + +using LuxLib: Utils +using Static: True + +Utils.is_extension_loaded(::Val{:Enzyme}) = True() + +end diff --git a/lib/LuxLib/ext/LuxLibLoopVectorizationExt.jl b/lib/LuxLib/ext/LuxLibLoopVectorizationExt.jl new file mode 100644 index 000000000..87a912bec --- /dev/null +++ b/lib/LuxLib/ext/LuxLibLoopVectorizationExt.jl @@ -0,0 +1,72 @@ +module LuxLibLoopVectorizationExt + +using LoopVectorization: LoopVectorization, @tturbo, @turbo, indices +using Polyester: @batch +using Static: True + +using LuxLib: LuxLib, Utils + +Utils.is_extension_loaded(::Val{:LoopVectorization}) = True() + +Utils.can_loopvec_args_check(::True, args...) = LoopVectorization.check_args(args...) + +# matmul +for serial in (true, false) + opname = serial ? :serial_matmul_loopvec! : :matmul_loopvec! + @eval @inline function LuxLib.Impl.$(opname)( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) + if !iszero(β) # Secial case this because Base.FastMath.mul_fast(NaN, false) = NaN + @turbo thread=$(!serial) for K in indices((C, B), 2), J in indices((C, A), 1) + Cⱼₖ = zero(eltype(C)) + for I in indices((A, B), (2, 1)) + Cⱼₖ += A[J, I] * B[I, K] + end + C[J, K] = α * Cⱼₖ + β * C[J, K] + end + else + @turbo thread=$(!serial) for K in indices((C, B), 2), J in indices((C, A), 1) + Cⱼₖ = zero(eltype(C)) + for I in indices((A, B), (2, 1)) + Cⱼₖ += A[J, I] * B[I, K] + end + C[J, K] = α * Cⱼₖ + end + end + end +end + +@inline function LuxLib.Impl.matmuladd_loopvec!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + @tturbo for K in indices((C, B), 2), J in indices((C, A), 1) + Cⱼₖ = zero(eltype(C)) + for I in indices((A, B), (2, 1)) + Cⱼₖ += A[J, I] * B[I, K] + end + C[J, K] = bias[J] + Cⱼₖ + end + return +end + +# batched matmul +function LuxLib.Impl.batched_matmul_loopvec_impl!( + z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}, α::Number=true, β::Number=false) where {zT, xT, yT} + if size(x, 3) == size(y, 3) + @batch for L in axes(z, 3) + LuxLib.Impl.serial_matmul_loopvec!( + Utils.batchview(z, L), Utils.batchview(x, L), Utils.batchview(y, L), α, β) + end + elseif size(x, 3) == 1 + @batch for L in axes(z, 3) + LuxLib.Impl.serial_matmul_loopvec!( + Utils.batchview(z, L), Utils.batchview(x, 1), Utils.batchview(y, L), α, β) + end + else # has to be size(y, 3) == 1 + @batch for L in axes(z, 3) + LuxLib.Impl.serial_matmul_loopvec!( + Utils.batchview(z, L), Utils.batchview(x, L), Utils.batchview(y, 1), α, β) + end + end +end + +end diff --git a/lib/LuxLib/ext/LuxLibMKLExt.jl b/lib/LuxLib/ext/LuxLibMKLExt.jl new file mode 100644 index 000000000..64becb4fa --- /dev/null +++ b/lib/LuxLib/ext/LuxLibMKLExt.jl @@ -0,0 +1,8 @@ +module LuxLibMKLExt + +using LuxLib: Utils +using Static: True + +Utils.is_extension_loaded(::Val{:MKL}) = True() + +end diff --git a/lib/LuxLib/ext/LuxLibOctavianExt.jl b/lib/LuxLib/ext/LuxLibOctavianExt.jl new file mode 100644 index 000000000..a112fa946 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibOctavianExt.jl @@ -0,0 +1,16 @@ +module LuxLibOctavianExt + +using Octavian: Octavian +using Static: True + +using LuxLib: LuxLib, Utils + +Utils.is_extension_loaded(::Val{:Octavian}) = True() + +@inline function LuxLib.Impl.matmul_octavian!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) + Octavian.matmul!(C, A, B, α, β) + return +end + +end diff --git a/lib/LuxLib/ext/LuxLibReverseDiffExt.jl b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl new file mode 100644 index 000000000..229a22a35 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibReverseDiffExt.jl @@ -0,0 +1,68 @@ +module LuxLibReverseDiffExt + +using ChainRulesCore: ChainRulesCore +using LuxLib: LuxLib, Utils, Traits +using NNlib: NNlib +using ReverseDiff: ReverseDiff, TrackedArray, TrackedVector, TrackedReal, + @grad_from_chainrules +using Static: True + +const CRC = ChainRulesCore + +# Patches: Needs upstreaming (I don't know how to construct an MWE though) +function ReverseDiff.increment_deriv!( + t::Union{TrackedArray, TrackedReal}, ::CRC.NoTangent, i) + return ReverseDiff.increment_deriv!(t, zero(eltype(ReverseDiff.value(t))), i) +end +function ReverseDiff.decrement_deriv!( + t::Union{TrackedArray, TrackedReal}, ::CRC.NoTangent, i) + return ReverseDiff.decrement_deriv!(t, zero(eltype(ReverseDiff.value(t))), i) +end + +# Patch Conv for ReverseDiff +for func in (:conv, :depthwiseconv, :∇conv_data, :∇conv_filter), + xType in (:AbstractArray, :TrackedArray), + wType in (:AbstractArray, :TrackedArray) + + Utils.is_tracked(xType, wType) || continue + + @eval @grad_from_chainrules NNlib.$(func)( + x::$(xType), w::$(wType), cdims::NNlib.ConvDims; kwargs...) +end + +# batched_mul +@grad_from_chainrules NNlib.batched_mul( + x::TrackedArray{<:Any, <:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) +@grad_from_chainrules NNlib.batched_mul( + x::TrackedArray{<:Any, <:Any, 3}, y::AbstractArray{<:Any, 3}) +@grad_from_chainrules NNlib.batched_mul( + x::AbstractArray{<:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) + +@grad_from_chainrules LuxLib.Impl.batched_matmul( + x::TrackedArray{<:Any, <:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) +@grad_from_chainrules LuxLib.Impl.batched_matmul( + x::TrackedArray{<:Any, <:Any, 3}, y::AbstractArray{<:Any, 3}) +@grad_from_chainrules LuxLib.Impl.batched_matmul( + x::AbstractArray{<:Any, 3}, y::TrackedArray{<:Any, <:Any, 3}) + +# Currently falls back to mapreduce and has a terrible performance +@grad_from_chainrules Base.sum(::typeof(abs2), x::TrackedArray; kwargs...) + +for pool in (:maxpool, :meanpool, :lpnormpool) + @eval @grad_from_chainrules NNlib.$(pool)(x::TrackedArray, ::NNlib.PoolDims; kwargs...) +end + +# Utils extensions +Utils.remove_tracking(x::TrackedReal) = ReverseDiff.value(x) +Utils.remove_tracking(x::TrackedArray) = ReverseDiff.value(x) +Utils.remove_tracking(x::AbstractArray{<:TrackedReal}) = ReverseDiff.value.(x) +Utils.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = Utils.remove_tracking(T) + +Utils.within_autodiff(::TrackedReal) = True() +Utils.within_autodiff(::TrackedArray) = True() +Utils.within_autodiff(::AbstractArray{<:TrackedReal}) = True() + +# Traits extensions +Traits.is_tracked(::Type{<:TrackedReal}) = True() + +end diff --git a/lib/LuxLib/ext/LuxLibSLEEFPiratesExt.jl b/lib/LuxLib/ext/LuxLibSLEEFPiratesExt.jl new file mode 100644 index 000000000..6c522b2ba --- /dev/null +++ b/lib/LuxLib/ext/LuxLibSLEEFPiratesExt.jl @@ -0,0 +1,58 @@ +module LuxLibSLEEFPiratesExt + +using ChainRulesCore: ChainRulesCore +using NNlib: NNlib +using SLEEFPirates: SLEEFPirates + +using LuxLib: Numeric, Impl + +const CRC = ChainRulesCore + +sigmoid_fast(x::Number) = SLEEFPirates.sigmoid_fast(x) +softplus(x::Number) = SLEEFPirates.softplus(x) +logsigmoid(x::Number) = -softplus(-x) +swish(x::Number) = Base.FastMath.mul_fast(x, sigmoid_fast(x)) +lisht(x::Number) = Base.FastMath.mul_fast(x, tanh_fast(x)) +tanh(x::Number) = SLEEFPirates.tanh(x) +tanh_fast(x::Number) = SLEEFPirates.tanh_fast(x) + +for (f, dfdx) in [ + #! format: off + (:sigmoid_fast, :(conj(Base.FastMath.mul_fast(Ω, Base.FastMath.sub_fast(1, Ω))))), + (:softplus, :(sigmoid_fast(x))), + (:logsigmoid, :(sigmoid_fast(-x))), + (:swish, :(Base.FastMath.add_fast(Ω, Base.FastMath.mul_fast(sigmoid_fast(x), Base.FastMath.sub_fast(1, Ω))))), + (:lisht, :(Base.FastMath.add_fast(x, Base.FastMath.mul_fast(tanh_fast(x), Base.FastMath.sub_fast(1, Ω))))), + (:tanh, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))), + (:tanh_fast, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))) + #! format: on +] + @eval CRC.@scalar_rule($f(x), $(dfdx)) + + ∇f = Symbol(:∇broadcasted_, f) + @eval function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof($f), + x::Union{Numeric, Broadcast.Broadcasted}) + Ω = $(f).(x) + function $(∇f)(dΩ) + ∂x = CRC.InplaceableThunk(dx -> @.(dx+=dΩ * $(dfdx)), CRC.@thunk @.(dΩ*$(dfdx))) + return CRC.NoTangent(), CRC.NoTangent(), ∂x + end + return Ω, $(∇f) + end +end + +for (fbase, ffast) in [ + #! format: off + (NNlib.sigmoid_fast, sigmoid_fast), + (NNlib.softplus, softplus), + (NNlib.logsigmoid, logsigmoid), + (NNlib.swish, swish), + (NNlib.lisht, lisht), + (Base.tanh, tanh), + (NNlib.tanh_fast, tanh_fast) + #! format: on +] + @eval Impl.sleefpirates_fast_act(::typeof($fbase)) = $ffast +end + +end diff --git a/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl new file mode 100644 index 000000000..eef503f66 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibTrackerAMDGPUExt.jl @@ -0,0 +1,58 @@ +module LuxLibTrackerAMDGPUExt + +using AMDGPU: AMDGPU +using NNlib: NNlib, PoolDims +using Tracker: Tracker, TrackedArray + +const ROCTrackedArray{T, N} = TrackedArray{T, N, <:AMDGPU.ROCArray{T, N}} + +# Taken from https://github.com/FluxML/NNlib.jl/blob/07833637dec96d12d0614308d3145b432fdb320a/ext/NNlibAMDGPUExt/NNlibAMDGPUExt.jl#L38 +function nnlib_padding(dims) + pd = NNlib.padding(dims) + if !all(pd[1:2:end] .== pd[2:2:end]) + @warn """ + MIOpen does not support asymmetric padding, defaulting to symmetric choice: + $pd -> $(pd[1:2:end]). + """ maxlog=1 + end + return pd[1:2:end] +end + +# For meanpool and maxpool NNlib directly defines the rrules so we need to define special +# rules for Tracker +for poolname in (:maxpool, :meanpool) + @eval begin + Tracker.@grad function NNlib.$(poolname)( + x_tracked::ROCTrackedArray{<:AMDGPU.MIOpen.MIOPENFloat, N}, + pdims::PoolDims) where {N} + x = Tracker.data(x_tracked) + y = similar( + x, NNlib.output_size(pdims)..., NNlib.channels_out(pdims), size(x, N)) + nd = max(0, 4 - N) + npdims = NNlib.insert_singleton_spatial_dimension(pdims, nd) + + # `workspace` is used in the pullback. + _, workspace = AMDGPU.MIOpen.$(Symbol("$(poolname)!"))( + NNlib.insert_singleton_spatial_dimension(y, nd), + NNlib.insert_singleton_spatial_dimension(x, nd); + dims=NNlib.kernel_size(npdims), + padding=nnlib_padding(npdims), stride=NNlib.stride(npdims)) + + function ∇pooling(Δ) + dx = similar(x) + AMDGPU.MIOpen.$(Symbol("∇$(poolname)!"))( + NNlib.insert_singleton_spatial_dimension(dx, nd), + NNlib.insert_singleton_spatial_dimension(Δ, nd), + NNlib.insert_singleton_spatial_dimension(y, nd), + NNlib.insert_singleton_spatial_dimension(x, nd); + dims=NNlib.kernel_size(npdims), padding=nnlib_padding(npdims), + stride=NNlib.stride(npdims), workspace) + return Tracker.nobacksies($(Expr(:quote, poolname)), (dx, nothing)) + end + + return y, ∇pooling + end + end +end + +end diff --git a/lib/LuxLib/ext/LuxLibTrackerExt.jl b/lib/LuxLib/ext/LuxLibTrackerExt.jl new file mode 100644 index 000000000..230309584 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibTrackerExt.jl @@ -0,0 +1,103 @@ +module LuxLibTrackerExt + +using FastClosures: @closure +using LuxLib: LuxLib, Utils, Traits +using NNlib: NNlib +using Static: True, StaticBool +using Tracker: Tracker, TrackedArray, TrackedReal, TrackedVector + +tracker_data(x) = Tracker.data(x) +tracker_data(x::NNlib.BatchedAdjoint) = NNlib.batched_adjoint(tracker_data(parent(x))) +tracker_data(x::NNlib.BatchedTranspose) = NNlib.batched_transpose(tracker_data(parent(x))) + +# batched matrix multiplication +import LuxLib.Impl: batched_matmul +import NNlib: batched_mul + +## Without the rules on BatchedAdjoint and BatchedTranspose, we end up constructing +## AbstractMatrix{<:TrackedReal} which is not efficient +for T1 in (:AbstractArray, :TrackedArray), T2 in (:AbstractArray, :TrackedArray) + Utils.is_tracked(T1, T2) || continue + + for op in (:batched_mul, :batched_matmul) + @eval begin + $(op)(x::$T1{<:Any, 3}, y::$T2{<:Any, 3}) = Tracker.track($(op), x, y) + function $(op)(x::NNlib.BatchedAdjOrTrans{<:Any, <:$T1{<:Any, 3}}, + y::$T2{<:Any, 3}) + return Tracker.track($(op), x, y) + end + function $(op)( + x::$T1{<:Any, 3}, y::NNlib.BatchedAdjOrTrans{<:Any, <:$T2{<:Any, 3}}) + return Tracker.track($(op), x, y) + end + function $(op)(x::NNlib.BatchedAdjOrTrans{<:Any, <:$T1{<:Any, 3}}, + y::NNlib.BatchedAdjOrTrans{<:Any, <:$T2{<:Any, 3}}) + return Tracker.track($(op), x, y) + end + end + end +end + +for op in (:batched_mul, :batched_matmul) + @eval Tracker.@grad function $(op)(x, y) + z = $(op)(tracker_data(x), tracker_data(y)) + ∇batched_matmul = @closure Δ -> begin + ∂x = $(op)(tracker_data(Δ), NNlib.batched_adjoint(tracker_data(y))) + size(x, 3) == 1 && (∂x = sum(∂x; dims=3)) + ∂y = $(op)(NNlib.batched_adjoint(tracker_data(x)), tracker_data(Δ)) + size(y, 3) == 1 && (∂y = sum(∂y; dims=3)) + return Tracker.nobacksies(:batched_matmul, (∂x, ∂y)) + end + return z, ∇batched_matmul + end +end + +# NNlib: gather +Tracker.@grad_from_chainrules NNlib.gather!( + dst::AbstractArray, src::TrackedArray, idx::AbstractArray) + +# Base.repeat +Tracker.@grad_from_chainrules Base.repeat(x::TrackedArray, counts...) + +# Base.selectdim -- Needed for GPUArrays +Base.selectdim(x::TrackedArray, d::Integer, i) = Tracker.track(selectdim, x, d, i) + +Tracker.@grad function Base.selectdim(x::AbstractArray, d::Integer, i) + x′ = Tracker.data(x) + y = selectdim(x′, d, i) + ∇selectdim = @closure Δ -> begin + ∂x = zero(x′) + selectdim(∂x, d, i) .= Tracker.data(Δ) + return ∂x, nothing, nothing + end + return y, ∇selectdim +end + +# Impl: batchnorm_cudnn +## cuDNN batchnorm -- the chain rule gets defined once cuDNN is loaded +for RM in (:TrackedVector, :Nothing, :AbstractVector), + RV in (:TrackedVector, :Nothing, :AbstractVector), + S in (:TrackedVector, :Nothing, :AbstractVector), + B in (:TrackedVector, :Nothing, :AbstractVector), + XT in (:TrackedArray, :AbstractArray) + + Utils.is_tracked(RM, RV, S, B, XT) || continue + + @eval Tracker.@grad_from_chainrules LuxLib.Impl.batchnorm_cudnn( + γ::$S, β::$B, x::$XT, rμ::$RM, rσ²::$RV, m::Real, ϵ::Real, training::StaticBool) +end + +# Utils extensions +Utils.remove_tracking(x::TrackedReal) = Tracker.data(x) +Utils.remove_tracking(x::TrackedArray) = Tracker.data(x) +Utils.remove_tracking(x::AbstractArray{<:TrackedReal}) = Tracker.data.(x) +Utils.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = Utils.remove_tracking(T) + +Utils.within_autodiff(::TrackedReal) = True() +Utils.within_autodiff(::TrackedArray) = True() +Utils.within_autodiff(::AbstractArray{<:TrackedReal}) = True() + +# Traits extensions +Traits.is_tracked(::Type{<:TrackedReal}) = True() + +end diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl new file mode 100644 index 000000000..77e59d3e4 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl @@ -0,0 +1,42 @@ +module LuxLibcuDNNExt + +using LuxLib: LuxLib, Optional, ∂∅, Impl +using LuxLib.Utils: safe_reshape, safe_vec, unsafe_known +using CUDA: CUDA, CuArray, CuVector, CU_NULL, DenseCuArray, DenseCuVector +using ChainRulesCore: ChainRulesCore +using cuDNN: cuDNN, cudnnBatchNormalizationBackward, + cudnnBatchNormalizationForwardInference, CUDNN_BATCHNORM_SPATIAL, + cudnnBatchNormalizationForwardTraining, cudnnTensorDescriptor, + CUDNN_TENSOR_NCHW, cudnnDataType +using FastClosures: @closure +using Static: StaticBool, False, True + +const CRC = ChainRulesCore + +const cuDNNFloat = Union{Float32, Float64} + +include("batchnorm.jl") + +# api/batchnorm.jl +function Impl.batchnorm(x::Union{<:CuArray{T, 2}, <:CuArray{T, 4}, <:CuArray{T, 5}}, + γ::Optional{<:CuVector{T}}, β::Optional{<:CuVector{T}}, + rμ::Optional{<:CuVector{T}}, rσ²::Optional{<:CuVector{T}}, + training::StaticBool, σ::F, m::Real, ϵ::Real) where {T <: cuDNNFloat, F} + rμₙ, rσ²ₙ = Impl.get_batchnorm_statistics(x, rμ, rσ², training) + y = Impl.batchnorm_cudnn(γ, β, x, rμₙ, rσ²ₙ, m, ϵ, training)[1] + return Impl.activation!!(σ, y), safe_vec(rμₙ), safe_vec(rσ²ₙ) +end + +function CRC.rrule( + ::typeof(Impl.batchnorm_cudnn), γ, β, x, rμ, rσ², m, ϵ, training::StaticBool) + y, xμ, xσ⁻² = Impl.batchnorm_cudnn(γ, β, x, rμ, rσ², m, ϵ, training) + 𝒫x, 𝒫γ, 𝒫β = CRC.ProjectTo(x), CRC.ProjectTo(γ), CRC.ProjectTo(β) + ∇batchnorm_cudnn = @closure Δ -> begin + ∂γ, ∂β, ∂x = Impl.∇batchnorm_cudnn( + γ, β, x, CRC.unthunk(first(Δ)), rμ, rσ², xμ, xσ⁻², ϵ) + return ∂∅, 𝒫γ(∂γ), 𝒫β(∂β), 𝒫x(∂x), ∂∅, ∂∅, ∂∅, ∂∅, ∂∅ + end + return (y, xμ, xσ⁻²), ∇batchnorm_cudnn +end + +end diff --git a/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl new file mode 100644 index 000000000..1cb7bccc1 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibcuDNNExt/batchnorm.jl @@ -0,0 +1,144 @@ +# Difference from the NNlib version: We expose the mean and inv_variance computed in the +# cudnn call, since they can be used at other places like forward mode AD +wsize(x::AbstractArray{T, N}, ::False) where {T, N} = (size(x, N - 1),) +function wsize(x::AbstractArray{T, N}, ::True) where {T, N} + return ntuple(i -> i == N - 1 ? size(x, N - 1) : 1, N) +end + +# Try to avoid hitting this in the first place. An easy workaround is to store the +# gamma and bias parameters in states so that they are never trained +function Impl.batchnorm_cudnn(::Nothing, ::Nothing, x::DenseCuArray, args...) + affine_sz = wsize(x, False()) + γ = CUDA.ones(eltype(x), affine_sz) + β = CUDA.zeros(eltype(x), affine_sz) + + y, xμ, xσ⁻² = Impl.batchnorm_cudnn(γ, β, x, args...) + + CUDA.unsafe_free!(γ) + CUDA.unsafe_free!(β) + + return y, xμ, xσ⁻² +end + +function Impl.batchnorm_cudnn(γ::DenseCuVector{T}, β::DenseCuVector{T}, + x::DenseCuArray{T, 2}, args...) where {T <: cuDNNFloat} + x = reshape(x, 1, 1, size(x, 1), size(x, 2)) + y, xμ, xσ⁻² = Impl.batchnorm_cudnn(γ, β, x, args...) + return dropdims(y; dims=(1, 2)), xμ, xσ⁻² +end + +function Impl.batchnorm_cudnn(γ::DenseCuVector{T}, β::DenseCuVector{T}, + x::Union{DenseCuArray{T, 4}, DenseCuArray{T, 5}}, rμ::Optional{<:DenseCuVector{T}}, + rσ²::Optional{<:DenseCuVector{T}}, args...) where {T <: cuDNNFloat} + y = similar(x) + μ, σ⁻² = batchnorm_cudnn!(y, γ, β, x, rμ, rσ², args...) + return y, μ, σ⁻² +end + +function batchnorm_cudnn!( + y::DenseCuArray{T}, γ′::DenseCuVector{T}, β′::DenseCuVector{T}, x::DenseCuArray{T}, + rμ′::Optional{<:DenseCuVector{T}}, rσ²′::Optional{<:DenseCuVector{T}}, + m, ϵ, training::StaticBool) where {T <: cuDNNFloat} + dims = wsize(x, True()) + + γ = reshape(γ′, dims) + β = reshape(β′, dims) + rμ = safe_reshape(rμ′, dims...) + rσ² = safe_reshape(rσ²′, dims...) + + if rμ === nothing || rσ² === nothing + rμ !== rσ² && throw(ArgumentError("both or neither of rμ and rσ² must be nothing")) + rμ = CU_NULL + rσ² = CU_NULL + end + + xd = cudnnTensorDescriptor(x) + yd = cudnnTensorDescriptor(y) + γβd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), + cuDNN.dim4(dims, Val(CUDNN_TENSOR_NCHW))) + + if unsafe_known(training) + μ = CUDA.zeros(T, dims) + σ⁻² = CUDA.ones(T, dims) + + cudnnBatchNormalizationForwardTraining(cuDNN.handle(), CUDNN_BATCHNORM_SPATIAL, + cuDNN.scalingParameter(T, true), cuDNN.scalingParameter(T, false), + xd, x, yd, y, γβd, γ, β, m, rμ, rσ², ϵ, μ, σ⁻²) + + return μ, σ⁻² + else + cudnnBatchNormalizationForwardInference( + cuDNN.handle(), CUDNN_BATCHNORM_SPATIAL, cuDNN.scalingParameter(T, true), + cuDNN.scalingParameter(T, false), xd, x, yd, y, γβd, γ, β, rμ, rσ², ϵ) + + return similar(x, zero.(dims)), similar(x, zero.(dims)) + end +end + +function Impl.∇batchnorm_cudnn(::Nothing, ::Nothing, x::DenseCuArray, ∂y::DenseCuArray, + rμ::Optional{<:DenseCuVector}, rσ²::Optional{<:DenseCuVector}, args...) + affine_sz = wsize(x, False()) + γ = CUDA.ones(eltype(x), affine_sz) + β = CUDA.zeros(eltype(x), affine_sz) + + ∂γ, ∂β, ∂x = Impl.∇batchnorm_cudnn(γ, β, x, ∂y, rμ, rσ², args...) + + CUDA.unsafe_free!(γ) + CUDA.unsafe_free!(β) + CUDA.unsafe_free!(∂γ) + CUDA.unsafe_free!(∂β) + + return nothing, nothing, ∂x +end + +function Impl.∇batchnorm_cudnn( + γ::DenseCuVector{T}, β::DenseCuVector{T}, x::DenseCuArray{T, 2}, + ∂y::DenseCuArray{T, 2}, rμ::Optional{<:DenseCuVector{T}}, + rσ²::Optional{<:DenseCuVector{T}}, args...) where {T <: cuDNNFloat} + ∂γ, ∂β, ∂x = Impl.∇batchnorm_cudnn(γ, β, reshape(x, 1, 1, size(x, 1), size(x, 2)), + reshape(∂y, 1, 1, size(∂y, 1), size(∂y, 2)), rμ, rσ², args...) + return ∂γ, ∂β, dropdims(∂x; dims=(1, 2)) +end + +function Impl.∇batchnorm_cudnn( + γ::DenseCuVector{T}, β::DenseCuVector{T}, x::DenseCuArray{T, N}, + ∂y::DenseCuArray{T, N}, rμ::Optional{<:DenseCuVector{T}}, + rσ²::Optional{<:DenseCuVector{T}}, args...) where {T <: cuDNNFloat, N} + ∂γ, ∂β, ∂x = similar(γ), similar(β), similar(x) + ∇batchnorm_cudnn!(∂γ, γ, ∂β, ∂x, x, ∂y, rμ, rσ², args...) + return ∂γ, ∂β, ∂x +end + +function ∇batchnorm_cudnn!( + ∂γ′::DenseCuVector{T}, γ′::DenseCuVector{T}, ∂β′::DenseCuVector{T}, + ∂x::DenseCuArray{T, N}, x::DenseCuArray{T, N}, ∂y::DenseCuArray{T, N}, + rμ′::Optional{<:DenseCuVector{T}}, rσ²′::Optional{<:DenseCuVector{T}}, + xμ::Optional{<:DenseCuArray{<:cuDNNFloat, N}}, + xσ⁻²::Optional{<:DenseCuArray{<:cuDNNFloat, N}}, ϵ) where {T <: cuDNNFloat, N} + dims = wsize(x, True()) + + ∂γ = reshape(∂γ′, dims) + γ = reshape(γ′, dims) + ∂β = reshape(∂β′, dims) + rμ = safe_reshape(rμ′, dims...) + rσ² = safe_reshape(rσ²′, dims...) + + if rμ === nothing && rσ² === nothing + rμ = CU_NULL + rσ² = CU_NULL + end + + xd = cudnnTensorDescriptor(x) + ∂yd = cudnnTensorDescriptor(∂y) + ∂xd = cudnnTensorDescriptor(∂x) + γd = cudnnTensorDescriptor(CUDNN_TENSOR_NCHW, cudnnDataType(T), Cint(length(dims)), + cuDNN.dim4(dims, Val(CUDNN_TENSOR_NCHW))) + + xμ = xμ === nothing ? CU_NULL : xμ + xσ⁻² = xσ⁻² === nothing ? CU_NULL : xσ⁻² + + return cudnnBatchNormalizationBackward(cuDNN.handle(), CUDNN_BATCHNORM_SPATIAL, + cuDNN.scalingParameter(T, true), cuDNN.scalingParameter(T, false), + cuDNN.scalingParameter(T, true), cuDNN.scalingParameter(T, false), + xd, x, ∂yd, ∂y, ∂xd, ∂x, γd, γ, ∂γ, ∂β, ϵ, xμ, xσ⁻²) +end diff --git a/lib/LuxLib/src/LuxLib.jl b/lib/LuxLib/src/LuxLib.jl new file mode 100644 index 000000000..f0e5ca707 --- /dev/null +++ b/lib/LuxLib/src/LuxLib.jl @@ -0,0 +1,29 @@ +module LuxLib + +using Compat: @compat +using Preferences: @load_preference +using Reexport: @reexport +using Static: Static, known + +using ChainRulesCore: ChainRulesCore, NoTangent + +using LuxCore: LuxCore +using MLDataDevices: get_device_type, AbstractGPUDevice +using NNlib: NNlib + +const Optional{T} = Union{Nothing, T} +const Numeric = Union{AbstractArray{<:T}, T} where {T <: Number} +const ∂∅ = NoTangent() +const CRC = ChainRulesCore + +const DISABLE_LOOP_VECTORIZATION = @load_preference("disable_loop_vectorization", false) + +include("utils.jl") +include("traits.jl") +include("impl/Impl.jl") +include("api/API.jl") + +@compat(public, + (internal_operation_mode, GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp)) + +end diff --git a/lib/LuxLib/src/api/API.jl b/lib/LuxLib/src/api/API.jl new file mode 100644 index 000000000..d222d92e8 --- /dev/null +++ b/lib/LuxLib/src/api/API.jl @@ -0,0 +1,46 @@ +module API + +using ChainRulesCore: ChainRulesCore +using Markdown: @doc_str +using NNlib: NNlib, ConvDims +using Random: Random, AbstractRNG +using Static: Static, StaticBool, static + +using ..LuxLib: Optional +using ..Impl: Impl, select_fastest_activation +using ..Utils: default_epsilon, expand_batchdim, remove_tracking, static_training_mode + +const CRC = ChainRulesCore + +const TrainingType = Union{Val{true}, Val{false}, StaticBool, Nothing} + +# The names are aliased so we define constants for them +for op in (:batched_matmul, :batchnorm, :bias_activation, :bias_activation!!, + :dropout, :alpha_dropout, :groupnorm, :instancenorm, :layernorm, + :activation, :activation!!, :fused_conv, :fused_dense) + impl_op = Symbol(op, :_impl) + @eval const $impl_op = Impl.$op +end + +include("activation.jl") +include("batched_mul.jl") +include("batchnorm.jl") +include("bias_activation.jl") +include("conv.jl") +include("dense.jl") +include("dropout.jl") +include("groupnorm.jl") +include("instancenorm.jl") +include("layernorm.jl") + +export alpha_dropout, dropout +export bias_activation, bias_activation!! +export batched_matmul +export batchnorm, groupnorm, instancenorm, layernorm +export fast_activation, fast_activation!! +export fused_conv_bias_activation +export fused_dense_bias_activation + +end + +@reexport using .API diff --git a/lib/LuxLib/src/api/activation.jl b/lib/LuxLib/src/api/activation.jl new file mode 100644 index 000000000..df44aa0c6 --- /dev/null +++ b/lib/LuxLib/src/api/activation.jl @@ -0,0 +1,56 @@ +""" + fast_activation!!(σ::F, x::AbstractArray) where {F} + +Compute `σ.(x)` with the best possible implementation available. If it is possible to +rewrite `x` in-place, it does so. If `x` is an immutable array, it falls back to the +generic implementation. + +!!! note + + This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be + done by the user if needed. + +!!! tip "Load `SLEEFPirates.jl` to get faster activations" + + Certain activation functions are replaced with specialized implementations from + [SLEEFPirates.jl](https://github.com/JuliaSIMD/SLEEFPirates.jl) for FP32. This might + lead to faster performance but can cause slight decrease in accuracy (in the floating + point limit). + +## Arguments + + - `σ`: Activation function + - `x`: Input array + +## Returns + + - Output Array with the same size as `x` +""" +function fast_activation!!(σ::F, x::AbstractArray) where {F} + return activation!!_impl(select_fastest_activation(σ, x), x) +end + +""" + fast_activation(σ::F, x::AbstractArray) where {F} + +Compute `σ.(x)` with the best possible implementation available. On CPUs we unroll the +loop and use LoopVectorization.jl to vectorize the computation. On GPUs we use simply use +broadcasting. + +!!! note + + This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be + done by the user if needed. + +## Arguments + + - `σ`: Activation function + - `x`: Input array + +## Returns + + - Output Array with the same size as `x` +""" +function fast_activation(σ::F, x::AbstractArray) where {F} + return activation_impl(select_fastest_activation(σ, x), x) +end diff --git a/lib/LuxLib/src/api/batched_mul.jl b/lib/LuxLib/src/api/batched_mul.jl new file mode 100644 index 000000000..c6cb379a6 --- /dev/null +++ b/lib/LuxLib/src/api/batched_mul.jl @@ -0,0 +1,23 @@ +""" + batched_matmul(x, y) + +Computes the batched matrix multiplication of `x` and `y`. For more details see the NNlib +documentation on `NNlib.batched_mul`. This function is mostly a wrapper around `batched_mul` +but attempts to be faster on CPUs. + +!!! tip "Load `LoopVectorization.jl` to get faster batched matrix multiplication" + + On CPUs loading LoopVectorization adds faster implementations of batched matrix + multiplication. +""" +function batched_matmul(x::AbstractMatrix, y::AbstractArray{yT, 3}) where {yT} + return batched_matmul(expand_batchdim(x), y) +end + +function batched_matmul(x::AbstractArray{xT, 3}, y::AbstractMatrix) where {xT} + return batched_matmul(x, expand_batchdim(y)) +end + +function batched_matmul(x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} + return batched_matmul_impl(x, y) +end diff --git a/lib/LuxLib/src/api/batchnorm.jl b/lib/LuxLib/src/api/batchnorm.jl new file mode 100644 index 000000000..05964f0c6 --- /dev/null +++ b/lib/LuxLib/src/api/batchnorm.jl @@ -0,0 +1,46 @@ +@doc doc""" + batchnorm(x, scale, bias, running_mean, running_var, training, + σ=identity, momentum = 0.1f0, epsilon = eps(eltype(x)) ^ (5 // 7)) + +Batch Normalization. For details see [1]. + +Batch Normalization computes the mean and variance for each +``D_1 \times ... \times D_{N - 2} \times 1 \times D_N`` input slice and normalises the input +accordingly. + +## Arguments + + - `x`: Input to be Normalized + - `scale`: Scale factor (``\gamma``) (can be `nothing`) + - `bias`: Bias factor (``\beta``) (can be `nothing`) + - `running_mean`: Running mean (can be `nothing`) + - `running_var`: Running variance (can be `nothing`) + - `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to + `nothing` to automatically determine if the function is being called within an autodiff + context + - `σ`: Activation function (default: `identity`) + - `momentum`: Momentum for updating running mean and variance (default: `0.1f0`) + - `epsilon`: Value added to the denominator for numerical stability + (default: `eps(eltype(x)) ^ (5 / 7)`) + +## Returns + +Normalized Array of same size as `x`. And a Named Tuple containing the updated running +mean and variance. + +## References + +[1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network + training by reducing internal covariate shift." International conference on machine + learning. PMLR, 2015. +""" +function batchnorm(x::AbstractArray{T, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, + rσ²::Optional{<:AbstractVector}, training::TrainingType, act::F=identity, + momentum::Real=0.1f0, epsilon::Real=default_epsilon(x)) where {F, T, N} + σ = select_fastest_activation(act, x, γ, β, rμ, rσ²) + y, rμ, rσ² = batchnorm_impl( + x, γ, β, rμ, rσ², static_training_mode(training, x, γ, β, rμ, rσ²), + σ, momentum, epsilon) + return y, (; running_mean=remove_tracking(rμ), running_var=remove_tracking(rσ²)) +end diff --git a/lib/LuxLib/src/api/bias_activation.jl b/lib/LuxLib/src/api/bias_activation.jl new file mode 100644 index 000000000..9be9d3a2d --- /dev/null +++ b/lib/LuxLib/src/api/bias_activation.jl @@ -0,0 +1,46 @@ +""" + bias_activation(σ, x, bias) + +Applies the activation function `σ` elementwise to the result of broadcasted addition of `x` +and `bias` along the penultimate dimension. A vector `x` is treated as a matrix with a +single last dimension. + +## Arguments + + - `σ`: Activation function + - `x`: Input to be transformed + - `bias`: Bias to be added. Can be `nothing`. + +See also [`bias_activation!!`](@ref), [`fast_activation`](@ref). +""" +function bias_activation(σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} + bias_act_check(x, bias) + σ′ = select_fastest_activation(σ, x, bias) + return bias_activation_impl(select_fastest_activation(σ, x, bias), x, bias) +end + +""" + bias_activation!!(σ, x, bias) + +Same as [`bias_activation`](@ref) but might update `x` in-place if possible. Users should +not rely on `x` being mutated, it is recommended to use it like +`y = bias_activation!!(σ, x, bias)`. If `x` is updated in-place, `y` aliases `x`. + +See also [`bias_activation`](@ref), [`fast_activation!!`](@ref). +""" +function bias_activation!!( + σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} + bias_act_check(x, bias) + return bias_activation!!_impl(select_fastest_activation(σ, x, bias), x, bias) +end + +bias_act_check(_, __) = nothing +function bias_act_check(x::AbstractArray{xT, N}, bias::AbstractVector) where {xT, N} + if N == 1 + @assert length(bias) == length(x) + else + @assert length(bias) == size(x, N - 1) + end +end + +CRC.@non_differentiable bias_act_check(::Any, ::Any) diff --git a/lib/LuxLib/src/api/conv.jl b/lib/LuxLib/src/api/conv.jl new file mode 100644 index 000000000..031e340be --- /dev/null +++ b/lib/LuxLib/src/api/conv.jl @@ -0,0 +1,34 @@ +""" + fused_conv_bias_activation(σ::F, weight::AbstractArray, x::AbstractArray, + b::Optional{<:AbstractVector}, cdims::ConvDims) where {F} + +Computes `σ.(conv(x, weight, cdims) .+ b)` (`b` is not exactly broadcasted like this, +rather it is reshaped and broadcasted to the penultimate dimension) with the best possible +implementation available. This operation fuses operations into a single kernel if possible, +and minimizes reallocations by reusing the output buffer for multiple operations. + +## Arguments + + - `σ`: Activation function + - `weight`: Weight tensor + - `x`: Input tensor + - `b`: Bias tensor (can be `nothing`) + - `cdims`: `ConvDims` object + +## Notes on implementation + + - For CUDA Arrays, this uses fused CUDNN kernels when the activation is `identity` or + `relu`. For other activations, it tries to fuse the operations on the Julia side. + - If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to + the generic non-mutating implementation. + - Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD + backends or backends that support mutation. Backends like `Tracker` and `ReverseDiff` + fallback to the generic implementation. + - For Mixed-Precision Inputs on GPU, we type promote the inputs to the highest precision, + with a warning. +""" +function fused_conv_bias_activation( + σ::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + b::Optional{<:AbstractVector}, cdims::ConvDims) where {F, N, wT, xT} + return fused_conv_impl(select_fastest_activation(σ, weight, x, b), weight, x, b, cdims) +end diff --git a/lib/LuxLib/src/api/dense.jl b/lib/LuxLib/src/api/dense.jl new file mode 100644 index 000000000..f51b2518f --- /dev/null +++ b/lib/LuxLib/src/api/dense.jl @@ -0,0 +1,36 @@ +""" + fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, + b::Optional{<:AbstractVector}) where {F} + +Compute `σ.(weight * x .+ b)` with the best possible implementation available. Currently +this implementation attempts to minimize reallocations by reusing the output buffer for +multiple operations. + +## Arguments + + - `σ`: Activation function + - `weight`: Weight matrix + - `x`: Input matrix + - `b`: Bias vector (can be `nothing`) + +## Notes on implementation + + - If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to + the generic non-mutating implementation. + - Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD + backends or backends that support mutation. Backends like `Tracker` and `ReverseDiff` + fallback to the generic implementation. + - For CUDA Arrays, this uses a special fused implementation via cuBLASLt. + - For small CPU Arrays, we use LoopVectorization.jl. On `x86_64` we use Octavian for + medium sized matrices. This is overridden if special BLAS implementations are loaded + (currently `MKL`, `AppleAccelerate`, and `BLISBLAS`). + +!!! tip "Load `Octavian.jl` + + Loading `Octavian.jl` enables a polyalgorithm that uses different backends based on the + input sizes. +""" +function fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix, + b::Optional{<:AbstractVector}) where {F} + return fused_dense_impl(select_fastest_activation(σ, weight, x, b), weight, x, b) +end diff --git a/lib/LuxLib/src/api/dropout.jl b/lib/LuxLib/src/api/dropout.jl new file mode 100644 index 000000000..3d4e4c6dd --- /dev/null +++ b/lib/LuxLib/src/api/dropout.jl @@ -0,0 +1,80 @@ +""" + dropout(rng::AbstractRNG, x, p, training, invp, dims) + dropout(rng::AbstractRNG, x, mask, p, training, update_mask::Union{Val, StaticBool}, + invp, dims) + +Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1]. + +## Arguments + + - `rng`: Random number generator + - `x`: Input Array + - `mask`: Dropout Mask. If not used then it is constructed automatically + - `p`: Probability of an element to be dropped out + - `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to + `nothing` to automatically determine if the function is being called within an autodiff + context + - `update_mask`: If `Val(true)` or `True()` then the mask is generated and used. Else, the + `mask` provided is directly used + - `invp`: Inverse multiplied to the mask. Calculated as `invp = 1 / (1 - p)`. + +## Returns + + - Output Array after applying dropout + - Dropout Mask (if `training == false`, the returned value is meaningless) + - Updated state for the random number generator + +## References + +[1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from +overfitting." The journal of machine learning research 15.1 (2014): 1929-1958. +""" +function dropout(rng::AbstractRNG, x::AbstractArray, p::T, training::TrainingType, invp::T, + dims) where {T} + return dropout_impl(rng, x, p, static_training_mode(training, x), invp, dims) +end + +function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, + p::T, training::TrainingType, update_mask::TrainingType, invp::T, dims) where {T} + return dropout_impl(rng, x, mask, p, static_training_mode(training, x), + static(update_mask), invp, dims) +end + +""" + alpha_dropout(rng::AbstractRNG, x, p, training) + alpha_dropout(rng::AbstractRNG, x, p, training, α, A, B) + +Alpha Dropout: Dropout ensuring that the mean and variance of the output remains same as the +input. For details see [1]. Use the second call signature to avoid recomputing the constants +for a fixed dropout probability. + +## Arguments + + - `rng`: Random number generator + - `x`: Input Array + - `p`: Probability of an element to be dropped out + - `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to + `nothing` to automatically determine if the function is being called within an autodiff + context` + - `α`: `-1.7580993408473766`. Computed at limit x tends to infinity, `selu(x) = -λβ = α` + - `A`: Scaling factor for the mean + - `B`: Scaling factor for the variance + +## Returns + + - Output Array after applying alpha dropout + - Updated state for the random number generator + +## References + +[1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural +information processing systems 30 (2017). +""" +function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, training::TrainingType) + return alpha_dropout_impl(rng, x, p, static_training_mode(training, x)) +end + +function alpha_dropout( + rng::AbstractRNG, x::AbstractArray, p, training::TrainingType, α, A, B) + return alpha_dropout_impl(rng, x, p, static_training_mode(training, x), α, A, B) +end diff --git a/lib/LuxLib/src/api/groupnorm.jl b/lib/LuxLib/src/api/groupnorm.jl new file mode 100644 index 000000000..4e6a7bff8 --- /dev/null +++ b/lib/LuxLib/src/api/groupnorm.jl @@ -0,0 +1,56 @@ +@doc doc""" + groupnorm(x, scale, bias, groups::Int, σ::F=identity, + epsilon::Real=eps(eltype(x)) ^ (5 // 7)) + +Group Normalization. For details see [1]. + +This op is similar to batch normalization, but statistics are shared across equally-sized +groups of channels and not shared across batch dimension. Thus, group normalization does not +depend on the batch composition and does not require maintaining internal state for storing +statistics. + +## Arguments + + - `x`: Input to be Normalized + - `scale`: Scale factor (``\gamma``) (can be `nothing`) + - `bias`: Bias factor (``\beta``) (can be `nothing`) + - `groups`: Number of groups + - `σ`: Activation function (default: `identity`) + - `epsilon`: Value added to the denominator for numerical stability + (default: `eps(eltype(x)) ^ (5 / 7)`) + +## Returns + +The normalized array is returned. + +## References + +[1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference + on computer vision (ECCV). 2018. +""" +function groupnorm(x::AbstractArray{<:Real, N}, scale::Optional{<:AbstractVector}, + bias::Optional{<:AbstractVector}, groups::Int, σ::F=identity, + epsilon::Real=default_epsilon(x)) where {F, N} + assert_valid_groupnorm_arguments(x, scale, bias, groups) + return groupnorm_impl( + x, scale, bias, groups, select_fastest_activation(σ, x, scale, bias), epsilon) +end + +function assert_valid_groupnorm_arguments( + x::AbstractArray{T, N}, scale, bias, groups) where {T, N} + @assert length(scale)==length(bias)==size(x, N - 1) "Length of `scale` and `bias` must \ + be equal to the number of \ + channels ((N - 1) dim of the \ + input array)." + assert_valid_groupnorm_arguments(x, nothing, nothing, groups) + return nothing +end + +function assert_valid_groupnorm_arguments( + x::AbstractArray{T, N}, ::Nothing, ::Nothing, groups::Int) where {T, N} + @assert size(x, N - 1) % groups==0 "Number of channels $(size(x, N - 1)) must be \ + divisible by the number of groups $groups." + return nothing +end + +CRC.@non_differentiable assert_valid_groupnorm_arguments(::Any...) diff --git a/lib/LuxLib/src/api/instancenorm.jl b/lib/LuxLib/src/api/instancenorm.jl new file mode 100644 index 000000000..158785524 --- /dev/null +++ b/lib/LuxLib/src/api/instancenorm.jl @@ -0,0 +1,62 @@ +@doc doc""" + instancenorm(x, scale, bias, training, act, epsilon = eps(eltype(x)) ^ (5 // 7)) + instancenorm(x, scale, bias, running_mean, running_var, training, act, momentum, + epsilon = eps(eltype(x)) ^ (5 // 7)) + +Instance Normalization. For details see [1]. + +Instance Normalization computes the mean and variance for each +``D_1 \times ... \times D_{N - 2} \times 1 \times 1`` input slice and normalises the input +accordingly. + +## Arguments + + - `x`: Input to be Normalized (must be atleast 3D) + - `scale`: Scale factor (``\gamma``) (can be `nothing`) + - `bias`: Bias factor (``\beta``) (can be `nothing`) + - `running_mean`: Running mean (can be `nothing`) + - `running_var`: Running variance (can be `nothing`) + - `training`: Set to `Val(true)` or `True()` if running in training mode. Can be set to + `nothing` to automatically determine if the function is being called within an autodiff + context + - `σ`: Activation function (default: `identity`) + - `epsilon`: Value added to the denominator for numerical stability + (default: `eps(eltype(x)) ^ (5 / 7)`) + - `momentum`: Momentum for updating running mean and variance (default: `0.1f0`) + +## Returns + +Normalized Array of same size as `x`. And a Named Tuple containing the updated running +mean and variance. + +## References + +[1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The + missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016). +""" +function instancenorm(x::AbstractArray, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, training::TrainingType, + σ::F=identity, epsilon::Real=default_epsilon(x)) where {F} + # This API is kept for legacy purposes when we didn't support passing running stats + return instancenorm(x, γ, β, nothing, nothing, training, σ, nothing, epsilon) +end + +function instancenorm(x::AbstractArray, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, + rσ²::Optional{<:AbstractVector}, training::TrainingType, σ::F=identity, + momentum::Optional{<:Real}=0.1f0, epsilon::Real=default_epsilon(x)) where {F} + assert_valid_instancenorm_arguments(x) + + y, rμₙ, rσ²ₙ = instancenorm_impl( + x, γ, β, rμ, rσ², static_training_mode(training, x, γ, β, rμ, rσ²), + select_fastest_activation(σ, x, γ, β), momentum, epsilon) + + return y, (; running_mean=remove_tracking(rμₙ), running_var=remove_tracking(rσ²ₙ)) +end + +function assert_valid_instancenorm_arguments(::AbstractArray{T, N}) where {T, N} + @assert N>2 "`ndims(x) = $(N)` must be at least > 2." + return nothing +end + +CRC.@non_differentiable assert_valid_instancenorm_arguments(::Any...) diff --git a/lib/LuxLib/src/api/layernorm.jl b/lib/LuxLib/src/api/layernorm.jl new file mode 100644 index 000000000..eb147d30e --- /dev/null +++ b/lib/LuxLib/src/api/layernorm.jl @@ -0,0 +1,42 @@ +@doc doc""" + layernorm(x::AbstractArray{xT, N}, scale, bias, σ = identity, dims=1:(N - 1), + epsilon = eps(eltype(x)) ^ (5 / 7)) where {xT, N} + +Layer Normalization. For details see [1]. + +Given an input array ``x``, this layer computes + +```math +y = \frac{x - \mathbb{E}[x]}{\sqrt{Var[x] + \epsilon}} * \gamma + \beta +``` + +and applies the activation function `σ` elementwise to `y`. + +## Arguments + + - `x`: Input to be Normalized + - `scale`: Scale factor (``\gamma``) (can be `nothing`) + - `bias`: Bias factor (``\beta``) (can be `nothing`) + - `σ`: Activation function (default: `identity`) + - `dims`: Dimensions along which the mean and std of `x` is computed. If `nothing` is + passed, the dims are inferred based on the dimensions of scale and bias. For example, + if `x` is `N` dimensional and `scale` and `bias` are `M` dimensional, then the dims + will be `1:(N - M)`. + - `epsilon`: Value added to the denominator for numerical stability + (default: `eps(eltype(x)) ^ (5 / 7)`) + +## Returns + +Normalized Array of same size as `x`. + +## References + +[1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv + preprint arXiv:1607.06450 (2016). +""" +function layernorm(x::AbstractArray{xT, N}, scale::Optional{<:AbstractArray}, + bias::Optional{<:AbstractArray}, σ::F=identity, dims=1:(N - 1), + epsilon::Real=default_epsilon(x)) where {F, xT, N} + return layernorm_impl( + x, scale, bias, select_fastest_activation(σ, x, scale, bias), dims, epsilon) +end diff --git a/lib/LuxLib/src/impl/Impl.jl b/lib/LuxLib/src/impl/Impl.jl new file mode 100644 index 000000000..3bd59797d --- /dev/null +++ b/lib/LuxLib/src/impl/Impl.jl @@ -0,0 +1,56 @@ +module Impl + +using ArrayInterface: ArrayInterface, aos_to_soa +using DispatchDoctor: @stable +using FastClosures: @closure +using StaticArraysCore: StaticVector, SArray +using Static: StaticBool, True, False, static + +using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig +using EnzymeCore: EnzymeCore, EnzymeRules +using ForwardDiff: ForwardDiff + +using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index + +using Polyester: @batch + +using LinearAlgebra: LinearAlgebra, mul! +using Random: Random, AbstractRNG, rand! +using Statistics: Statistics, mean, var + +using LuxCore: LuxCore +using MLDataDevices: get_device_type, CPUDevice, AMDGPUDevice, CUDADevice, XLADevice, + AbstractGPUDevice, AbstractDevice +using NNlib: NNlib, ConvDims + +using ..LuxLib: Optional, Numeric, ∂∅, internal_operation_mode, AbstractInternalArrayOpMode, + GenericBroadcastOp, GPUBroadcastOp, LoopedArrayOp +using ..Utils: Utils, NotaNumber, batchview, concrete_bias_act_output_eltype, contiguous, + copy_drop_gradients, eltype_mismatch, expand_batchdim, + maybe_reduce_BLAS_threads, ofeltype_array, only_derivative, remove_tracking, + reset_BLAS_threads, run_ka_kernel, safe_eltype, safe_vec, safe_warning, + unsafe_known, unrolled_mapreduce, can_loopvec_args, is_extension_loaded, + @enzyme_alternative +using ..Traits: activation_intermediate_not_needed, activation_has_rrule, is_mutable_array, + fuse_cpu_activation +using ..System: explicit_blas_loaded, use_octavian, fits_in_l1cache, fits_in_l2cache, + fits_in_l3cache + +const CRC = ChainRulesCore +const KA = KernelAbstractions + +include("activation.jl") +include("batched_mul.jl") +include("batchnorm.jl") +include("bias_activation.jl") +include("common_ops.jl") +include("conv.jl") +include("dense.jl") +include("dropout.jl") +include("forward_diff.jl") +include("groupnorm.jl") +include("layernorm.jl") +include("matmul.jl") +include("normalization.jl") + +end diff --git a/lib/LuxLib/src/impl/activation.jl b/lib/LuxLib/src/impl/activation.jl new file mode 100644 index 000000000..0b015e3b1 --- /dev/null +++ b/lib/LuxLib/src/impl/activation.jl @@ -0,0 +1,144 @@ +# Entry Points +function activation!!(σ::F, x::AbstractArray) where {F} + return activation!!(internal_operation_mode(x), is_mutable_array(x), σ, x) +end + +activation!(::typeof(identity), ::AbstractArray) = nothing +function activation!(σ::F, x::AbstractArray) where {F} + activation!(x, internal_operation_mode(x), σ, x) + return nothing +end + +activation(::typeof(identity), x::AbstractArray) = x +activation(σ::F, x::AbstractArray) where {F} = activation(internal_operation_mode(x), σ, x) + +# Core Implementation +function activation!!( + opmode::AbstractInternalArrayOpMode, ::False, σ::F, x::AbstractArray) where {F} + return activation(opmode, σ, x) +end +@stable default_mode="disable" function activation!!( + opmode::AbstractInternalArrayOpMode, ::True, σ::F, x::AbstractArray) where {F} + activation!(x, opmode, σ, x) + return x +end + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(activation!!), + opmode::AbstractInternalArrayOpMode, ::True, + σ::F, x::AbstractArray{T}) where {F, T} + if unsafe_known(activation_intermediate_not_needed(σ, T)) + activation!(x, opmode, σ, x) + 𝒫x_no_intermediate = CRC.ProjectTo(x) + ∇activation_no_intermediate_rrule = @closure Δ -> begin + ∂x = ∇activation(CRC.unthunk(Δ), x, σ, NotaNumber()) + return ∂∅, ∂∅, ∂∅, ∂∅, 𝒫x_no_intermediate(∂x) + end + return x, ∇activation_no_intermediate_rrule + end + + if unsafe_known(activation_has_rrule(σ, T)) + y = activation(opmode, σ, x) + 𝓟x_cached = CRC.ProjectTo(x) + ∇activation_rrule = @closure Δ -> begin + ∂x = ∇activation(CRC.unthunk(Δ), y, σ, x) + return ∂∅, ∂∅, ∂∅, ∂∅, 𝓟x_cached(∂x) + end + return y, ∇activation_rrule + end + + res, ∇activation_from_ad = CRC.rrule_via_ad(cfg, activation, opmode, σ, x) + ∇activation_fallback = @closure Δ -> begin + _, ∂opmode, ∂σ, ∂x = ∇activation_from_ad(Δ) + return ∂∅, ∂opmode, ∂∅, ∂σ, ∂x + end + return res, ∇activation_fallback +end + +function activation(::AbstractInternalArrayOpMode, σ::F, x::AbstractArray) where {F} + return broadcast(σ, x) +end +@stable default_mode="disable" function activation( + opmode::LoopedArrayOp, σ::F, x::AbstractArray{T}) where {F, T} + RT = Core.Compiler._return_type(σ, Tuple{T}) + y = similar(x, ifelse(isconcretetype(RT), RT, T)) + activation!(y, opmode, σ, x) + return y +end + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(activation), + opmode::LoopedArrayOp, σ::F, x::AbstractArray{T}) where {F, T} + if unsafe_known(activation_has_rrule(σ, T)) + y = activation(opmode, σ, x) + 𝓟x = CRC.ProjectTo(x) + ∇activation_rrule = @closure Δ -> begin + ∂x = ∇activation(CRC.unthunk(Δ), y, σ, x) + return ∂∅, ∂∅, ∂∅, 𝓟x(∂x) + end + return y, ∇activation_rrule + end + + z, ∇broadcast = CRC.rrule_via_ad(cfg, broadcast, σ, x) + ∇activation_fallback = @closure Δ -> begin + ∂f, ∂σ, ∂x = ∇broadcast(Δ) + return ∂f, ∂∅, ∂σ, ∂x + end + return z, ∇activation_fallback +end + +function activation!( + y::AbstractArray, ::AbstractInternalArrayOpMode, σ::F, x::AbstractArray) where {F} + broadcast!(σ, y, x) + return +end +function activation!(y::AbstractArray, ::LoopedArrayOp, σ::F, x::AbstractArray) where {F} + activation_simd_loop!(y, σ, x) + return +end + +function activation_simd_loop!(y::AbstractArray, σ::F, x::AbstractArray) where {F} + @simd ivdep for I in eachindex(y, x) + @inbounds y[I] = σ(x[I]) + end +end + +# Gradient for activations +∇activation(Δ, _, ::typeof(identity), x) = Δ +function ∇activation(Δ, out, act::F, x) where {F} + return ∇activation(internal_operation_mode((Δ, out)), Δ, out, act, x) +end +function ∇activation(::AbstractInternalArrayOpMode, Δ, out, act::F, x) where {F} + return @. Δ * only_derivative(out, act, x) +end +@inbounds function ∇activation(::LoopedArrayOp, Δ, out, act::F, x) where {F} + y = similar(out) + if x isa NotaNumber + @simd ivdep for i in eachindex(Δ, out) + @inbounds y[i] = only_derivative(out[i], act, x) * Δ[i] + end + else + @simd ivdep for i in eachindex(Δ, out, x) + @inbounds y[i] = only_derivative(out[i], act, x[i]) * Δ[i] + end + end + return y +end + +# Switch some of the activations to use SLEEFPirates.jl if needed +function select_fastest_activation(f::F, xs...) where {F} + return select_fastest_activation( + f, internal_operation_mode(xs), unrolled_mapreduce(safe_eltype, promote_type, xs)) +end + +select_fastest_activation(f::F, ::AbstractInternalArrayOpMode, ::Type{T}) where {F, T} = f + +function select_fastest_activation(f::F, ::LoopedArrayOp, ::Type{T}) where {F, T} + return sleefpirates_fast_act(f, T) +end + +CRC.@non_differentiable select_fastest_activation(::Any...) + +sleefpirates_fast_act(f::F, ::Type{T}) where {F, T} = f +sleefpirates_fast_act(f::F, ::Type{Float32}) where {F} = sleefpirates_fast_act(f) +sleefpirates_fast_act(f::F) where {F} = f + +CRC.@non_differentiable sleefpirates_fast_act(::Any...) diff --git a/lib/LuxLib/src/impl/batched_mul.jl b/lib/LuxLib/src/impl/batched_mul.jl new file mode 100644 index 000000000..b8900d8eb --- /dev/null +++ b/lib/LuxLib/src/impl/batched_mul.jl @@ -0,0 +1,258 @@ +# Entry Point +function batched_matmul(x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} + return batched_matmul(internal_operation_mode((x, y)), x, y) +end + +function batched_matmul(::GenericBroadcastOp, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {xT, yT} + return NNlib.batched_mul(x, y) +end + +for dev in (AMDGPUDevice, CUDADevice) + @eval function batched_matmul(::GPUBroadcastOp{$(dev)}, + x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} + return NNlib.batched_mul(x, y) # GPU versions are well optimized + end +end + +function batched_matmul(opmode::GPUBroadcastOp{<:AbstractGPUDevice}, + x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} + if isconcretetype(Core.Compiler._return_type( + NNlib.batched_mul, Tuple{typeof(x), typeof(y)})) + return NNlib.batched_mul(x, y) # GPU versions are well optimized + end + return fallback_batched_matmul(opmode, x, y) +end + +function batched_matmul( + opmode::GPUBroadcastOp{AMDGPUDevice}, x::AbstractArray{<:Complex, 3}, + y::AbstractArray{<:Complex, 3}) + return fallback_batched_matmul(opmode, x, y) +end + +function batched_matmul(opmode::LoopedArrayOp, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {xT, yT} + if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) || + (size(x, 2) != size(y, 1)) + throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.")) + end + z = similar(x, promote_type(eltype(x), eltype(y)), size(x, 1), + size(y, 2), max(size(x, 3), size(y, 3))) + batched_matmul!(z, opmode, x, y) + return z +end + +function batched_matmul!(z::AbstractArray{zT, 3}, ::AbstractInternalArrayOpMode, + x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT} + batched_mul!(z, x, y) + return +end + +function batched_matmul!(z::AbstractArray{zT, 3}, ::LoopedArrayOp, + x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT} + batched_matmul_cpu!(z, x, y) + return +end + +function batched_matmul_cpu!(z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {zT, xT, yT} + if can_loopvec_args(batchview(z, 1), batchview(x, 1), batchview(y, 1)) && + !unsafe_known(explicit_blas_loaded()) + batched_matmul_loopvec_impl!(z, x, y) + return + end + # Avoid an Enzyme segfault https://github.com/EnzymeAD/Enzyme.jl/issues/1983 + fallback_batched_matmul!(z, LoopedArrayOp(), x, y) + # NNlib.batched_mul!(z, x, y) # XXX: restore once the enzyme segfault is fixed + return +end + +function batched_matmul_loopvec_impl! end + +function fallback_batched_matmul( + opmode, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} + z = similar(x, promote_type(eltype(x), eltype(y)), size(x, 1), + size(y, 2), max(size(x, 3), size(y, 3))) + fallback_batched_matmul!(z, opmode, x, y) + return z +end + +function fallback_batched_matmul!( + z::AbstractArray{zT, 3}, opmode, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {zT, xT, yT} + # XXX: bring back once the enzyme segfault is fixed + # @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \ + # $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \ + # slow." maxlog=1 + + if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) || + (size(x, 2) != size(y, 1)) + throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.")) + end + + if use_threaded_batched_matmul(get_device_type(x)) + unsafe_fallback_threaded_batched_matmul!(z, x, y) + else + unsafe_fallback_serial_batched_matmul!(z, x, y) + end + + return +end + +function unsafe_fallback_serial_batched_matmul!( + z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {zT, xT, yT} + if size(x, 3) == size(y, 3) + for L in axes(z, 3) + mul!(batchview(z, L), batchview(x, L), batchview(y, L)) + end + elseif size(x, 3) == 1 + for L in axes(z, 3) + mul!(batchview(z, L), batchview(x, 1), batchview(y, L)) + end + else # has to be size(y, 3) == 1 + for L in axes(z, 3) + mul!(batchview(z, L), batchview(x, L), batchview(y, 1)) + end + end +end + +function unsafe_fallback_threaded_batched_matmul!( + z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {zT, xT, yT} + old_threads = maybe_reduce_BLAS_threads(z) + + if size(x, 3) == size(y, 3) + Threads.@threads for L in axes(z, 3) + mul!(batchview(z, L), batchview(x, L), batchview(y, L)) + end + elseif size(x, 3) == 1 + Threads.@threads for L in axes(z, 3) + mul!(batchview(z, L), batchview(x, 1), batchview(y, L)) + end + else # has to be size(y, 3) == 1 + Threads.@threads for L in axes(z, 3) + mul!(batchview(z, L), batchview(x, L), batchview(y, 1)) + end + end + + reset_BLAS_threads(old_threads) + return +end + +use_threaded_batched_matmul(::Type) = false +use_threaded_batched_matmul(::Type{CUDADevice}) = true +use_threaded_batched_matmul(::Type{CPUDevice}) = true + +function CRC.rrule(::typeof(batched_matmul), x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {xT, yT} + ∇batched_matmul = @closure Δ_ -> begin + Δ = CRC.unthunk(Δ_) + ∂x = CRC.@thunk begin + tmp = batched_matmul(Δ, NNlib.batched_adjoint(y)) + size(x, 3) == 1 ? sum(tmp; dims=3) : tmp + end + ∂y = CRC.@thunk begin + tmp = batched_matmul(NNlib.batched_adjoint(x), Δ) + size(y, 3) == 1 ? sum(tmp; dims=3) : tmp + end + return ∂∅, ∂x, ∂y + end + return batched_matmul(x, y), ∇batched_matmul +end + +# This is type-piracy but needed to fix a blocking issue. TODO: upstream to NNlib +# Enzyme causes a "active variables passed by value to jl_new_task are not yet supported" +# warning without this patch. +for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!) + @eval begin + function EnzymeRules.augmented_primal( + cfg::EnzymeRules.RevConfigWidth, ::EnzymeCore.Const{typeof($(func))}, + ::Type{RT}, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} + if typeof(C) <: EnzymeCore.Duplicated || typeof(C) <: EnzymeCore.BatchDuplicated + $(func)(C.val, A.val, B.val) + end + + primal = EnzymeRules.needs_primal(cfg) ? C.val : nothing + shadow = EnzymeRules.needs_shadow(cfg) ? C.dval : nothing + + cache_A = (EnzymeRules.overwritten(cfg)[3] && + !(typeof(C) <: EnzymeCore.Const) && + !(typeof(B) <: EnzymeCore.Const)) ? copy(A.val) : nothing + cache_B = (EnzymeRules.overwritten(cfg)[3] && + !(typeof(C) <: EnzymeCore.Const) && + !(typeof(A) <: EnzymeCore.Const)) ? copy(B.val) : nothing + + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B)) + end + + function EnzymeRules.reverse( + cfg::EnzymeRules.RevConfigWidth, ::EnzymeCore.Const{typeof($(func))}, + ::Type{RT}, cache, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}, + B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT} + cache_A, cache_B = cache + + if !(typeof(B) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[3] + cache_A = A.val + end + end + + if !(typeof(A) <: EnzymeCore.Const) && !(typeof(C) <: EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[3] + cache_B = B.val + end + end + + dCs = C.dval + dAs = A isa EnzymeCore.Const ? dCs : A.dval + dBs = B isa EnzymeCore.Const ? dCs : B.dval + + if EnzymeRules.width(cfg) == 1 + dCs = (dCs,) + dAs = (dAs,) + dBs = (dBs,) + end + + # NOTE: The implementation here is memory efficient and non-allocating. However, + # for maximum performance we would want to reuse the parallel batched_mul + # followed by a reduction. + for (dC, dA, dB) in zip(dCs, dAs, dBs) + if !(typeof(C) <: EnzymeCore.Const) && dC !== C.val + if !(typeof(A) <: EnzymeCore.Const) && dA !== A.val + if size(dA, 3) == 1 && size(B.val, 3) != 1 + B′ = NNlib.batched_adjoint(B.val) + dA′ = batchview(dA, 1) + for L in axes(B′, 3) + mul!(dA′, batchview(dC, L), + batchview(B′, L), true, true) + end + else + $(func)(dA, dC, NNlib.batched_adjoint(B.val), true, true) + end + end + + if !(typeof(B) <: EnzymeCore.Const) && dB !== B.val + if size(dB, 3) == 1 && size(A.val, 3) != 1 + A′ = NNlib.batched_adjoint(A.val) + dB′ = batchview(dB, 1) + for L in axes(A′, 3) + mul!(dB′, batchview(A′, L), + batchview(dC, L), true, true) + end + else + $(func)(dB, NNlib.batched_adjoint(A.val), dC, true, true) + end + end + + dC .= 0 + end + end + + return ntuple(Returns(nothing), 3) + end + end +end diff --git a/lib/LuxLib/src/impl/batchnorm.jl b/lib/LuxLib/src/impl/batchnorm.jl new file mode 100644 index 000000000..b15490f1f --- /dev/null +++ b/lib/LuxLib/src/impl/batchnorm.jl @@ -0,0 +1,440 @@ +function batchnorm_cudnn end # Defined in LuxLibcuDNNExt +function ∇batchnorm_cudnn end # Defined in LuxLibcuDNNExt + +function batchnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} + return (ntuple(static, N - 2)..., static(N)) +end + +CRC.@non_differentiable batchnorm_reduce_dims(::Any...) + +function get_batchnorm_statistics(::AbstractArray, rμ::Optional{<:AbstractVector}, + rσ²::Optional{<:AbstractVector}, ::True) + return copy_drop_gradients(rμ), copy_drop_gradients(rσ²) +end + +function get_batchnorm_statistics(x::AbstractArray, ::Nothing, ::Nothing, ::False) + μ, σ² = mean_var(x; dims=unsafe_known(batchnorm_reduce_dims(x)), corrected=false) + return safe_vec(μ), safe_vec(σ²) +end + +function get_batchnorm_statistics( + ::AbstractArray, rμ::AbstractVector, rσ²::AbstractVector, ::False) + return rμ, rσ² +end + +CRC.@non_differentiable get_batchnorm_statistics(::Any...) + +function batchnorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, + rσ²::Optional{<:AbstractVector}, training::StaticBool, act::F, + momentum::Real, ϵ::Real) where {F, xT, N} + (μ, σ²), (rμ, rσ²) = compute_batch_statistics( + x, reshape_norm_dims(x, rμ), reshape_norm_dims(x, rσ²), + batchnorm_reduce_dims(x), training, momentum) + return batchnorm_affine_normalize(act, x, μ, σ², γ, β, ϵ), safe_vec(rμ), safe_vec(rσ²) +end + +function batchnorm_affine_normalize( + act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, + σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT, μT, σ²T, N} + return batchnorm_affine_normalize( + internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ) +end + +function batchnorm_affine_normalize( + ::GenericBroadcastOp, act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, + σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT, μT, σ²T, N} + return affine_normalize( + act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ) +end + +function batchnorm_affine_normalize( + opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, N}, + μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, + ϵ::Real) where {F, xT, μT, σ²T, N} + x′ = reshape(x, :, size(x, N - 1), size(x, N)) + return reshape( + batchnorm_affine_normalize_internal(opmode, act, x′, vec(μ), vec(σ²), γ, β, ϵ), + size(x)) +end + +@stable default_mode="disable" function batchnorm_affine_normalize_internal( + opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, 3}, + μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real) where {F, xT} + y = similar(x, + promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), + safe_eltype(γ), safe_eltype(β))) + batchnorm_affine_normalize_internal!(y, opmode, act, x, μ, σ², γ, β, ϵ) + return y +end + +function batchnorm_affine_normalize_internal!( + y::AbstractArray{yT, 3}, opmode::LoopedArrayOp, act::F, x::AbstractArray{xT, 3}, + μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real, + γ′::Optional{<:AbstractVector}=nothing) where {F, xT, yT} + N = size(y, 2) + γ′ = γ′ === nothing ? + similar(x, promote_type(safe_eltype(γ), safe_eltype(σ²), safe_eltype(ϵ)), N) : + γ′ + β′ = similar(x, promote_type(safe_eltype(β), safe_eltype(σ²), safe_eltype(ϵ)), N) + + compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ) + + if unsafe_known(fuse_cpu_activation(act)) + apply_batchnorm_scale_bias_act_cpu!(y, γ′, β′, x, act) + else + apply_batchnorm_scale_bias_cpu!(y, γ′, β′, x) + activation!(y, opmode, act, y) + end + + return +end + +function compute_batchnorm_scale_bias!(γ′, β′, γ, β, μ, σ², ϵ) + if γ === nothing && β === nothing + @simd ivdep for J in eachindex(γ′, β′, μ, σ²) + @fastmath @inbounds γ′[J] = inv(sqrt(σ²[J] + ϵ)) + @fastmath @inbounds β′[J] = -μ[J] * γ′[J] + end + else + @simd ivdep for J in eachindex(γ′, β′, γ, β, μ, σ²) + @fastmath @inbounds γ′[J] = γ[J] / sqrt(σ²[J] + ϵ) + @fastmath @inbounds β′[J] = β[J] - μ[J] * γ′[J] + end + end +end + +function apply_batchnorm_scale_bias_act_cpu!( + y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, + x::AbstractArray{xT, 3}, σ::F) where {F, xT, yT} + if size(y, 1) == 1 + apply_batchnorm_scale_bias_act_2d_serial_cpu!(y, γ′, β′, x, σ) + else + apply_batchnorm_scale_bias_act_3d_threaded_cpu!(y, γ′, β′, x, σ) + end +end + +@inline function apply_batchnorm_scale_bias_act_2d_serial_cpu!( + y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, + x::AbstractArray{xT, 3}, σ::F) where {F, xT, yT} + for K in axes(x, 3) + @simd ivdep for J in axes(x, 2) + @fastmath @inbounds y[1, J, K] = σ(x[1, J, K] * γ′[J] + β′[J]) + end + end +end + +@inline function apply_batchnorm_scale_bias_act_3d_threaded_cpu!( + y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, + x::AbstractArray{xT, 3}, σ::F) where {F, xT, yT} + @batch for K in axes(x, 3) + for J in axes(x, 2) + @simd ivdep for I in axes(x, 1) + @fastmath @inbounds y[I, J, K] = σ(x[I, J, K] * γ′[J] + β′[J]) + end + end + end +end + +@inline function apply_batchnorm_scale_bias_act_3d_serial_cpu!( + y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, + x::AbstractArray{xT, 3}, σ::F) where {F, xT, yT} + for K in axes(x, 3) + for J in axes(x, 2) + @simd ivdep for I in axes(x, 1) + @fastmath @inbounds y[I, J, K] = σ(x[I, J, K] * γ′[J] + β′[J]) + end + end + end +end + +@enzyme_alternative apply_batchnorm_scale_bias_act_3d_threaded_cpu! apply_batchnorm_scale_bias_act_3d_serial_cpu! + +function apply_batchnorm_scale_bias_cpu!(y::AbstractArray{yT, 3}, γ′::AbstractVector, + β′::AbstractVector, x::AbstractArray{xT, 3}) where {xT, yT} + if size(y, 1) == 1 + apply_batchnorm_scale_bias_2d_serial_cpu!(y, γ′, β′, x) + else + apply_batchnorm_scale_bias_3d_threaded_cpu!(y, γ′, β′, x) + end +end + +@inline function apply_batchnorm_scale_bias_2d_serial_cpu!( + y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, + x::AbstractArray{xT, 3}) where {xT, yT} + for K in axes(x, 3) + @simd ivdep for J in axes(x, 2) + @fastmath @inbounds y[1, J, K] = x[1, J, K] * γ′[J] + β′[J] + end + end +end + +@inline function apply_batchnorm_scale_bias_3d_threaded_cpu!( + y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, + x::AbstractArray{xT, 3}) where {xT, yT} + @batch for K in axes(x, 3) + for J in axes(x, 2) + @simd ivdep for I in axes(x, 1) + @fastmath @inbounds y[I, J, K] = x[I, J, K] * γ′[J] + β′[J] + end + end + end +end + +@inline function apply_batchnorm_scale_bias_3d_serial_cpu!( + y::AbstractArray{yT, 3}, γ′::AbstractVector, β′::AbstractVector, + x::AbstractArray{xT, 3}) where {xT, yT} + for K in axes(x, 3) + for J in axes(x, 2) + @simd ivdep for I in axes(x, 1) + @fastmath @inbounds y[I, J, K] = x[I, J, K] * γ′[J] + β′[J] + end + end + end +end + +@enzyme_alternative apply_batchnorm_scale_bias_3d_threaded_cpu! apply_batchnorm_scale_bias_3d_serial_cpu! + +function batchnorm_affine_normalize_internal!( + y::AbstractArray{yT, 3}, ::GPUBroadcastOp, act::F, x::AbstractArray{xT, 3}, + μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real, + γ′::Optional{<:AbstractVector}=nothing) where {F, xT, yT} + backend = KA.get_backend(y) + run_ka_kernel( + batchnorm_affine_normalize_internal_kernel!, backend, nothing, size(y), + y, γ′, act, x, μ, σ², γ, β, ϵ) + KA.synchronize(backend) +end + +@kernel cpu=false inbounds=true function batchnorm_affine_normalize_internal_kernel!( + y::AbstractArray{<:Number, 3}, @Const(γ′::Nothing), + @Const(f), @Const(x), @Const(μ), @Const(σ²), + @Const(γ::Nothing), @Const(β::Nothing), @Const(ϵ)) + i, j, k = @index(Global, NTuple) + γ′′ = inv(sqrt(σ²[j] + ϵ)) + β′ = -μ[j] * γ′′ + y[i, j, k] = f(muladd(x[i, j, k], γ′′, β′)) +end + +@kernel cpu=false inbounds=true function batchnorm_affine_normalize_internal_kernel!( + y::AbstractArray{<:Number, 3}, γ′::AbstractVector{<:Number}, + @Const(f), @Const(x), @Const(μ), @Const(σ²), + @Const(γ::Nothing), @Const(β::Nothing), @Const(ϵ)) + i, j, k = @index(Global, NTuple) + γ′[j] = inv(sqrt(σ²[j] + ϵ)) + β′ = -μ[j] * γ′[j] + y[i, j, k] = f(muladd(x[i, j, k], γ′[j], β′)) +end + +@kernel cpu=false inbounds=true function batchnorm_affine_normalize_internal_kernel!( + y::AbstractArray{<:Number, 3}, @Const(γ′::Nothing), + @Const(f), @Const(x), @Const(μ), @Const(σ²), + @Const(γ), @Const(β), @Const(ϵ)) + i, j, k = @index(Global, NTuple) + γ′′ = γ[j] / sqrt(σ²[j] + ϵ) + β′ = muladd(-μ[j], γ′′, β[j]) + y[i, j, k] = f(muladd(x[i, j, k], γ′′, β′)) +end + +@kernel cpu=false inbounds=true function batchnorm_affine_normalize_internal_kernel!( + y::AbstractArray{<:Number, 3}, γ′::AbstractVector{<:Number}, + @Const(f), @Const(x), @Const(μ), @Const(σ²), + @Const(γ), @Const(β), @Const(ϵ)) + i, j, k = @index(Global, NTuple) + γ′[j] = γ[j] / sqrt(σ²[j] + ϵ) + β′ = muladd(-μ[j], γ′[j], β[j]) + y[i, j, k] = f(muladd(x[i, j, k], γ′[j], β′)) +end + +function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(batchnorm_affine_normalize_internal), + opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{T, N}, + μ::AbstractVector, σ²::AbstractVector, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real) where {F, T, N} + y = similar(x, + promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), + safe_eltype(γ), safe_eltype(β))) + γ′ = similar( + x, promote_type(safe_eltype(γ), safe_eltype(σ²), safe_eltype(ϵ)), size(x, N - 1)) + + batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ, γ′) + z, ∇activation = CRC.rrule_via_ad( + cfg, activation!!, opmode, is_mutable_array(y), act, y) + + 𝒫x, 𝒫μ, 𝒫σ² = CRC.ProjectTo(x), CRC.ProjectTo(μ), CRC.ProjectTo(σ²) + 𝒫γ = γ === nothing ? identity : CRC.ProjectTo(γ) + 𝒫β = β === nothing ? identity : CRC.ProjectTo(β) + + ∇batchnorm_affine_normalize_internal = @closure Δ -> begin + ∂y = last(∇activation(Δ)) + ∂x, ∂μ, ∂σ², ∂γ, ∂β = ∇batchnorm_affine_normalize(opmode, ∂y, x, μ, σ², γ, β, ϵ, γ′) + return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫μ(∂μ), 𝒫σ²(∂σ²), 𝒫γ(∂γ), 𝒫β(∂β), ∂∅ + end + + return z, ∇batchnorm_affine_normalize_internal +end + +function ∇batchnorm_affine_normalize(opmode::LoopedArrayOp, ∂y::AbstractArray{∂yT, 3}, + x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real, + γ′::AbstractVector) where {∂yT, xT} + ∂x, ∂μ, ∂σ² = similar(x), similar(μ), similar(σ²) + ∂γ = γ === nothing ? nothing : similar(γ) + ∂β = β === nothing ? nothing : similar(β) + + ∇batchnorm_affine_normalize_cpu!(∂x, ∂μ, ∂σ², ∂γ, ∂β, ∂y, x, μ, σ², γ, ϵ, γ′) + + ∂γ = γ === nothing ? ∂∅ : ∂γ + ∂β = β === nothing ? ∂∅ : ∂β + + return ∂x, ∂μ, ∂σ², ∂γ, ∂β +end + +function ∇batchnorm_affine_normalize_cpu!( + ∂x::AbstractArray{∂xT, 3}, ∂μ::AbstractVector{∂μT}, + ∂σ²::AbstractVector{∂σ²T}, ::Nothing, ::Nothing, ∂y::AbstractArray{∂yT, 3}, + x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, ::Nothing, + ϵ::Real, γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT} + half = eltype(∂σ²)(0.5) + + fill!(∂μ, 0) + fill!(∂σ², 0) + + if size(∂y, 1) == 1 + @fastmath @inbounds for K in axes(∂y, 3) + @simd for J in axes(∂y, 2) + idenom = γ′[J] + idenom² = idenom^2 + + xμ = x[1, J, K] - μ[J] + + ∂x[1, J, K] = ∂y[1, J, K] * idenom + ∂μ[J] -= ∂x[1, J, K] + ∂σ²[J] -= ∂x[1, J, K] * xμ * half * idenom² + end + end + else + @fastmath @inbounds for K in axes(∂y, 3), J in axes(∂y, 2) + idenom = γ′[J] + idenom² = idenom^2 + + @simd for I in axes(∂y, 1) + xμ = x[I, J, K] - μ[J] + + ∂x[I, J, K] = ∂y[I, J, K] * idenom + ∂μ[J] -= ∂x[I, J, K] + ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² + end + end + end +end + +function ∇batchnorm_affine_normalize_cpu!( + ∂x::AbstractArray{∂xT, 3}, ∂μ::AbstractVector{∂μT}, + ∂σ²::AbstractVector{∂σ²T}, ∂γ::AbstractVector{∂γT}, + ∂β::AbstractVector{∂βT}, ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, + μ::AbstractVector, σ²::AbstractVector, γ::AbstractVector, ϵ::Real, + γ′::AbstractVector) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT} + half = eltype(∂σ²)(0.5) + + fill!(∂μ, 0) + fill!(∂σ², 0) + fill!(∂γ, 0) + fill!(∂β, 0) + + if size(∂y, 1) == 1 + @fastmath @inbounds for K in axes(∂y, 3) + @simd for J in axes(∂y, 2) + idenom = inv(sqrt(σ²[J] + ϵ)) + idenom² = idenom^2 + + xμ = x[1, J, K] - μ[J] + + ∂x[1, J, K] = ∂y[1, J, K] * γ′[J] + ∂μ[J] -= ∂x[1, J, K] + ∂σ²[J] -= ∂x[1, J, K] * xμ * half * idenom² + ∂γ[J] += ∂y[1, J, K] * xμ * idenom + ∂β[J] += ∂y[1, J, K] + end + end + else + @fastmath @inbounds for K in axes(∂y, 3), J in axes(∂y, 2) + idenom = inv(sqrt(σ²[J] + ϵ)) + idenom² = idenom^2 + + @simd for I in axes(∂y, 1) + xμ = x[I, J, K] - μ[J] + + ∂x[I, J, K] = ∂y[I, J, K] * γ′[J] + ∂μ[J] -= ∂x[I, J, K] + ∂σ²[J] -= ∂x[I, J, K] * xμ * half * idenom² + ∂γ[J] += ∂y[I, J, K] * xμ * idenom + ∂β[J] += ∂y[I, J, K] + end + end + end +end + +function ∇batchnorm_affine_normalize( + opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{∂yT, 3}, + x::AbstractArray{xT, 3}, μ::AbstractVector, σ²::AbstractVector, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, ϵ::Real, + γ′::AbstractVector) where {∂yT, xT} + ∂x, ∂σ² = similar(x), similar(σ², size(x)) + ∂γ = γ === nothing ? nothing : similar(γ, size(x)) + + ∇batchnorm_affine_normalize!(∂x, ∂σ², ∂γ, opmode, ∂y, x, μ, σ², γ, ϵ, γ′) + + ∂μ = dropdims(sum(-, ∂x; dims=(1, 3)); dims=(1, 3)) + ∂σ² = dropdims(sum(∂σ²; dims=(1, 3)); dims=(1, 3)) + ∂γ = γ === nothing ? ∂∅ : dropdims(sum(∂γ; dims=(1, 3)); dims=(1, 3)) + ∂β = β === nothing ? ∂∅ : dropdims(sum(∂y; dims=(1, 3)); dims=(1, 3)) + + return ∂x, ∂μ, ∂σ², ∂γ, ∂β +end + +function ∇batchnorm_affine_normalize!( + ∂x::AbstractArray{∂xT, 3}, ∂σ²::AbstractArray{∂σ²T, 3}, + ∂γ::Optional{<:AbstractArray{<:Any, 3}}, ::GPUBroadcastOp, + ∂y::AbstractArray{∂yT, 3}, x::AbstractArray{xT, 3}, μ::AbstractVector, + σ²::AbstractVector, γ::Optional{<:AbstractVector}, ϵ::Real, + γ′::AbstractVector) where {∂xT, ∂σ²T, ∂yT, xT} + backend = KA.get_backend(∂x) + run_ka_kernel( + ∇batchnorm_affine_normalize_kernel!, backend, nothing, size(∂x), + ∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ′) + KA.synchronize(backend) +end + +@kernel cpu=false inbounds=true function ∇batchnorm_affine_normalize_kernel!( + ∂x, ∂σ², @Const(∂γ::Nothing), @Const(∂y), @Const(x), + @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ′)) + i, j, k = @index(Global, NTuple) + idenom = γ′[j] + idenom² = idenom * idenom + + xμ = x[i, j, k] - μ[j] + + ∂x[i, j, k] = ∂y[i, j, k] * γ′[j] + ∂σ²[i, j, k] = -∂x[i, j, k] * xμ * idenom² / 2 +end + +@kernel cpu=false inbounds=true function ∇batchnorm_affine_normalize_kernel!( + ∂x, ∂σ², ∂γ, @Const(∂y), @Const(x), + @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ′)) + i, j, k = @index(Global, NTuple) + idenom = inv(sqrt(σ²[j] + ϵ)) + idenom² = idenom * idenom + + xμ = x[i, j, k] - μ[j] + + ∂x[i, j, k] = ∂y[i, j, k] * γ′[j] + ∂σ²[i, j, k] = -∂x[i, j, k] * xμ * idenom² / 2 + ∂γ[i, j, k] = ∂y[i, j, k] * xμ * idenom +end diff --git a/lib/LuxLib/src/impl/bias_activation.jl b/lib/LuxLib/src/impl/bias_activation.jl new file mode 100644 index 000000000..f96531a7d --- /dev/null +++ b/lib/LuxLib/src/impl/bias_activation.jl @@ -0,0 +1,283 @@ +# Entry Points +bias_activation(::typeof(identity), x::AbstractVector, ::Nothing) = x +for bType in (Nothing, AbstractVector) + @eval function bias_activation(σ::F, x::AbstractVector, bias::$(bType)) where {F} + return vec(bias_activation(σ, expand_batchdim(x), bias)) + end +end + +bias_activation(::typeof(identity), x::AbstractArray, ::Nothing) = x +function bias_activation(σ::F, x::AbstractArray{xT, N}, ::Nothing) where {F, N, xT} + return activation(σ, x) +end +function bias_activation( + σ::F, x::AbstractArray{xT, N}, bias::AbstractVector{bT}) where {F, N, xT, bT} + return bias_activation(internal_operation_mode((x, bias)), σ, x, bias) +end + +## General Implementation +function bias_activation( + ::GenericBroadcastOp, ::typeof(identity), x::AbstractArray{T1, N}, + bias::AbstractVector{T2}) where {N, T1, T2} + return x .+ reshape_bias(x, bias) +end +function bias_activation(::GenericBroadcastOp, σ::F, x::AbstractArray{T1, N}, + bias::AbstractVector) where {F, N, T1} + return σ.(x .+ reshape_bias(x, bias)) +end + +function bias_activation(::AbstractInternalArrayOpMode, ::typeof(identity), + x::AbstractArray{xT, N}, bias::AbstractVector) where {N, xT} + return x .+ reshape_bias(x, bias) +end +function bias_activation( + ::AbstractInternalArrayOpMode, σ::F, x::AbstractArray{xT, N}, + bias::AbstractVector) where {F, N, xT} + return broadcast(σ ∘ +, x, reshape_bias(x, bias)) +end + +# Prevent ambiguity +@stable default_mode="disable" function bias_activation( + opmode::LoopedArrayOp, ::typeof(identity), + x::AbstractArray{xT, N}, bias::AbstractVector) where {N, xT} + y = similar(x, concrete_bias_act_output_eltype(identity, x, bias)) + bias_activation!(y, opmode, identity, x, bias) + return y +end +@stable default_mode="disable" function bias_activation( + opmode::LoopedArrayOp, σ::F, x::AbstractArray{xT, N}, + bias::AbstractVector) where {F, N, xT} + y = similar(x, concrete_bias_act_output_eltype(σ, x, bias)) + bias_activation!(y, opmode, σ, x, bias) + return y +end + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation), + opmode::AbstractInternalArrayOpMode, σ::F, x::AbstractArray{xT, N}, + bias::AbstractVector) where {F, N, xT} + T = concrete_bias_act_output_eltype(σ, x, bias) + 𝒫x, 𝒫bias = CRC.ProjectTo(x), CRC.ProjectTo(bias) + + if unsafe_known(activation_intermediate_not_needed(σ, T)) + y = bias_activation(opmode, σ, x, bias) + ∇bias_activation_no_intermediate = @closure Δ -> begin + ∂x = ∇activation(CRC.unthunk(Δ), y, σ, NotaNumber()) + ∂b = ∇bias_add(bias, ∂x) + return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫bias(∂b) + end + return y, ∇bias_activation_no_intermediate + end + + if unsafe_known(activation_has_rrule(σ, T)) + tmp = similar(x, T) + bias_add!(tmp, opmode, x, bias) + y = activation(opmode, σ, tmp) + ∇bias_activation_rrule = @closure Δ -> begin + ∂x = ∇activation(CRC.unthunk(Δ), y, σ, tmp) + ∂b = ∇bias_add(bias, ∂x) + return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫bias(∂b) + end + return y, ∇bias_activation_rrule + end + + y, ∇broadcast = CRC.rrule_via_ad(cfg, broadcast, σ ∘ +, x, reshape_bias(x, bias)) + ∇bias_activation_rrule = @closure Δ -> begin + _, _, ∂x, ∂bias = ∇broadcast(Δ) + return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫bias(vec(∂bias)) + end + return y, ∇bias_activation_rrule +end + +bias_activation!!(::typeof(identity), x::AbstractVector, ::Nothing) = x +for bType in (Nothing, AbstractVector) + @eval function bias_activation!!(σ::F, x::AbstractVector, bias::$(bType)) where {F} + return vec(bias_activation!!(σ, expand_batchdim(x), bias)) + end +end + +bias_activation!!(::typeof(identity), x::AbstractArray, ::Nothing) = x +function bias_activation!!(σ::F, x::AbstractArray{xT, N}, ::Nothing) where {F, N, xT} + return activation!!(σ, x) +end +function bias_activation!!( + σ::F, x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT} + return bias_activation!!( + internal_operation_mode((x, bias)), is_mutable_array(x), σ, x, bias) +end + +function bias_activation!!(opmode::AbstractInternalArrayOpMode, ::False, σ::F, + x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT} + return bias_activation(opmode, σ, x, bias) +end + +function bias_activation!!( + opmode::GenericBroadcastOp, ::True, σ::F, x::AbstractArray{xT, N}, + bias::AbstractVector) where {F, N, xT} + return bias_activation(opmode, σ, x, bias) +end + +@stable default_mode="disable" function bias_activation!!( + opmode::AbstractInternalArrayOpMode, ::True, σ::F, + x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT} + bias_activation!(x, opmode, σ, x, bias) + return x +end + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(bias_activation!!), + opmode::AbstractInternalArrayOpMode, ::True, σ::F, + x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT} + T = concrete_bias_act_output_eltype(σ, x, bias) + 𝒫x, 𝒫bias = CRC.ProjectTo(x), CRC.ProjectTo(bias) + + if unsafe_known(activation_intermediate_not_needed(σ, T)) + bias_activation!(x, opmode, σ, x, bias) + ∇bias_activation_no_intermediate = @closure Δ -> begin + ∂x = ∇activation(CRC.unthunk(Δ), x, σ, NotaNumber()) + ∂b = ∇bias_add(bias, ∂x) + return ∂∅, ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫bias(∂b) + end + return x, ∇bias_activation_no_intermediate + end + + if unsafe_known(activation_has_rrule(σ, T)) + y, tmp = bias_activation_cached!!(σ, x, bias) + ∇bias_activation_rrule = @closure Δ -> begin + ∂x = ∇activation(CRC.unthunk(Δ), y, σ, tmp) + ∂b = ∇bias_add(bias, ∂x) + return ∂∅, ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫bias(∂b) + end + return y, ∇bias_activation_rrule + end + + res, ∇bias_activation_from_ad = CRC.rrule_via_ad( + cfg, bias_activation, opmode, σ, x, bias) + ∇bias_activation_fallback = @closure Δ -> begin + _, _, _, ∂x, ∂b = ∇bias_activation_from_ad(Δ) + return ∂∅, ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫bias(∂b) + end + return res, ∇bias_activation_fallback +end + +# Core Implementation +function bias_activation!( + y::AbstractArray{yT, N}, opmode::AbstractInternalArrayOpMode, + σ::F, x::AbstractArray{xT, N}, ::Nothing) where {F, N, xT, yT} + activation!(y, opmode, σ, x) + return +end + +function bias_activation!( + y::AbstractArray{yT, N}, opmode::AbstractInternalArrayOpMode, σ::F, + x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT, yT} + if σ === identity + bias_add!(y, opmode, x, bias) + else + broadcast!(σ ∘ +, y, x, reshape_bias(x, bias)) + end + return +end + +function bias_activation!(y::AbstractArray{yT, N}, ::LoopedArrayOp, σ::F, + x::AbstractArray{xT, N}, bias::AbstractVector) where {F, N, xT, yT} + bias_activation_cpu!( + reshape(y, flattened_bias_dims(y), size(y, N - 1), size(y, N)), + fuse_cpu_activation(σ), + σ, reshape(x, flattened_bias_dims(x), size(x, N - 1), size(x, N)), bias) + return +end + +function bias_activation_cpu!(y::AbstractArray{yT, 3}, ::True, σ::F, + x::AbstractArray{xT, 3}, bias::AbstractVector) where {F, xT, yT} + bias_activation_simd_loop!(y, σ, x, bias) + return +end + +function bias_activation_cpu!(y::AbstractArray{yT, 3}, ::False, σ::F, + x::AbstractArray{xT, 3}, bias::AbstractVector) where {F, xT, yT} + bias_activation_simd_loop!(y, σ, x, bias) + return +end + +function bias_activation_simd_loop!(y::AbstractArray{yT, 3}, σ::F, x::AbstractArray{xT, 3}, + bias::AbstractVector) where {F, xT, yT} + if size(y, 1) == 1 + for K in axes(x, 3) + @simd ivdep for J in axes(x, 2) + @inbounds y[1, J, K] = σ(x[1, J, K] + bias[J]) + end + end + else + for K in axes(x, 3), J in axes(x, 2) + @simd ivdep for I in axes(x, 1) + @inbounds y[I, J, K] = σ(x[I, J, K] + bias[J]) + end + end + end + return +end + +function bias_add!(y::AbstractArray{yT, N}, ::AbstractInternalArrayOpMode, + x::AbstractArray{xT, N}, bias::AbstractVector) where {N, xT, yT} + broadcast!(+, y, x, reshape_bias(x, bias)) + return +end + +function bias_add!(y::AbstractArray{yT, N}, ::LoopedArrayOp, + x::AbstractArray{xT, N}, bias::AbstractVector) where {N, xT, yT} + bias_add_loop!(reshape(y, flattened_bias_dims(y), size(y, N - 1), size(y, N)), + reshape(x, flattened_bias_dims(x), size(x, N - 1), size(x, N)), bias) + return +end + +function bias_add_loop!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, 3}, + bias::AbstractVector) where {xT, yT} + if size(y, 1) == 1 + for K in axes(x, 3) + @simd ivdep for J in axes(x, 2) + @inbounds y[1, J, K] = x[1, J, K] + bias[J] + end + end + else + for K in axes(x, 3), J in axes(x, 2) + @simd ivdep for I in axes(y, 1) + @inbounds y[I, J, K] = x[I, J, K] + bias[J] + end + end + end +end + +# Some helper functions for the rrule +function bias_activation_cached!!(σ::F, x::AbstractArray{xT, N}, + bias::Optional{<:AbstractVector}) where {F, N, xT} + @assert σ !== identity + bias === nothing && return activation(σ, x), x + return bias_activation_cached!!( + internal_operation_mode((x, bias)), is_mutable_array(x), σ, x, bias) +end + +function bias_activation_cached!!( + ::AbstractInternalArrayOpMode, ::False, σ::F, x::AbstractArray{xT, N}, + bias::Optional{<:AbstractVector}) where {F, N, xT} + y = broadcast(+, x, reshape_bias(x, bias)) + return activation(σ, y), y +end + +function bias_activation_cached!!( + ::AbstractInternalArrayOpMode, ::True, σ::F, x::AbstractArray{xT, N}, + bias::Optional{<:AbstractVector}) where {F, N, xT} + broadcast!(+, x, x, reshape_bias(x, bias)) + return activation(σ, x), x +end + +function bias_activation_cached!!( + ::LoopedArrayOp, ::True, σ::F, x::AbstractArray{xT, N}, + bias::Optional{<:AbstractVector}) where {F, N, xT} + x′ = reshape(x, flattened_bias_dims(x), size(x, N - 1), size(x, N)) + bias_add_loop!(x′, x′, bias) + x′′ = reshape(x′, size(x)) + return activation(σ, x′′), x′′ +end + +flattened_bias_dims(x::AbstractArray{T, N}) where {T, N} = prod(size(x)[1:(N - 2)]; init=1) + +CRC.@non_differentiable flattened_bias_dims(::Any...) diff --git a/lib/LuxLib/src/impl/common_ops.jl b/lib/LuxLib/src/impl/common_ops.jl new file mode 100644 index 000000000..ed25da525 --- /dev/null +++ b/lib/LuxLib/src/impl/common_ops.jl @@ -0,0 +1,69 @@ +function reshaped_bias_dims(x::AbstractArray, bias::AbstractVector) + return ntuple(i -> ifelse(i == ndims(x) - 1, length(bias), 1), ndims(x)) +end + +reshape_bias(::AbstractArray, ::Nothing) = nothing +reshape_bias(::AbstractVector, bias::Union{AbstractVector, StaticVector}) = bias +function reshape_bias(x::AbstractArray, bias::AbstractVector) + return reshape(bias, reshaped_bias_dims(x, bias)) +end +function reshape_bias(x::AbstractArray{<:Any, N}, bias::StaticVector) where {N} + return SArray{Tuple{reshaped_bias_dims(x, bias)...}, eltype(bias), N, length(bias)}(bias.data) +end + +## Needed for type stability +function CRC.rrule(::typeof(reshape_bias), x::AbstractArray{xT, N}, + bias::AbstractVector{bT}) where {xT, bT, N} + bias_r = reshape_bias(x, bias) + 𝒫bias = CRC.ProjectTo(bias) + return bias_r, Δ -> (∂∅, ∂∅, 𝒫bias(vec(Δ))) +end + +∇bias_add(::Nothing, Δ::AbstractArray) = ∂∅ +function ∇bias_add(b::AbstractArray{xT, N}, Δ::AbstractArray{yT, N}) where {xT, yT, N} + return reduce_sum(b, Δ) +end +function ∇bias_add(b::AbstractVector{xT}, Δ::AbstractArray{yT}) where {xT, yT} + return vec(reduce_sum(reshape_bias(Δ, b), Δ)) +end + +reduce_sum(::Nothing, ::NoTangent) = ∂∅ +function reduce_sum(x::AbstractArray, y::AbstractArray) + z = similar(x, promote_type(eltype(x), eltype(y))) + sum!(z, y) + return z +end + +function mean_var(x::AbstractArray; dims=:, corrected::Bool=true) + μ = mean(x; dims) + return μ, var(x; dims, corrected, mean=μ) +end + +function CRC.rrule(::typeof(mean_var), x::AbstractArray; dims=:, corrected::Bool=true) + μ, σ² = mean_var(x; dims, corrected) + + 𝒫x = CRC.ProjectTo(x) + ∇mean_var = @closure Δ -> begin + ∂μ, ∂σ² = CRC.unthunk(Δ) + n = dims_denom(x, dims) + ∂x₁ = unsum(x, CRC.unthunk(∂μ) / n, dims) + pre = 2 // (dims_denom(x, dims) - corrected) + ∂x₂ = pre .* CRC.unthunk(∂σ²) .* (x .- μ) + return NoTangent(), 𝒫x(add!!(∂x₁, ∂x₂)) + end + + return (μ, σ²), ∇mean_var +end + +add!!(x, y) = add!!(is_mutable_array(x), x, y) +add!!(::True, x, y) = x .+= y +add!!(::False, x, y) = x .+ y + +dims_denom(x, dims) = size(x, dims) +dims_denom(x, ::Colon) = length(x) +function dims_denom(x, dims::Union{Tuple, AbstractArray}) + return mapreduce(Base.Fix1(size, x), Base.mul_prod, unique(dims); init=1) +end + +unsum(x, dy, _) = broadcast(last ∘ tuple, x, dy) +unsum(x, dy, ::Colon) = broadcast(last ∘ tuple, x, Ref(dy)) diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl new file mode 100644 index 000000000..3a3d22ee3 --- /dev/null +++ b/lib/LuxLib/src/impl/conv.jl @@ -0,0 +1,248 @@ +function get_conv_input_weight(x, weight) + return get_conv_input_weight(get_device_type((x, weight)), x, weight) +end + +function get_conv_input_weight(::Type{Device}, x, weight) where {Device <: AbstractDevice} + return get_conv_input_weight( + Device, eltype_mismatch(safe_eltype(x), safe_eltype(weight)), x, weight) +end + +function get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::True, x, weight) + T = promote_type(safe_eltype(x), safe_eltype(weight)) + safe_warning("Mixed Precision Inputs received for GPU convolution [weight: \ + $(safe_eltype(weight))] and [x: $(safe_eltype(x))]. Promoting to $(T).", 1) + return contiguous(ofeltype_array(T, x)), contiguous(ofeltype_array(T, weight)) +end + +function get_conv_input_weight(::Type{<:AbstractGPUDevice}, ::False, x, weight) + return contiguous(x), contiguous(weight) +end + +get_conv_input_weight(::Type{<:AbstractDevice}, ::StaticBool, x, weight) = x, weight + +# Define some wrappers over NNlib operations. Useful once we ship our own versions +# with Kernel Abstractions and Loop Vectorization +function conv!(y, x, weight, cdims::ConvDims) + return conv!(y, get_device_type((y, x, weight)), x, weight, cdims) +end +function conv!(y::AbstractArray{yT, N}, ::Type{<:AbstractDevice}, + x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, + cdims::ConvDims) where {yT, xT, wT, N} + NNlib.conv!(y, x, weight, cdims) + return +end +function conv!(y::AbstractArray{yT, N}, ::Type{<:Union{CUDADevice, AMDGPUDevice}}, + x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, + cdims::ConvDims) where {yT, xT, wT, N} + if xT !== wT !== yT + safe_warning( + "Mixed Precision Inputs received for GPU convolution [weight: $(wT)] and \ + [x: $(xT)]. Promoting to $(yT).", 1) + end + NNlib.conv!(y, contiguous(ofeltype_array(yT, x)), + contiguous(ofeltype_array(yT, weight)), cdims) + return +end +function conv!(y::AbstractArray{yT, N}, dev::Type{<:AbstractGPUDevice}, + x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, + cdims::ConvDims) where {yT, xT, wT, N} + if xT !== wT !== yT + safe_warning( + "Mixed Precision Inputs received for GPU convolution [weight: $(wT)] and \ + [x: $(xT)]. Promoting to $(yT).", 1) + end + x_cont = contiguous(ofeltype_array(yT, x)) + weight_cont = contiguous(ofeltype_array(yT, weight)) + fallback_slow_conv!(y, dev, x_cont, weight_cont, cdims) + return +end + +function fallback_slow_conv!(y::AbstractArray{yT, N}, dev::Type{<:AbstractDevice}, + x::AbstractArray{xT, N}, weight::AbstractArray{wT, N}, + cdims::ConvDims) where {yT, xT, wT, N} + @warn "Falling back to slow convolution routine for $(dev) with x: size = \ + $(size(x)) eltype = $(xT) and weight: size = $(size(weight)) \ + eltype = $(wT)." maxlog=1 + # TODO: We should be able to reuse `y` for some part here for some efficiency + @assert NNlib.groupcount(cdims)==1 "Only groups=1 is supported for now." # FIXME + tmp = NNlib.unfold(x, cdims) + weight_compact = reshape(weight, :, size(weight, N), 1) + res = batched_matmul(tmp, weight_compact) + copyto!(y, reshape(res, size(y))) + return +end + +conv(x, weight, cdims::ConvDims) = conv(get_device_type((x, weight)), x, weight, cdims) + +function conv(::Type{<:Union{CPUDevice, CUDADevice, AMDGPUDevice, XLADevice}}, + x′, weight′, cdims::ConvDims) + x, weight = get_conv_input_weight(x′, weight′) + return NNlib.conv(x, weight, cdims) +end +function conv(dev::Type{<:AbstractDevice}, x′, weight′, cdims::ConvDims) + x, weight = get_conv_input_weight(dev, x′, weight′) + return fallback_slow_conv(dev, x, weight, cdims) +end + +function fallback_slow_conv(dev, x, weight, cdims::ConvDims) + y = similar(x, promote_type(eltype(x), eltype(weight)), NNlib.output_size(cdims)..., + NNlib.channels_out(cdims), size(x, ndims(x))) + fallback_slow_conv!(y, dev, x, weight, cdims) + return y +end + +function ∇conv_data(x′, weight′, cdims::ConvDims) + x, weight = get_conv_input_weight(x′, weight′) + return NNlib.∇conv_data(x, weight, cdims) +end + +function ∇conv_filter(x′, y′, cdims::ConvDims) + x, y = get_conv_input_weight(x′, y′) + return NNlib.∇conv_filter(x, y, cdims) +end + +function conv_bias_act(x′, weight′, cdims::ConvDims, bias′, act::F) where {F} + x, weight = get_conv_input_weight(x′, weight′) + bias = ofeltype_array(promote_type(safe_eltype(x), safe_eltype(weight)), bias′) + return conv_bias_act(get_device_type((x, weight, bias)), x, weight, cdims, bias, act) +end + +function conv_bias_act(::Type, x, weight, cdims, bias, act::F) where {F} + y = similar(x, concrete_bias_act_output_eltype(act, weight, x, bias), + NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, ndims(x))) + conv!(y, x, weight, cdims) + bias_activation!(y, internal_operation_mode((y, bias)), act, y, bias) + return y +end + +function conv_bias_act(::Type{CUDADevice}, x, weight, cdims, ::Nothing, act::F) where {F} + return activation!!(act, conv(x, weight, cdims)) +end +function conv_bias_act(::Type{CUDADevice}, x, weight, cdims, bias′, act::F) where {F} + if act === identity || act === NNlib.relu + bias = reshape_bias(x, bias′) + return NNlib.conv_bias_act(x, weight, cdims, bias, act) + end + return conv_bias_act(Nothing, x, weight, cdims, bias′, act) +end + +# Entry Points +function fused_conv( + act::F, weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, wT, xT, N} + old_threads = maybe_reduce_BLAS_threads(weight) + y = fused_conv(internal_operation_mode((weight, x, bias)), act, weight, x, bias, cdims) + reset_BLAS_threads(old_threads) + return y +end + +function fused_conv(::GenericBroadcastOp, act::F, weight::AbstractArray{wT, N}, + x::AbstractArray{xT, N}, bias::Optional{<:AbstractVector}, + cdims::ConvDims) where {F, wT, xT, N} + return bias_activation(act, conv(x, weight, cdims), bias) +end + +@stable default_mode="disable" function fused_conv(::AbstractInternalArrayOpMode, act::F, + weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, wT, xT, N} + return conv_bias_act(x, weight, cdims, bias, act) +end + +function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), + opmode::AbstractInternalArrayOpMode, act::F, + weight::AbstractArray{wT, N}, x::AbstractArray{xT, N}, + bias::Optional{<:AbstractVector}, cdims::ConvDims) where {F, wT, xT, N} + T = concrete_bias_act_output_eltype(act, weight, x, bias) + 𝒫w, 𝒫x, 𝒫b = CRC.ProjectTo(weight), CRC.ProjectTo(x), CRC.ProjectTo(bias) + + if unsafe_known(activation_intermediate_not_needed(act, T)) + y = conv_bias_act(x, weight, cdims, bias, act) + ∇fused_conv_no_cached = @closure Δ -> begin + return ∇fused_conv( + Δ, weight, x, bias, cdims, y, NotaNumber(), 𝒫w, 𝒫x, 𝒫b, act) + end + return y, ∇fused_conv_no_cached + end + + # In any case here we need the intermediate pre-activation values + y = similar(x, T, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)) + conv!(y, x, weight, cdims) + + if unsafe_known(activation_has_rrule(act, T)) + z, tmp = bias_activation_cached!!(act, y, bias) + ∇fused_conv_cached = @closure Δ -> begin + return ∇fused_conv(Δ, weight, x, bias, cdims, z, tmp, 𝒫w, 𝒫x, 𝒫b, act) + end + return z, ∇fused_conv_cached + end + + z, ∇bias_activation = CRC.rrule_via_ad(cfg, bias_activation, act, y, bias) + ∇fused_conv_cached = @closure Δ -> begin + old_threads = maybe_reduce_BLAS_threads(weight) + Δ = NNlib.colmajor(Δ) + _, _, ∂y, ∂b = ∇bias_activation(Δ) + ∂w, ∂x, _ = ∇conv_bias(∂y, ∂b, weight, x, bias, cdims) + reset_BLAS_threads(old_threads) + return ∂∅, ∂∅, ∂∅, 𝒫w(∂w), 𝒫x(∂x), 𝒫b(∂b), ∂∅ + end + + return z, ∇fused_conv_cached +end + +CRC.@opt_out rrule( + ::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), ::GenericBroadcastOp, + ::F, ::AbstractArray{wT, N}, ::AbstractArray{xT, N}, + ::Optional{<:AbstractVector}, ::ConvDims) where {F, wT, xT, N} + +function ∇fused_conv(Δ′, weight, x, bias, cdims::ConvDims, z, tmp, 𝒫w, 𝒫x, 𝒫b, act) + old_threads = maybe_reduce_BLAS_threads(weight) + Δ = CRC.unthunk(NNlib.colmajor(Δ′)) + ∂y = ∇activation(Δ, z, act, tmp) + ∂w, ∂x, ∂b = ∇conv_bias(∂y, weight, x, bias, cdims) + reset_BLAS_threads(old_threads) + return ∂∅, ∂∅, ∂∅, 𝒫w(∂w), 𝒫x(∂x), 𝒫b(∂b), ∂∅ +end + +function ∇conv_bias(∂y, weight, x, bias, cdims::ConvDims) + return ∇conv_bias(∂y, ∇bias_add(bias, ∂y), weight, x, bias, cdims) +end +function ∇conv_bias(∂y, ∂b, weight, x, _, cdims::ConvDims) + return ∇conv_filter(x, ∂y, cdims), ∇conv_data(∂y, weight, cdims), ∂b +end + +# Special handling for AMDGPU: AMDGPU doesn't support Float64 convolutions, so we need to +# type-cast everything +for (wT, xT) in [(Float64, Float64), (Float64, Float32), (Float32, Float64)] + for bT in (Float32, Float64) + @eval begin + function fused_conv(opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, + weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, + bias::AbstractVector{$(bT)}, cdims::ConvDims) where {F, N} + @warn "MIOpen doesn't support Float64 convolutions, type-casting \ + everything to Float32 to avoid runtime errors" maxlog=1 + return ofeltype_array(Float64, + fused_conv(opmode, act, ofeltype_array(Float32, weight), + ofeltype_array(Float32, x), ofeltype_array(Float32, bias), cdims)) + end + + CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), + opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, + weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, + bias::Optional{<:AbstractVector{$(bT)}}, cdims::ConvDims) where {F, N} + end + end + + @eval begin + function fused_conv(opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, + weight::AbstractArray{$(wT), N}, x::AbstractArray{$(xT), N}, + ::Nothing, cdims::ConvDims) where {F, N} + return ofeltype_array(Float64, + fused_conv(opmode, act, ofeltype_array(Float32, weight), + ofeltype_array(Float32, x), nothing, cdims)) + end + + CRC.@opt_out rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv), + opmode::GPUBroadcastOp{AMDGPUDevice}, act::F, weight::AbstractArray{$(wT), N}, + x::AbstractArray{$(xT), N}, ::Nothing, cdims::ConvDims) where {F, N} + end +end diff --git a/lib/LuxLib/src/impl/dense.jl b/lib/LuxLib/src/impl/dense.jl new file mode 100644 index 000000000..26e70b51a --- /dev/null +++ b/lib/LuxLib/src/impl/dense.jl @@ -0,0 +1,218 @@ +function cublasLt_fused_dense end # Defined in `LuxLibCUDAExt` +function cublasLt_fused_dense! end # Defined in `LuxLibCUDAExt` + +function fused_dense(::typeof(identity), weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}) + return matmuladd(weight, x, b) +end + +function fused_dense(act::F, weight::AbstractMatrix, x::AbstractMatrix, + b::Optional{<:AbstractVector}) where {F} + return fused_dense(internal_operation_mode((weight, x, b)), act, weight, x, b) +end + +function fused_dense(opmode::GenericBroadcastOp, act::F, weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} + return bias_activation(act, matmul(opmode, weight, x), b) +end + +@stable default_mode="disable" function fused_dense( + opmode::AbstractInternalArrayOpMode, act::F, weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} + y = similar(weight, concrete_bias_act_output_eltype(act, weight, x, b), + size(weight, 1), size(x, 2)) + fused_dense!(y, opmode, act, weight, x, b) + return y +end + +function fused_dense!(y::AbstractMatrix, opmode::AbstractInternalArrayOpMode, act::F, + weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} + matmul!(y, opmode, weight, x) + bias_activation!(y, opmode, act, y, b) + return nothing +end + +function fused_dense!( + y::AbstractMatrix, ::GPUBroadcastOp{CUDADevice}, act::F, weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} + cublasLt_fused_dense!(y, act, weight, x, b) + return nothing +end + +function CRC.rrule(cfg::CRC.RuleConfig{>:HasReverseMode}, ::typeof(fused_dense), + opmode::AbstractInternalArrayOpMode, act::F, weight::AbstractMatrix, + x::AbstractMatrix, b::Optional{<:AbstractVector}) where {F} + T = concrete_bias_act_output_eltype(act, weight, x, b) + 𝒫weight, 𝒫x, 𝒫b = CRC.ProjectTo(weight), CRC.ProjectTo(x), CRC.ProjectTo(b) + + if unsafe_known(activation_intermediate_not_needed(act, T)) + y = fused_dense(opmode, act, weight, x, b) + ∇fused_dense_no_intermediate = @closure Δ -> begin + ∂y = ∇activation(CRC.unthunk(Δ), y, act, NotaNumber()) + ∂w, ∂x, ∂b = ∇matmul_bias(∂y, weight, x, b) + return ∂∅, ∂∅, ∂∅, 𝒫weight(∂w), 𝒫x(∂x), 𝒫b(∂b) + end + return y, ∇fused_dense_no_intermediate + end + + if unsafe_known(activation_has_rrule(act, T)) + y = matmuladd(weight, x, b) + z = activation(opmode, act, y) + ∇fused_dense_cached = @closure Δ -> begin + ∂y = ∇activation(CRC.unthunk(Δ), z, act, y) + ∂w, ∂x, ∂b = ∇matmul_bias(∂y, weight, x, b) + return ∂∅, ∂∅, ∂∅, 𝒫weight(∂w), 𝒫x(∂x), 𝒫b(∂b) + end + return z, ∇fused_dense_cached + end + + y = similar(weight, T, size(weight, 1), size(x, 2)) + matmul!(y, opmode, weight, x) + z, ∇bias_activation = CRC.rrule_via_ad(cfg, bias_activation, act, y, b) + ∇fused_dense_fallback = @closure Δ -> begin + _, _, ∂y, ∂b = ∇bias_activation(Δ) + ∂w, ∂x, _ = ∇matmul_bias(∂y, ∂b, weight, x, b) + return ∂∅, ∂∅, ∂∅, 𝒫weight(∂w), 𝒫x(∂x), 𝒫b(∂b) + end + return z, ∇fused_dense_fallback +end + +## Special Reverse Pass for gelu activation. All other cases, we don't need special handling +function CRC.rrule( + ::typeof(fused_dense), ::GPUBroadcastOp{CUDADevice}, ::typeof(NNlib.gelu), + weight::AbstractMatrix, x::AbstractMatrix, b::Optional{<:AbstractVector}) + z, y = cublasLt_fused_dense(NNlib.gelu, weight, x, b, True()) + 𝒫weight, 𝒫x, 𝒫b = CRC.ProjectTo(weight), CRC.ProjectTo(x), CRC.ProjectTo(b) + + ∇fused_dense = @closure Δ -> begin + ∂y = ∇activation(CRC.unthunk(Δ), z, NNlib.gelu, y) + ∂w, ∂x, ∂b = ∇matmul_bias(∂y, weight, x, b) + return ∂∅, ∂∅, ∂∅, 𝒫weight(∂w), 𝒫x(∂x), 𝒫b(∂b) + end + + return z, ∇fused_dense +end + +# TODO: We can optimize these a bit further by checking for cases where the forward pass +# is not needed. We skip such optimizations for now +function EnzymeRules.augmented_primal(cfg, ::EnzymeCore.Const{typeof(fused_dense!)}, + ::Type{EnzymeCore.Const{Nothing}}, y::EnzymeCore.Annotation{<:AbstractMatrix}, + opmode::EnzymeCore.Const{<:AbstractInternalArrayOpMode}, act::EnzymeCore.Const, + weight::EnzymeCore.Annotation{<:AbstractMatrix}, + x::EnzymeCore.Annotation{<:AbstractMatrix}, + b::EnzymeCore.Annotation{<:Optional{<:AbstractVector}}) + + # NOTE: Here we are using the ChainRulesCore rrules if they are defined for simplicity + all_const = weight isa EnzymeCore.Const && b isa EnzymeCore.Const && + x isa EnzymeCore.Const + intermediate_not_needed = unsafe_known(activation_intermediate_not_needed( + act.val, eltype(y.val))) || all_const + + weight_cache = EnzymeRules.overwritten(cfg)[5] && !(x isa EnzymeCore.Const) && + !(y isa EnzymeCore.Const) ? copy(weight.val) : nothing + x_cache = EnzymeRules.overwritten(cfg)[6] && !(weight isa EnzymeCore.Const) && + !(y isa EnzymeCore.Const) ? copy(x.val) : nothing + + case_specific_cache = if act.val === NNlib.gelu && + opmode.val isa GPUBroadcastOp{CUDADevice} + tmp = similar(y.val) + cublasLt_fused_dense!(y.val, act.val, weight.val, x.val, b.val, tmp) + (1, tmp) + elseif intermediate_not_needed + fused_dense!(y.val, opmode.val, act.val, weight.val, x.val, b.val) + (1, NotaNumber()) + elseif unsafe_known(activation_has_rrule(act.val, eltype(y.val))) + tmp = matmuladd(weight.val, x.val, b.val) + activation!(y.val, opmode.val, act.val, tmp) + (1, tmp) + else + # TODO: Here for performance we might want to fuse the bias and activation together. + # We skip this optimization for now + if b.val !== nothing + matmuladd!(y.val, opmode.val, weight.val, x.val, b.val) + else + matmul!(y.val, opmode.val, weight.val, x.val) + end + tmp = zero.(y.val) + EnzymeCore.autodiff(EnzymeCore.Forward, EnzymeCore.Const(activation!), + EnzymeCore.Duplicated(y.val, tmp), opmode, act, + EnzymeCore.Duplicated(y.val, one.(y.val))) + (2, tmp) + end + + cache = (case_specific_cache, weight_cache, x_cache) + + return EnzymeRules.AugmentedReturn(nothing, nothing, cache) +end + +function EnzymeRules.reverse(cfg, ::EnzymeCore.Const{typeof(fused_dense!)}, + ::Type{EnzymeCore.Const{Nothing}}, cache, y::EnzymeCore.Annotation{<:AbstractMatrix}, + opmode::EnzymeCore.Const{<:AbstractInternalArrayOpMode}, act::EnzymeCore.Const, + weight::EnzymeCore.Annotation{<:AbstractMatrix}, + x::EnzymeCore.Annotation{<:AbstractMatrix}, + b::EnzymeCore.Annotation{<:Optional{<:AbstractVector}}) + case_specific_cache, weight_cache, x_cache = cache + + (case, tmp) = case_specific_cache + + if !(x isa EnzymeCore.Const) && !(y isa EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[5] + weight_cache = weight.val + end + end + + if !(weight isa EnzymeCore.Const) && !(y isa EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[6] + x_cache = x.val + end + end + + ∂ys = y.dval + ∂xs = x isa EnzymeCore.Const ? ∂ys : x.dval + ∂ws = weight isa EnzymeCore.Const ? ∂ys : weight.dval + ∂bs = b isa EnzymeCore.Const ? ∂ys : b.dval + + if EnzymeRules.width(cfg) == 1 + ∂ys = (∂ys,) + ∂xs = (∂xs,) + ∂ws = (∂ws,) + ∂bs = (∂bs,) + end + + for (∂y, ∂w, ∂x, ∂b) in zip(∂ys, ∂ws, ∂xs, ∂bs) + if !(y isa EnzymeCore.Const) && ∂y !== y.val + # Compute preactivation gradients + ∂pre_act = if case == 1 + ∇activation(∂y, y.val, act.val, tmp) + elseif case == 2 + ∂y .* tmp + else + error("Unknown case: $case. This should not happen, open an issue.") + end + + if !(b isa EnzymeCore.Const) && ∂b !== b.val + # FIXME: Can we do this without allocating? + ∂b₁ = similar(∂b) + sum!(∂b₁, ∂pre_act) + ∂b .+= ∂b₁ + end + + if !(weight isa EnzymeCore.Const) && ∂w !== weight.val + # TODO: we don't use our faster matmul here since we lack the 5 arg version + mul!(∂w, ∂pre_act, x_cache', true, true) + end + + if !(x isa EnzymeCore.Const) && ∂x !== x.val + # TODO: we don't use our faster matmul here since we lack the 5 arg version + mul!(∂x, weight_cache', ∂pre_act, true, true) + end + + ∂y .= 0 + end + end + + return ntuple(Returns(nothing), 6) +end + +∇matmul_bias(∂y, weight, x, bias) = ∇matmul_bias(∂y, ∇bias_add(bias, ∂y), weight, x, bias) +∇matmul_bias(∂y, ∂b, weight, x, _) = matmul(∂y, x'), matmul(weight', ∂y), ∂b diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl new file mode 100644 index 000000000..5b4248291 --- /dev/null +++ b/lib/LuxLib/src/impl/dropout.jl @@ -0,0 +1,173 @@ +# Entry Points +## dropout +function dropout(rng::AbstractRNG, x::AbstractArray, p::T, ::True, invp::T, dims) where {T} + mask, rngₙ = generate_dropout_mask(rng, x, p, invp, dims) + return dropout_dot_mul(x, mask), mask, rngₙ +end + +dropout(rng::AbstractRNG, x::AbstractArray, ::T, ::False, ::T, dims) where {T} = (x, x, rng) + +function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray, p::T, + training::StaticBool, ::True, invp::T, dims) where {T} + return dropout(rng, x, p, training, invp, dims) +end + +function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, + ::T, ::True, ::False, invp::T, dims) where {T} + check_dropout_mask_shape_mismatch(x, mask, dims) + return dropout_dot_mul(x, mask), mask, rng +end + +function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray, + ::T, ::False, ::False, invp::T, dims) where {T} + return x, mask, rng +end + +function check_dropout_mask_shape_mismatch(x::AbstractArray, mask::AbstractArray, dims) + @assert dropout_shape(x, dims)==size(mask) "`mask` is not of the same size as `LuxLib.dropout_shape(x, dims)`." + return nothing +end + +CRC.@non_differentiable check_dropout_mask_shape_mismatch(::Any...) + +## alpha_dropout +function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, ::True) where {T} + α = T(-1.7580993408473766) + A = T(inv(sqrt((1 - p) * (1 + p * α^2)))) + B = T(-A * α * p) + return alpha_dropout(rng, x, p, True(), α, A, B) +end + +function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, ::False) where {T} + return alpha_dropout(rng, x, p, False(), T(0), T(0), T(0)) +end + +function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, ::True, α, A, B) where {T} + noise, rngₙ = generate_alpha_dropout_noise(rng, x) + return alpha_dropout(noise, p, x, α, A, B), rngₙ +end + +alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, ::False, α, A, B) where {T} = x, rng + +# Core Implementation +dropout_shape(s, ::Colon) = size(s) +function dropout_shape(s, dims) + return ntuple(@closure(i->ifelse(i ∈ dims, size(s, i), 1)), ndims(s)) +end + +CRC.@non_differentiable dropout_shape(::Any...) + +function alpha_dropout(noise::AbstractArray, p, x::AbstractArray, α, A, B) + return alpha_dropout(internal_operation_mode((noise, x)), noise, p, x, α, A, B) +end + +@stable default_mode="disable" function alpha_dropout( + ::AbstractInternalArrayOpMode, noise::AbstractArray, p::Real, + x::AbstractArray{T}, α::Real, A::Real, B::Real) where {T} + A′, B′, α = T(A), T(B), T(α) + return @. muladd(ifelse(noise > p, x, α), A′, B′) +end + +@stable default_mode="disable" function alpha_dropout( + opmode::LoopedArrayOp, noise::AbstractArray, p::Real, + x::AbstractArray, α::Real, A::Real, B::Real) + res = similar(x, promote_type(typeof(p), typeof(α))) + alpha_dropout!(res, opmode, noise, p, x, α, A, B) + return res +end + +function CRC.rrule(::typeof(alpha_dropout), ::LoopedArrayOp, noise::AbstractArray, + p::Real, x::AbstractArray, α::Real, A::Real, B::Real) + cond = similar(noise, Bool) + y = similar(x, promote_type(typeof(p), typeof(α), typeof(A), typeof(B), eltype(x))) + @simd ivdep for I in eachindex(noise, x, y, cond) + @inbounds cond[I] = noise[I] > p + @inbounds y[I] = ifelse(cond[I], x[I], α) * A + B + end + + ∇alpha_dropout = let cond = cond, 𝒫x = CRC.ProjectTo(x), x = x + Δ -> begin + ∂x = similar(x) + @simd ivdep for I in eachindex(cond, Δ, ∂x) + @inbounds ∂x[I] = cond[I] * Δ[I] * A + end + return (ntuple(Returns(∂∅), 4)..., 𝒫x(∂x), ntuple(Returns(∂∅), 3)...) + end + end + + return y, ∇alpha_dropout +end + +function CRC.rrule(::typeof(alpha_dropout), ::AbstractInternalArrayOpMode, + noise::AbstractArray, p::Real, x::AbstractArray, α::Real, A::Real, B::Real) + cond = noise .> p + y = @. ifelse(cond, x, α) * A + B + + 𝒫x = CRC.ProjectTo(x) + ∇alpha_dropout = @closure Δ -> begin + ∂x = 𝒫x(Δ .* cond .* A) + return (ntuple(Returns(∂∅), 4)..., ∂x, ntuple(Returns(∂∅), 3)...) + end + + return y, ∇alpha_dropout +end + +function alpha_dropout!( + res::AbstractArray{T}, ::LoopedArrayOp, noise::AbstractArray{T}, + p::Real, x::AbstractArray{T}, α::Real, A::Real, B::Real) where {T} + @simd ivdep for I in eachindex(noise, x, res) + res[I] = ifelse(noise[I] > p, x[I], α) * A + B + end +end + +dropout_fptype(x) = float(real(remove_tracking(eltype(x)))) + +CRC.@non_differentiable dropout_fptype(::Any...) + +@stable default_mode="disable" function generate_alpha_dropout_noise(rng::AbstractRNG, x) + rng = LuxCore.replicate(rng) + noise = similar(x, dropout_fptype(x)) + rand!(rng, noise) + return noise, rng +end + +CRC.@non_differentiable generate_alpha_dropout_noise(::Any...) + +@stable default_mode="disable" function generate_dropout_mask( + rng::AbstractRNG, x, p, invp, dims) + rng = LuxCore.replicate(rng) + y = similar(remove_tracking(x), dropout_fptype(x), dropout_shape(x, dims)) + rand!(rng, y) + generate_dropout_mask!(y, internal_operation_mode(y), p, invp) + return y, rng +end + +CRC.@non_differentiable generate_dropout_mask(::Any...) + +function generate_dropout_mask!(y::AbstractArray, ::LoopedArrayOp, p, invp) + generate_dropout_mask_loop!(y, p, invp) + return +end + +function generate_dropout_mask_loop!(y::AbstractArray{T}, p, invp) where {T} + p, invp = T(p), T(invp) + @simd ivdep for I in eachindex(y) + y[I] = (y[I] > p) * invp + end +end + +function generate_dropout_mask!( + y::AbstractArray{T}, ::AbstractInternalArrayOpMode, p, invp) where {T} + p, invp = T(p), T(invp) + @. y = (y > p) * invp + return +end + +dropout_dot_mul(x::AbstractArray, mask::AbstractArray) = x .* mask + +function CRC.rrule(::typeof(dropout_dot_mul), x::AbstractArray, mask::AbstractArray) + ∇dropout_dot_mul = @closure Δ -> begin + return ∂∅, (CRC.ProjectTo(x))(dropout_dot_mul(Δ, mask)), ∂∅ + end + return dropout_dot_mul(x, mask), ∇dropout_dot_mul +end diff --git a/lib/LuxLib/src/impl/forward_diff.jl b/lib/LuxLib/src/impl/forward_diff.jl new file mode 100644 index 000000000..56a45c4ec --- /dev/null +++ b/lib/LuxLib/src/impl/forward_diff.jl @@ -0,0 +1,50 @@ +for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter] + patched_op = op !== :depthwiseconv ? eval(op) : getfield(NNlib, op) + + @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, + x2::AbstractArray{<:Real, N}, cdims::NNlib.ConvDims; + kwargs...) where {N, Tag, V, P} + value_fn(x) = ForwardDiff.value(Tag, x) + partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) + + y = $(patched_op)(value_fn.(x1), x2, cdims; kwargs...) + dys = ntuple(i -> $(patched_op)(partial_fn.(x1, i), x2, cdims; kwargs...), P) + + partials = ForwardDiff.Partials.(tuple.(dys...)) + return ForwardDiff.Dual{Tag, eltype(y), P}.(y, partials) + end + + @eval function NNlib.$(op)(x1::AbstractArray{<:Real, N}, + x2::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, + cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P} + value_fn(x) = ForwardDiff.value(Tag, x) + partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) + + y = $(patched_op)(x1, value_fn.(x2), cdims; kwargs...) + dys = ntuple(i -> $(patched_op)(x1, partial_fn.(x2, i), cdims; kwargs...), P) + + partials = ForwardDiff.Partials.(tuple.(dys...)) + return ForwardDiff.Dual{Tag, eltype(y), P}.(y, partials) + end + + @eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N}, + x2::AbstractArray{<:ForwardDiff.Dual{Tag, Vₚ, P}, N}, + cdims::NNlib.ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P} + value_fn(x) = ForwardDiff.value(Tag, x) + partial_fn(x, i) = ForwardDiff.partials(Tag, x, i) + + x1_data, x2_data = value_fn.(x1), value_fn.(x2) + + y = $(patched_op)(x1_data, x2_data, cdims; kwargs...) + + dys₁ = ntuple(P) do i + dys₁ᵢ = $(patched_op)(partial_fn.(x1, i), x2_data, cdims; kwargs...) + dys₂ᵢ = $(patched_op)(x1_data, partial_fn.(x2, i), cdims; kwargs...) + dys₁ᵢ .+= dys₂ᵢ + return dys₁ᵢ + end + + partials = ForwardDiff.Partials.(tuple.(dys₁...)) + return ForwardDiff.Dual{Tag, eltype(y), P}.(y, partials) + end +end diff --git a/lib/LuxLib/src/impl/groupnorm.jl b/lib/LuxLib/src/impl/groupnorm.jl new file mode 100644 index 000000000..9a64fd735 --- /dev/null +++ b/lib/LuxLib/src/impl/groupnorm.jl @@ -0,0 +1,424 @@ +groupnorm_reduce_dims(::AbstractArray{T, N}) where {T, N} = ntuple(static, N - 1) + +CRC.@non_differentiable groupnorm_reduce_dims(::Any) + +function groupnorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, groups::Int, act::F, ϵ::Real) where {F, N, xT} + x′ = reshape(x, size(x)[1:(N - 2)]..., size(x, N - 1) ÷ groups, groups, size(x, N)) + (μ, σ²), _ = compute_batch_statistics( + x′, nothing, nothing, groupnorm_reduce_dims(x), False(), nothing) + return reshape(groupnorm_affine_normalize(act, x′, μ, σ², γ, β, ϵ), size(x)) +end + +function groupnorm_affine_normalize( + act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, + σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real) where {F, N, xT, μT, σ²T} + return groupnorm_affine_normalize( + internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ) +end + +function groupnorm_affine_normalize( + ::GenericBroadcastOp, act::F, x::AbstractArray{xT, N}, μ::AbstractArray{μT, N}, + σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real) where {F, N, xT, μT, σ²T} + return affine_normalize( + act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ) +end + +@generated function groupnorm_affine_normalize( + opmode::AbstractInternalArrayOpMode, act::F, x::AbstractArray{xT, N}, + μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, ϵ::Real) where {F, N, xT, μT, σ²T} + reshape_calls = if γ != Nothing + quote + γ′ = reshape(γ, 1, size(x, N - 2), size(x, N - 1), 1) + β′ = reshape(β, 1, size(x, N - 2), size(x, N - 1), 1) + end + else + quote + γ′ = nothing + β′ = nothing + end + end + + return quote + x′ = reshape(x, :, size(x, N - 2), size(x, N - 1), size(x, N)) + μ′ = reshape(μ, 1, 1, size(x, N - 1), size(x, N)) + σ²′ = reshape(σ², 1, 1, size(x, N - 1), size(x, N)) + $(reshape_calls) + return reshape( + groupnorm_affine_normalize_internal(opmode, act, x′, μ′, σ²′, γ′, β′, ϵ), + size(x)) + end +end + +@stable default_mode="disable" function groupnorm_affine_normalize_internal( + opmode::AbstractInternalArrayOpMode, act::F, + x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, + γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, + ϵ::Real) where {F, xT, μT, σ²T} + y = similar(x, + promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), + safe_eltype(γ), safe_eltype(β))) + groupnorm_affine_normalize_internal!(y, opmode, act, x, μ, σ², γ, β, ϵ) + return y +end + +function groupnorm_affine_normalize_internal!( + y::AbstractArray{yT, 4}, opmode::LoopedArrayOp, act::F, x::AbstractArray{xT, 4}, + μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, + γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, + ϵ::Real) where {F, xT, yT, μT, σ²T} + if unsafe_known(fuse_cpu_activation(act)) + groupnorm_affine_normalize_act_cpu!(y, x, μ, σ², γ, β, ϵ, act) + else + groupnorm_affine_normalize_cpu!(y, x, μ, σ², γ, β, ϵ) + activation!(y, opmode, act, y) + end + return +end + +function groupnorm_affine_normalize_act_cpu!( + y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, + σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, act::F) where {F, xT, yT, μT, σ²T} + if size(y, 1) == 1 + groupnorm_affine_normalize_act_3d_serial_cpu!(y, x, μ, σ², γ, β, ϵ, act) + else + groupnorm_affine_normalize_act_4d_serial_cpu!(y, x, μ, σ², γ, β, ϵ, act) + end +end + +function groupnorm_affine_normalize_act_3d_serial_cpu!( + y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, + σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, σ::F) where {F, xT, yT, μT, σ²T} + if γ === nothing && β === nothing + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) + γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + β′ = -μ[1, 1, K, L] * γ′ + @simd ivdep for J in axes(y, 2) + y[1, J, K, L] = σ(x[1, J, K, L] * γ′ + β′) + end + end + else + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + @simd for J in axes(y, 2) + γ′ = γ[1, J, K, 1] * idenom + β′ = β[1, J, K, 1] - μ[1, 1, K, L] * γ′ + y[1, J, K, L] = σ(x[1, J, K, L] * γ′ + β′) + end + end + end +end + +function groupnorm_affine_normalize_act_4d_serial_cpu!( + y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, + σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real, σ::F) where {F, xT, yT, μT, σ²T} + if γ === nothing && β === nothing + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) + γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + β′ = -μ[1, 1, K, L] * γ′ + for J in axes(y, 2) + @simd ivdep for I in axes(y, 1) + y[I, J, K, L] = σ(x[I, J, K, L] * γ′ + β′) + end + end + end + else + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + for J in axes(y, 2) + γ′ = γ[1, J, K, 1] * idenom + β′ = β[1, J, K, 1] - μ[1, 1, K, L] * γ′ + @simd ivdep for I in axes(y, 1) + y[I, J, K, L] = σ(x[I, J, K, L] * γ′ + β′) + end + end + end + end +end + +function groupnorm_affine_normalize_cpu!( + y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, + σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T} + if size(y, 1) == 1 + groupnorm_affine_normalize_3d_serial_cpu!(y, x, μ, σ², γ, β, ϵ) + else + groupnorm_affine_normalize_4d_serial_cpu!(y, x, μ, σ², γ, β, ϵ) + end +end + +@inline function groupnorm_affine_normalize_3d_serial_cpu!( + y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, + σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T} + if γ === nothing && β === nothing + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) + γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + β′ = -μ[1, 1, K, L] * γ′ + @simd ivdep for J in axes(y, 2) + y[1, J, K, L] = x[1, J, K, L] * γ′ + β′ + end + end + else + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + @simd for J in axes(y, 2) + γ′ = γ[1, J, K, 1] * idenom + β′ = β[1, J, K, 1] - μ[1, 1, K, L] * γ′ + y[1, J, K, L] = x[1, J, K, L] * γ′ + β′ + end + end + end +end + +@inline function groupnorm_affine_normalize_4d_serial_cpu!( + y::AbstractArray{yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, + σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, + β::Optional{<:AbstractArray{<:Any, 4}}, ϵ::Real) where {xT, yT, μT, σ²T} + if γ === nothing && β === nothing + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) + γ′ = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + β′ = -μ[1, 1, K, L] * γ′ + for J in axes(y, 2) + @simd ivdep for I in axes(y, 1) + y[I, J, K, L] = x[I, J, K, L] * γ′ + β′ + end + end + end + else + @fastmath @inbounds for L in axes(y, 4), K in axes(y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + for J in axes(y, 2) + γ′ = γ[1, J, K, 1] * idenom + β′ = β[1, J, K, 1] - μ[1, 1, K, L] * γ′ + @simd ivdep for I in axes(y, 1) + y[I, J, K, L] = x[I, J, K, L] * γ′ + β′ + end + end + end + end +end + +function groupnorm_affine_normalize_internal!( + y::AbstractArray{yT, 4}, ::GPUBroadcastOp, act::F, x::AbstractArray{xT, 4}, + μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, + γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, + ϵ::Real) where {F, xT, yT, μT, σ²T} + backend = KA.get_backend(y) + run_ka_kernel( + groupnorm_affine_normalize_kernel!, backend, nothing, size(y), + y, act, x, μ, σ², γ, β, ϵ) + KA.synchronize(backend) +end + +@kernel cpu=false inbounds=true function groupnorm_affine_normalize_kernel!( + y::AbstractArray{<:Number, 4}, @Const(f), + @Const(x), @Const(μ), @Const(σ²), @Const(γ::Nothing), @Const(β::Nothing), @Const(ϵ)) + i, j, k, l = @index(Global, NTuple) + γ′ = inv(sqrt(σ²[1, 1, k, l] + ϵ)) + β′ = -μ[1, 1, k, l] * γ′ + y[i, j, k, l] = f(muladd(x[i, j, k, l], γ′, β′)) +end + +@kernel cpu=false inbounds=true function groupnorm_affine_normalize_kernel!( + y::AbstractArray{<:Number, 4}, @Const(f), @Const(x), + @Const(μ), @Const(σ²), @Const(γ), @Const(β), @Const(ϵ)) + i, j, k, l = @index(Global, NTuple) + γ′ = γ[1, j, k, 1] / sqrt(σ²[1, 1, k, l] + ϵ) + β′ = muladd(-μ[1, 1, k, l], γ′, β[1, j, k, 1]) + y[i, j, k, l] = f(muladd(x[i, j, k, l], γ′, β′)) +end + +function CRC.rrule( + cfg::RuleConfig{>:HasReverseMode}, ::typeof(groupnorm_affine_normalize_internal), + opmode::AbstractInternalArrayOpMode, f::F, + x::AbstractArray{T, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, + γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, + ϵ::Real) where {F, T, μT, σ²T} + y = similar(x, + promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), + safe_eltype(γ), safe_eltype(β))) + groupnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ) + z, ∇activation = CRC.rrule_via_ad(cfg, activation!!, opmode, is_mutable_array(y), f, y) + + 𝒫x, 𝒫μ, 𝒫σ² = CRC.ProjectTo(x), CRC.ProjectTo(μ), CRC.ProjectTo(σ²) + 𝒫γ = γ === nothing ? identity : CRC.ProjectTo(γ) + 𝒫β = β === nothing ? identity : CRC.ProjectTo(β) + + ∇groupnorm_affine_normalize_internal = @closure Δ -> begin + ∂y = last(∇activation(Δ)) + ∂x, ∂μ, ∂σ², ∂γ, ∂β = ∇groupnorm_affine_normalize(opmode, ∂y, x, μ, σ², γ, β, ϵ) + return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫μ(∂μ), 𝒫σ²(∂σ²), 𝒫γ(∂γ), 𝒫β(∂β), ∂∅ + end + + return z, ∇groupnorm_affine_normalize_internal +end + +function ∇groupnorm_affine_normalize( + opmode::AbstractInternalArrayOpMode, ∂y::AbstractArray{∂yT, 4}, + x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, + γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, + ϵ::Real) where {∂yT, xT, μT, σ²T} + ∂x, ∂σ² = similar(x), similar(σ², size(x)) + ∂γ = γ === nothing ? nothing : similar(γ, size(x)) + + ∇groupnorm_affine_normalize!(∂x, ∂σ², ∂γ, opmode, ∂y, x, μ, σ², γ, ϵ) + + ∂μ = sum(-, ∂x; dims=(1, 2)) + ∂σ² = sum(∂σ²; dims=(1, 2)) + ∂γ = γ === nothing ? ∂∅ : sum(∂γ; dims=(1, 4)) + ∂β = β === nothing ? ∂∅ : sum(∂y; dims=(1, 4)) + + return ∂x, ∂μ, ∂σ², ∂γ, ∂β +end + +function ∇groupnorm_affine_normalize(::LoopedArrayOp, ∂y::AbstractArray{∂yT, 4}, + x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, + γ::Optional{<:AbstractArray{<:Any, 4}}, β::Optional{<:AbstractArray{<:Any, 4}}, + ϵ::Real) where {∂yT, xT, μT, σ²T} + ∂x, ∂μ, ∂σ² = similar(x), similar(μ), similar(σ²) + ∂γ = γ === nothing ? nothing : similar(γ) + ∂β = β === nothing ? nothing : similar(β) + + ∇groupnorm_affine_normalize_cpu!(∂x, ∂μ, ∂σ², ∂γ, ∂β, ∂y, x, μ, σ², γ, ϵ) + + ∂γ = γ === nothing ? ∂∅ : ∂γ + ∂β = β === nothing ? ∂∅ : ∂β + + return ∂x, ∂μ, ∂σ², ∂γ, ∂β +end + +function ∇groupnorm_affine_normalize_cpu!( + ∂x::AbstractArray{∂xT, 4}, ∂μ::AbstractArray{∂μT, 4}, ∂σ²::AbstractArray{∂σ²T, 4}, + ::Nothing, ::Nothing, ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4}, + μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, ::Nothing, + ϵ::Real) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT, μT, σ²T} + half = eltype(∂σ²)(0.5) + + fill!(∂μ, 0) + fill!(∂σ², 0) + + if size(∂y, 1) == 1 + @fastmath @inbounds for L in axes(∂y, 4), K in axes(∂y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 + + @simd for J in axes(∂y, 2) + xμ = x[1, J, K, L] - μ[1, 1, K, L] + + ∂x[1, J, K, L] = ∂y[1, J, K, L] * idenom + ∂μ[1, 1, K, L] -= ∂x[1, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[1, J, K, L] * xμ * half * idenom² + end + end + else + @fastmath @inbounds for L in axes(∂y, 4), K in axes(∂y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 + + for J in axes(∂y, 2) + @simd for I in axes(∂y, 1) + xμ = x[I, J, K, L] - μ[1, 1, K, L] + + ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom + ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² + end + end + end + end +end + +function ∇groupnorm_affine_normalize_cpu!( + ∂x::AbstractArray{∂xT, 4}, ∂μ::AbstractArray{∂μT, 4}, ∂σ²::AbstractArray{∂σ²T, 4}, + ∂γ::AbstractArray{∂γT, 4}, ∂β::AbstractArray{∂βT, 4}, ∂y::AbstractArray{∂yT, 4}, + x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, σ²::AbstractArray{σ²T, 4}, + γ::AbstractArray{γT, 4}, + ϵ::Real) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT, μT, σ²T, γT} + half = eltype(∂σ²)(0.5) + + fill!(∂μ, 0) + fill!(∂σ², 0) + fill!(∂γ, 0) + fill!(∂β, 0) + + if size(∂y, 1) == 1 + @fastmath @inbounds for L in axes(∂y, 4), K in axes(∂y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 + + @simd for J in axes(∂y, 2) + γ′ = γ[1, J, K, 1] * idenom + + xμ = x[1, J, K, L] - μ[1, 1, K, L] + + ∂x[1, J, K, L] = ∂y[1, J, K, L] * γ′ + ∂μ[1, 1, K, L] -= ∂x[1, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[1, J, K, L] * xμ * half * idenom² + ∂γ[1, J, K, 1] += ∂y[1, J, K, L] * xμ * idenom + ∂β[1, J, K, 1] += ∂y[1, J, K, L] + end + end + else + @fastmath @inbounds for L in axes(∂y, 4), K in axes(∂y, 3) + idenom = inv(sqrt(σ²[1, 1, K, L] + ϵ)) + idenom² = idenom^2 + + for J in axes(∂y, 2) + γ′ = γ[1, J, K, 1] * idenom + @simd for I in axes(∂y, 1) + xμ = x[I, J, K, L] - μ[1, 1, K, L] + + ∂x[I, J, K, L] = ∂y[I, J, K, L] * γ′ + ∂μ[1, 1, K, L] -= ∂x[I, J, K, L] + ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] * xμ * half * idenom² + ∂γ[1, J, K, 1] += ∂y[I, J, K, L] * xμ * idenom + ∂β[1, J, K, 1] += ∂y[I, J, K, L] + end + end + end + end +end + +function ∇groupnorm_affine_normalize!( + ∂x::AbstractArray{∂xT, 4}, ∂σ²::AbstractArray{∂σ²T, 4}, + ∂γ::Optional{<:AbstractArray{<:Any, 4}}, ::GPUBroadcastOp, + ∂y::AbstractArray{∂yT, 4}, x::AbstractArray{xT, 4}, μ::AbstractArray{μT, 4}, + σ²::AbstractArray{σ²T, 4}, γ::Optional{<:AbstractArray{<:Any, 4}}, + ϵ::Real) where {∂xT, ∂σ²T, ∂yT, xT, μT, σ²T} + backend = KA.get_backend(∂x) + run_ka_kernel( + ∇groupnorm_affine_normalize_kernel!, backend, nothing, size(∂x), + ∂x, ∂σ², ∂γ, ∂y, x, μ, σ², ϵ, γ) + KA.synchronize(backend) +end + +@kernel cpu=false inbounds=true function ∇groupnorm_affine_normalize_kernel!( + ∂x, ∂σ², @Const(∂γ::Nothing), @Const(∂y), @Const(x), + @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ::Nothing)) + i, j, k, l = @index(Global, NTuple) + idenom = inv(sqrt(σ²[1, 1, k, l] + ϵ)) + + ∂x[i, j, k, l] = ∂y[i, j, k, l] * idenom + ∂σ²[i, j, k, l] = ∂x[i, j, k, l] * (μ[1, 1, k, l] - x[i, j, k, l]) * idenom * idenom / 2 +end + +@kernel cpu=false inbounds=true function ∇groupnorm_affine_normalize_kernel!( + ∂x, ∂σ², ∂γ, @Const(∂y), @Const(x), + @Const(μ), @Const(σ²), @Const(ϵ), @Const(γ)) + i, j, k, l = @index(Global, NTuple) + idenom = inv(sqrt(σ²[1, 1, k, l] + ϵ)) + γ′ = γ[1, j, k, 1] * idenom + + xμ_d = (x[i, j, k, l] - μ[1, 1, k, l]) * idenom + + ∂x[i, j, k, l] = ∂y[i, j, k, l] * γ′ + ∂σ²[i, j, k, l] = -∂x[i, j, k, l] * xμ_d * idenom / 2 + ∂γ[i, j, k, l] = ∂y[i, j, k, l] * xμ_d +end diff --git a/lib/LuxLib/src/impl/layernorm.jl b/lib/LuxLib/src/impl/layernorm.jl new file mode 100644 index 000000000..465597267 --- /dev/null +++ b/lib/LuxLib/src/impl/layernorm.jl @@ -0,0 +1,41 @@ +# TODO: For the `dims === nothing` case, we can optimize using a loop vectorization and +# kernel abstractions +function layernorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractArray}, + β::Optional{<:AbstractArray}, act::F, dims, epsilon::Real) where {N, F, xT} + μ, σ² = mean_var(x; dims=compute_layernorm_dims(x, γ, β, dims), corrected=false) + γ′, β′ = expand_layernorm_dims(x, γ, β, dims) + return affine_normalize(act, x, μ, σ², γ′, β′, epsilon) +end + +function compute_layernorm_dims(::AbstractArray, ::Nothing, ::Nothing, ::Nothing) + throw(ArgumentError("`dims` must be passed explicitly if `scale` and `bias` are \ + `nothing`")) +end + +function compute_layernorm_dims(::AbstractArray{xT, N}, ::AbstractArray{γT, M}, + ::AbstractArray{βT, M}, ::Nothing) where {xT, γT, βT, N, M} + @assert N>M "`x` must have more dimensions than `scale` and `bias` when `dims` is \ + `nothing`" + return 1:(N - M) +end + +function compute_layernorm_dims( + ::AbstractArray, ::Optional{<:AbstractArray}, ::Optional{<:AbstractArray}, dims) + return dims +end + +CRC.@non_differentiable compute_layernorm_dims(::Any...) + +expand_layernorm_dims(::AbstractArray, ::Nothing, ::Nothing, _) = nothing, nothing + +function expand_layernorm_dims(::AbstractArray{xT, N}, γ::AbstractArray{γT, M}, + β::AbstractArray{βT, M}, ::Nothing) where {xT, γT, βT, N, M} + new_γ_size = (size(γ)..., ntuple(i -> 1, N - M)...) + new_β_size = (size(β)..., ntuple(i -> 1, N - M)...) + return reshape(γ, new_γ_size), reshape(β, new_β_size) +end + +function expand_layernorm_dims(::AbstractArray{yT, N}, γ::AbstractArray{γT, N}, + β::AbstractArray{βT, N}, dims) where {yT, γT, βT, N} + return γ, β +end diff --git a/lib/LuxLib/src/impl/matmul.jl b/lib/LuxLib/src/impl/matmul.jl new file mode 100644 index 000000000..e202df32a --- /dev/null +++ b/lib/LuxLib/src/impl/matmul.jl @@ -0,0 +1,279 @@ +# Wrappers over Base & LinearAlgebra implementations to use poly algs if needed +matmuladd(A, B, ::Nothing) = matmul(A, B) +function matmuladd(A::AbstractMatrix, B::AbstractVector, bias::AbstractVector) + return matmuladd(A, expand_batchdim(B), bias) +end +function matmuladd(A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + return matmuladd(internal_operation_mode((A, B, bias)), A, B, bias) +end + +function matmuladd( + ::GenericBroadcastOp, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + return muladd(A, B, bias) +end +function matmuladd(opmode::AbstractInternalArrayOpMode, A::AbstractMatrix, + B::AbstractMatrix, bias::AbstractVector) + if size(A, 2) != size(B, 1) + throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))")) + end + if length(bias) != size(A, 1) + throw(DimensionMismatch(lazy"bias has length $(length(bias)) but A has shape ($(size(A, 1)), $(size(A, 2)))")) + end + C = similar(A, promote_type(eltype(A), eltype(B), eltype(bias)), size(A, 1), size(B, 2)) + matmuladd!(C, opmode, A, B, bias) + return C +end + +function matmul(A::AbstractMatrix, B::AbstractVector) + return vec(matmul(A, expand_batchdim(B))) +end +function matmul(A::AbstractMatrix, B::AbstractMatrix) + if size(A, 2) != size(B, 1) + throw(DimensionMismatch(lazy"A has shape ($(size(A, 1)), $(size(A, 2))) but B has shape ($(size(B, 1)), $(size(B, 2)))")) + end + return matmul(internal_operation_mode((A, B)), A, B) +end + +matmul(::GenericBroadcastOp, A::AbstractMatrix, B::AbstractMatrix) = A * B +function matmul(::AbstractInternalArrayOpMode, A::AbstractMatrix, B::AbstractMatrix) + C = similar(A, promote_type(eltype(A), eltype(B)), size(A, 1), size(B, 2)) + matmul!(C, A, B) + return C +end + +# Slightly higher level. Here we make decisions about which implementation to use +function matmuladd!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, ::Nothing) + matmul!(C, A, B) + return +end +function matmuladd!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + matmuladd!(C, internal_operation_mode((C, A, B, bias)), A, B, bias) + return +end + +function matmuladd!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, + A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + C .= bias + mul!(C, A, B, true, true) + return +end + +function matmuladd!(C::AbstractMatrix, ::GPUBroadcastOp{CUDADevice}, + A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + cublasLt_fused_dense!(C, identity, A, B, bias) + return +end + +function matmuladd!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, + B::AbstractMatrix, bias::AbstractVector) + if can_loopvec_args(C, A, B, bias) && fits_in_l2cache(C, A, B, bias) + matmuladd_loopvec!(C, A, B, bias) + return + end + matmuladd_cpu_fallback!(C, A, B, bias) + return +end + +function matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix) + matmul!(C, internal_operation_mode((C, A, B)), A, B) + return +end + +function matmul!(C::AbstractMatrix, ::AbstractInternalArrayOpMode, + A::AbstractMatrix, B::AbstractMatrix) + mul!(C, A, B) + return +end + +function matmul!(C::AbstractMatrix, ::LoopedArrayOp, A::AbstractMatrix, B::AbstractMatrix) + return matmul_cpu!(C, use_octavian(), explicit_blas_loaded(), A, B) +end + +for spl_blas in (True, False) + @eval begin + function matmul_cpu!( # Octavian can be used + C::AbstractMatrix, ::True, ::$(spl_blas), + A::AbstractMatrix, B::AbstractMatrix) + if can_loopvec_args(C, A, B) + if fits_in_l1cache(C, A, B) + matmul_loopvec!(C, A, B, true, false) + return + elseif $(unsafe_known(spl_blas()) ? fits_in_l2cache : + fits_in_l3cache)(C, A, B) + matmul_octavian!(C, A, B, true, false) + return + end + end + matmul_cpu_fallback!(C, A, B, true, false) + return + end + + function matmul_cpu!( # Octavian cannot be used + C::AbstractMatrix, ::False, ::$(spl_blas), + A::AbstractMatrix, B::AbstractMatrix) + if can_loopvec_args(C, A, B) + if $(unsafe_known(spl_blas()) ? fits_in_l1cache : fits_in_l2cache)(C, A, B) + matmul_loopvec!(C, A, B, true, false) + return + end + end + matmul_cpu_fallback!(C, A, B, true, false) + return + end + end +end + +# Low-Level Matmul implementations -- Either call libraries or implement our own +# We force inlining here to avoid allocations in the inner loops + +# Best case fallback, we are likely going to hit BLAS +@inline function matmul_cpu_fallback!(C::AbstractMatrix{T}, A::AbstractMatrix{T}, + B::AbstractMatrix{T}, α::Number, β::Number) where {T} + matmul_linalg_default!(C, A, B, α, β) + return +end + +@inline function matmul_cpu_fallback!(C::AbstractMatrix{T}, A::AbstractMatrix{AT}, + B::AbstractMatrix{BT}, α::Number, β::Number) where {T, AT, BT} + if can_loopvec_args(C, A, B) && unsafe_known(is_extension_loaded(Val(:Octavian))) + matmul_octavian!(C, A, B, α, β) + return + end + # Generic fallback is actually quite good starting julia 1.11 + @static if VERSION ≥ v"1.11-" + @warn lazy"Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [$(typeof(C))]: A [$(typeof(A))] x B [$(typeof(B))]). Falling back to generic implementation. This may be slow." maxlog=1 + A′, B′ = A, B + else + @warn lazy"Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [$(typeof(C))]: A [$(typeof(A))] x B [$(typeof(B))]). Converting to common type to to attempt to use BLAS. This may be slow." maxlog=1 + A′, B′ = ofeltype_array(T, A), ofeltype_array(T, B) + end + matmul_linalg_default!(C, A′, B′, α, β) + return +end + +@inline function matmul_linalg_default!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number) + mul!(C, A, B, α, β) + return +end + +function serial_matmul_loopvec! end +function matmul_loopvec! end +function matmuladd_loopvec! end + +function matmul_octavian! end + +@inline function matmuladd_cpu_fallback!( + C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + C .= bias + matmul_cpu_fallback!(C, A, B, true, true) + return +end + +# ChainRules +function CRC.rrule(::typeof(matmul), A::AbstractMatrix, B::AbstractMatrix) + 𝒫A, 𝒫B = CRC.ProjectTo(A), CRC.ProjectTo(B) + ∇matmul = @closure Δ′ -> begin + Δ = CRC.unthunk(Δ′) + ∂A = CRC.@thunk(𝒫A(matmul(Δ, B'))) + ∂B = CRC.@thunk(𝒫B(matmul(A', Δ))) + return ∂∅, ∂A, ∂B + end + return matmul(A, B), ∇matmul +end + +function CRC.rrule( + ::typeof(matmuladd), A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector) + 𝒫A, 𝒫B, 𝒫bias = CRC.ProjectTo(A), CRC.ProjectTo(B), CRC.ProjectTo(bias) + ∇matmuladd = @closure Δ′ -> begin + Δ = CRC.unthunk(Δ′) + ∂A = CRC.@thunk(𝒫A(matmul(Δ, B'))) + ∂B = CRC.@thunk(𝒫B(matmul(A', Δ))) + ∂bias = CRC.@thunk(𝒫bias(∇bias_add(bias, Δ))) + return ∂∅, ∂A, ∂B, ∂bias + end + return matmuladd(A, B, bias), ∇matmuladd +end + +# EnzymeRules +function EnzymeRules.augmented_primal(cfg, ::EnzymeCore.Const{typeof(matmuladd!)}, + ::Type{EnzymeCore.Const{Nothing}}, C::EnzymeCore.Annotation{<:AbstractMatrix}, + opmode::EnzymeCore.Const{<:AbstractInternalArrayOpMode}, + A::EnzymeCore.Annotation{<:AbstractMatrix}, + B::EnzymeCore.Annotation{<:AbstractMatrix}, + bias::EnzymeCore.Annotation{<:AbstractVector}) + A_cache = EnzymeRules.overwritten(cfg)[4] && !(B isa EnzymeCore.Const) && + !(C isa EnzymeCore.Const) ? copy(A.val) : nothing + B_cache = EnzymeRules.overwritten(cfg)[5] && !(A isa EnzymeCore.Const) && + !(C isa EnzymeCore.Const) ? copy(B.val) : nothing + + if !(C isa EnzymeCore.DuplicatedNoNeed || C isa EnzymeCore.BatchDuplicatedNoNeed) + matmuladd!(C.val, opmode.val, A.val, B.val, bias.val) + end + + return EnzymeRules.AugmentedReturn(nothing, nothing, (A_cache, B_cache)) +end + +function EnzymeRules.reverse(cfg, ::EnzymeCore.Const{typeof(matmuladd!)}, + ::Type{EnzymeCore.Const{Nothing}}, (A_cache, B_cache), + C::EnzymeCore.Annotation{<:AbstractMatrix}, + opmode::EnzymeCore.Const{<:AbstractInternalArrayOpMode}, + A::EnzymeCore.Annotation{<:AbstractMatrix}, + B::EnzymeCore.Annotation{<:AbstractMatrix}, + bias::EnzymeCore.Annotation{<:AbstractVector}) + if !(C isa EnzymeCore.Const) && !(B isa EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[4] + A_cache = A.val + end + end + + if !(C isa EnzymeCore.Const) && !(A isa EnzymeCore.Const) + if !EnzymeRules.overwritten(cfg)[5] + B_cache = B.val + end + end + + ∂Cs = C.dval + ∂As = A isa EnzymeCore.Const ? ∂Cs : A.dval + ∂Bs = B isa EnzymeCore.Const ? ∂Cs : B.dval + ∂bs = bias isa EnzymeCore.Const ? ∂Cs : bias.dval + + if EnzymeRules.width(cfg) == 1 + ∂Cs = (∂Cs,) + ∂As = (∂As,) + ∂Bs = (∂Bs,) + ∂bs = (∂bs,) + end + + for (∂C, ∂A, ∂B, ∂b) in zip(∂Cs, ∂As, ∂Bs, ∂bs) + if !(C isa EnzymeCore.Const) && ∂C !== C.val + if !(bias isa EnzymeCore.Const) && ∂b !== bias.val + # FIXME: Can we do this without allocating? + ∂b₁ = similar(∂b) + sum!(∂b₁, ∂C) + ∂b .+= ∂b₁ + end + + if !(A isa EnzymeCore.Const) && ∂A !== A.val + # TODO: we don't use our faster matmul here since we lack the 5 arg version + mul!(∂A, ∂C, B_cache', true, true) + end + + if !(B isa EnzymeCore.Const) && ∂B !== B.val + # TODO: we don't use our faster matmul here since we lack the 5 arg version + mul!(∂B, A_cache', ∂C, true, true) + end + + ∂C .= 0 + end + end + + return ntuple(Returns(nothing), 5) +end + +@enzyme_alternative matmul_octavian! matmul_linalg_default! +@enzyme_alternative serial_matmul_loopvec! matmul_linalg_default! +@enzyme_alternative matmul_loopvec! matmul_linalg_default! + +@enzyme_alternative matmuladd_loopvec! matmuladd_cpu_fallback! diff --git a/lib/LuxLib/src/impl/normalization.jl b/lib/LuxLib/src/impl/normalization.jl new file mode 100644 index 000000000..f9dafcdf0 --- /dev/null +++ b/lib/LuxLib/src/impl/normalization.jl @@ -0,0 +1,146 @@ +# In most cases this implementation should not be preferred. But this is nice to have +# because it works for arbitrary dimensions +function affine_normalize(act::F, x::AbstractArray, μ::Numeric, σ²::Numeric, + ::Nothing, ::Nothing, ϵ::Real) where {F} + γ′ = @. inv(sqrt(σ² + ϵ)) + β′ = @. -μ * γ′ + return @. act(x * γ′ + β′) +end + +function affine_normalize(act::F, x::AbstractArray, μ::Numeric, σ²::Numeric, + γ::AbstractArray, β::AbstractArray, ϵ::Real) where {F} + γ′ = @. γ / sqrt(σ² + ϵ) + β′ = @. β - μ * γ′ + return @. act(x * γ′ + β′) +end + +# Deal with statistics +function update_running_statistics(rμ, rσ², μ, σ², m₁, m₂) + return update_running_statistics( + internal_operation_mode((rμ, rσ², μ, σ²)), rμ, rσ², μ, σ², m₁, m₂, 1 - m₁) +end + +function update_running_statistics(::GenericBroadcastOp, rμ, rσ², μ, σ², m₁, m₂, m₃) + rμₙ = @. m₃ * rμ + m₁ * μ + rσ²ₙ = @. m₃ * rσ² + m₂ * σ² + return rμₙ, rσ²ₙ +end + +function update_running_statistics(opmode, rμ, rσ², μ, σ², m₁, m₂, m₃) + rμₙ = similar(rμ, promote_type(eltype(rμ), eltype(μ), typeof(m₃), typeof(m₁))) + rσ²ₙ = similar(rσ², promote_type(eltype(rσ²), eltype(σ²), typeof(m₂), typeof(m₃))) + update_running_statistics!(rμₙ, rσ²ₙ, opmode, rμ, rσ², μ, σ², m₁, m₂, m₃) + return rμₙ, rσ²ₙ +end + +CRC.@non_differentiable update_running_statistics(::Any...) + +function update_running_statistics!(rμₙ, rσ²ₙ, ::LoopedArrayOp, rμ, rσ², μ, σ², m₁, m₂, m₃) + update_running_statistics_simd_loop!( + rμₙ, rσ²ₙ, LoopedArrayOp(), rμ, rσ², μ, σ², m₁, m₂, m₃) + return +end + +function update_running_statistics_simd_loop!( + rμₙ, rσ²ₙ, ::LoopedArrayOp, rμ, rσ², μ, σ², m₁, m₂, m₃) + @simd ivdep for I in eachindex(rμₙ, rσ²ₙ) + rμₙ[I] = m₃ * rμ[I] + m₁ * μ[I] + rσ²ₙ[I] = m₃ * rσ²[I] + m₂ * σ²[I] + end +end + +function update_running_statistics!(rμₙ, rσ²ₙ, ::GPUBroadcastOp, rμ, rσ², μ, σ², m₁, m₂, m₃) + backend = KA.get_backend(rμₙ) + run_ka_kernel( + update_running_statistics_kernel!, backend, nothing, size(rμₙ), + rμₙ, rσ²ₙ, rμ, rσ², μ, σ², m₁, m₂, m₃) + KA.synchronize(backend) + return +end + +@kernel cpu=false inbounds=true function update_running_statistics_kernel!( + rμₙ, rσ²ₙ, @Const(rμ), @Const(rσ²), @Const(μ), + @Const(σ²), @Const(m₁), @Const(m₂), @Const(m₃)) + I = @index(Global) + rμₙ[I] = m₃ * rμ[I] + m₁ * μ[I] + rσ²ₙ[I] = m₃ * rσ²[I] + m₂ * σ²[I] +end + +function update_normalization_statistics( + x::AbstractArray{T, N}, rμ::AbstractArray{rμT, N}, rσ²::AbstractArray{rσ²T, N}, + μ::AbstractArray{μT, N}, σ²::AbstractArray{σ²T, N}, + momentum::Real, reduce_dims) where {T, N, rμT, rσ²T, μT, σ²T} + if last(reduce_dims) != N + μ = mean(μ; dims=N) + σ² = mean(σ²; dims=N) + end + m = remove_tracking(T(accum_size(x, reduce_dims))) + return update_running_statistics(rμ, rσ², μ, σ², momentum, momentum * m / (m - one(m))) +end + +accum_size(x, reduce_dims) = prod(Base.Fix1(size, x), unsafe_known(reduce_dims)) + +CRC.@non_differentiable update_normalization_statistics(::Any...) + +function compute_batch_statistics( + x::AbstractArray, ::Nothing, ::Nothing, reduce_dims, ::StaticBool, momentum) + μ, σ² = mean_var(x; dims=unsafe_known(reduce_dims), corrected=false) + return (aos_to_soa(μ), aos_to_soa(σ²)), (nothing, nothing) +end + +function compute_batch_statistics( + ::AbstractArray, rμ::AbstractArray, rσ²::AbstractArray, _, ::False, momentum) + return (remove_tracking(rμ), remove_tracking(rσ²)), (rμ, rσ²) +end + +function compute_batch_statistics(x::AbstractArray, rμ::AbstractArray, + rσ²::AbstractArray, reduce_dims, ::True, momentum) + μ, σ² = mean_var(x; dims=unsafe_known(reduce_dims), corrected=false) + rμ, rσ² = update_normalization_statistics( + remove_tracking(x), remove_tracking(rμ), remove_tracking(rσ²), + remove_tracking(μ), remove_tracking(σ²), momentum, reduce_dims) + return (aos_to_soa(μ), aos_to_soa(σ²)), (rμ, rσ²) +end + +# Main Implementation +## The idea here is to be generic. This is useful for testing the more optimized +## implementations as well. +function normalization( + x::AbstractArray, rμ::Optional{<:AbstractVector}, rσ²::Optional{<:AbstractVector}, + γ::Optional{<:AbstractVector}, β::Optional{<:AbstractVector}, reduce_dims, + training::StaticBool, momentum, epsilon, act::F=identity) where {F} + (μ, σ²), (rμ, rσ²) = compute_batch_statistics( + x, reshape_norm_dims(x, rμ), reshape_norm_dims(x, rσ²), + reduce_dims, training, momentum) + γ, β = reshape_norm_dims(x, γ), reshape_norm_dims(x, β) + return affine_normalize(act, x, μ, σ², γ, β, epsilon), rμ, rσ² +end + +reshape_norm_dims(_, ::Nothing) = nothing +reshape_norm_dims(y, x) = reshape(x, get_norm_reshape_dims(size(y), length(x))) + +@inbounds function get_norm_reshape_dims(sx::NTuple{N, <:Int}, ly::Int) where {N} + if ly == sx[N - 1] + return ntuple(i -> i == N - 1 ? ly : 1, N) + elseif N > 2 && ly == sx[N - 1] * sx[N - 2] + return ntuple(i -> i == (N - 1) || i == (N - 2) ? sx[i] : 1, N) + end + throw(ArgumentError("Invalid Dimensions!")) +end + +CRC.@non_differentiable get_norm_reshape_dims(::Any...) + +# Entry Points +## InstanceNorm +function instancenorm(x::AbstractArray{xT, N}, γ::Optional{<:AbstractVector}, + β::Optional{<:AbstractVector}, rμ::Optional{<:AbstractVector}, + rσ²::Optional{<:AbstractVector}, training::StaticBool, + act::F, momentum, epsilon) where {xT, N, F} + y, rμₙ, rσ²ₙ = normalization( + x, rμ, rσ², γ, β, instancenorm_reduce_dims(x), training, momentum, epsilon, act) + return y, safe_vec(rμₙ), safe_vec(rσ²ₙ) +end + +instancenorm_reduce_dims(::AbstractArray{T, N}) where {T, N} = ntuple(static, N - 2) + +CRC.@non_differentiable instancenorm_reduce_dims(::Any...) diff --git a/lib/LuxLib/src/traits.jl b/lib/LuxLib/src/traits.jl new file mode 100644 index 000000000..29d3dc1e0 --- /dev/null +++ b/lib/LuxLib/src/traits.jl @@ -0,0 +1,220 @@ +module Traits + +using ArrayInterface: ArrayInterface, can_setindex +using ChainRulesCore: ChainRulesCore +using ForwardDiff: ForwardDiff +using NNlib: NNlib +using Static: True, False, static +using StaticArraysCore: StaticArray + +using ..LuxLib: Numeric +using ..Utils: NotaNumber, only_derivative, unrolled_any, unrolled_map + +function fast_scalar_indexing(::T) where {T <: AbstractArray} + return static(ArrayInterface.fast_scalar_indexing(T)) +end +fast_scalar_indexing(::Nothing) = True() +fast_scalar_indexing(x::NNlib.BatchedAdjOrTrans) = fast_scalar_indexing(parent(x)) + +is_mutable_array(::T) where {T <: AbstractArray} = static(can_setindex(T)) +is_mutable_array(::Nothing) = True() + +ChainRulesCore.@non_differentiable is_mutable_array(::Any...) + +for op in (:has_dual, :has_float16, :is_tracked) + @eval $op(::Nothing) = False() + @eval $op(x::Numeric) = $op(eltype(x)) +end + +unwrap_array(x) = x +function unwrap_array(x::AbstractArray) + parent(x) === x && return x + return unwrap_array(parent(x)) +end + +has_dual(_) = False() +has_dual(::Type{<:ForwardDiff.Dual}) = True() + +has_float16(_) = False() +has_float16(::Type{<:Float16}) = True() + +is_tracked(_) = False() + +has_autodiff_value(x) = is_tracked(x) | has_dual(x) + +static_isa(::Type{T}) where {T} = Base.Fix2(static_isa, T) +static_isa(x, ::Type{T}) where {T} = static(isa(x, T)) + +function use_generic_broadcasting(xs::Tuple) + # Float16 is a bit iffy and reordering operations are not optimal for numerical + # stability so we use the generic implementation for now. + xs_unwrapped = unrolled_map(unwrap_array, xs) + return unrolled_any(has_autodiff_value, xs_unwrapped) | + unrolled_any(has_float16, xs_unwrapped) | + unrolled_any(static_isa(StaticArray), xs_unwrapped) +end + +activation_intermediate_not_needed(::typeof(identity), ::Type) = True() + +function activation_intermediate_not_needed(::F, ::Type{T}) where {F, T} + return static(isconcretetype(Core.Compiler._return_type( + only_derivative, Tuple{T, F, NotaNumber}))) +end + +function activation_has_rrule(::F, ::Type{T}) where {F, T} + return static(isconcretetype(Core.Compiler._return_type( + only_derivative, Tuple{T, F, T}))) +end + +# Which activations can be fused into a single kernel +for act in (:identity, :(NNlib.relu), :abs, :abs2) + @eval fuse_cpu_activation(::typeof($act)) = True() +end +fuse_cpu_activation(::F) where {F} = False() + +end + +module System + +using ChainRulesCore: ChainRulesCore +using Hwloc: Hwloc +using Static: static, False, True + +using ..LuxLib: DISABLE_LOOP_VECTORIZATION +using ..Utils: is_extension_loaded, safe_minimum + +const CRC = ChainRulesCore + +# Technically Octavian works fine on non-server AMD CPUs, but for safety we disable it +# on non Intel CPUs. +const INTEL_HARDWARE = @static if Sys.ARCH === :x86_64 || Sys.ARCH === :i686 + try + using CpuId: CpuId + static(lowercase(string(CpuId.cpuvendor())) == "intel") + catch + @warn "Could not detect cpu vendor via CpuId.jl, assuming not Intel. Open an \ + issue in `LuxLib.jl` if this is unexpected." + False() + end +else + False() +end + +const AMD_RYZEN_HARDWARE = @static if Sys.ARCH === :x86_64 || Sys.ARCH === :i686 + try + using CpuId: CpuId + static(occursin("ryzen", lowercase(string(CpuId.cpubrand())))) + catch + @warn "Could not detect cpu brand via CpuId.jl, assuming not Ryzen. Open an issue \ + in `LuxLib.jl` if this is unexpected." + False() + end +else + False() +end + +function is_x86_64() + @static if Sys.ARCH === :x86_64 + return True() + else + return False() + end +end + +CRC.@non_differentiable is_x86_64() + +function explicit_blas_loaded() + return is_extension_loaded(Val(:MKL)) | + is_extension_loaded(Val(:AppleAccelerate)) | + is_extension_loaded(Val(:BLISBLAS)) +end + +CRC.@non_differentiable explicit_blas_loaded() + +@static if DISABLE_LOOP_VECTORIZATION + use_octavian() = False() +else + function use_octavian() + return is_extension_loaded(Val(:Octavian)) & is_x86_64() & + (INTEL_HARDWARE | AMD_RYZEN_HARDWARE) + end +end + +CRC.@non_differentiable use_octavian() + +const L1CacheSize::Int = safe_minimum(Hwloc.l1cache_sizes(), 0) +const L2CacheSize::Int = safe_minimum(Hwloc.l2cache_sizes(), 0) +const L3CacheSize::Int = safe_minimum(Hwloc.l3cache_sizes(), 0) + +# NOTE: some systems might not have L3 cache, so we check whether it fits in L(N - 1) cache +fits_in_l1cache(xs::AbstractArray...) = sum(sizeof, xs) ≤ L1CacheSize +CRC.@non_differentiable fits_in_l1cache(::Any...) + +function fits_in_l2cache(xs::AbstractArray...) + return fits_in_l1cache(xs...) || sum(sizeof, xs) ≤ L2CacheSize +end +CRC.@non_differentiable fits_in_l2cache(::Any...) + +function fits_in_l3cache(xs::AbstractArray...) + return fits_in_l2cache(xs...) || sum(sizeof, xs) ≤ L3CacheSize +end +CRC.@non_differentiable fits_in_l3cache(::Any...) + +end + +# How to do an internal operation? +# 1. Generic Broadcasting without Preallocation -- GenericBroadcastOp +# 2. Broadcasting with Fusion -- GPUBroadcastOp +# 3. Use Loops possibly accelerating with LoopVectorization or Polyester. This might +# still use broadcasting if needed + +abstract type AbstractInternalArrayOpMode end + +abstract type AbstractBroadcastOpMode <: AbstractInternalArrayOpMode end + +struct GenericBroadcastOp <: AbstractBroadcastOpMode end +struct GPUBroadcastOp{dev} <: AbstractBroadcastOpMode end +struct LoopedArrayOp <: AbstractInternalArrayOpMode end + +## NOTE: Ensure that this always gets compiled out! Else we will have terrible type +## inference. +""" + internal_operation_mode(xs::Tuple) + internal_operation_mode(x::AbstractArray) + +Returns the internal operation mode for the given array(s). This is useful to define custom +implementations using different backends like simple Julia broadcasting, Kernel +Abstractions, Loop Vectorization, etc. + +Currently supported modes are: + + - `GenericBroadcastOp`: This is the fallback for most types. For the following types this + is the preferred mode: + + + Arrays with `fast_scalar_indexing` set to `False`. + + Static Arrays + + ReverseDiff Arrays + + Tracker Arrays + + ForwardDiff.Dual Arrays + + - `GPUBroadcastOp{dev}`: GPU Arrays where `dev` is obtained from `get_device_type(xs)`. + This option dispatches should preferably use `KernelAbstractions` or specialized vendor + dispatches. + - `LoopedArrayOp`: CPU arrays that can be optimized using SIMD Loops, ideally using + `LoopVectorization.jl` or `Polyester.jl`. +""" +function internal_operation_mode(xs::Tuple) + xs = filter(!isnothing, xs) + known(Traits.use_generic_broadcasting(xs)) && return GenericBroadcastOp() + + dev = get_device_type(xs) + dev <: AbstractGPUDevice && return GPUBroadcastOp{dev}() + + # This check needs to be done after the GPU Check + known(Utils.unrolled_any(!Traits.fast_scalar_indexing, xs)) && + return GenericBroadcastOp() + return LoopedArrayOp() +end +internal_operation_mode(x::AbstractArray) = internal_operation_mode((x,)) + +CRC.@non_differentiable internal_operation_mode(::Any...) diff --git a/lib/LuxLib/src/utils.jl b/lib/LuxLib/src/utils.jl new file mode 100644 index 000000000..0104457c7 --- /dev/null +++ b/lib/LuxLib/src/utils.jl @@ -0,0 +1,342 @@ +module Utils + +using ChainRulesCore: ChainRulesCore +using EnzymeCore: EnzymeCore, EnzymeRules +using FastClosures: @closure +using ForwardDiff: ForwardDiff +using KernelAbstractions: KernelAbstractions +using LinearAlgebra: LinearAlgebra, BLAS +using MLDataDevices: get_device_type, CPUDevice +using NNlib: NNlib +using Static: Static, StaticBool, False, True, static +using StaticArraysCore: SVector, SMatrix + +using ..LuxLib: Optional, ∂∅, DISABLE_LOOP_VECTORIZATION + +const CRC = ChainRulesCore +const KA = KernelAbstractions + +is_extension_loaded(::Val) = False() + +CRC.@non_differentiable is_extension_loaded(::Any...) +EnzymeRules.inactive_noinl(::typeof(is_extension_loaded), ::Any...) = nothing + +# Simple Operations -- no rrules needed +ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x +function ofeltype_array( + ::Type{T}, x::AbstractArray{<:ForwardDiff.Dual{Tag, T, N}}) where {Tag, T, N} + return x +end +ofeltype_array(::Type{T}, x::AbstractArray) where {T} = T.(x) +function ofeltype_array( + ::Type{T}, x::AbstractArray{<:ForwardDiff.Dual{Tag, T2, N}}) where {Tag, T, T2, N} + return ForwardDiff.Dual{Tag, T, N}.(x) +end +ofeltype_array(::Type{T}, ::Nothing) where {T} = nothing + +contiguous(x::AbstractArray) = x +contiguous(x::SubArray) = copy(x) + +safe_reshape(x::AbstractArray, dims...) = reshape(x, dims...) +safe_reshape(::Nothing, dims...) = nothing + +remove_tracking(x) = x +remove_tracking(x::AbstractArray) = x +remove_tracking(::Type{T}) where {T} = T +remove_tracking(x::ForwardDiff.Dual) = ForwardDiff.value(x) +remove_tracking(x::AbstractArray{<:ForwardDiff.Dual}) = ForwardDiff.value.(x) +remove_tracking(::Type{<:ForwardDiff.Dual{Tag, T}}) where {Tag, T} = remove_tracking(T) +remove_tracking(::Nothing) = nothing + +safe_vec(x) = x +safe_vec(x::AbstractArray) = vec(x) +safe_vec(::Nothing) = nothing + +## This part is taken from NNlib.jl +# This has no methods, used for testing whether `derivatives_given_output(Ω, f, x)` +# is independent of `x`, as `_return_type` says `Union{}` when calling is an error. +struct NotaNumber <: Real end + +# This just saves typing `only.(only.(` many times: +only_derivative(y, f::F, x) where {F} = only(only(CRC.derivatives_given_output(y, f, x))) + +# Non-differentiable functions +eltype_mismatch(::Type, ::Type) = True() +eltype_mismatch(::Type{T}, ::Type{T}) where {T} = False() +function eltype_mismatch(::Type{T}, ::Type{<:ForwardDiff.Dual{Tag, T, N}}) where {Tag, T, N} + return False() +end +function eltype_mismatch(::Type{<:ForwardDiff.Dual{Tag, T, N}}, ::Type{T}) where {Tag, T, N} + return False() +end + +CRC.@non_differentiable eltype_mismatch(::Any...) + +## Reduce BLAS threads if we are going to use a native Julia implementation +maybe_reduce_BLAS_threads(x::AbstractArray) = maybe_reduce_BLAS_threads(get_device_type(x)) +maybe_reduce_BLAS_threads(::Type{T}) where {T} = -1 +function maybe_reduce_BLAS_threads(::Type{CPUDevice})::Int + old_threads = BLAS.get_num_threads() + BLAS.set_num_threads(1) + return old_threads +end + +CRC.@non_differentiable maybe_reduce_BLAS_threads(::AbstractArray) + +function reset_BLAS_threads(old_threads::Int) + old_threads ≥ 1 && BLAS.set_num_threads(old_threads) + return nothing +end + +CRC.@non_differentiable reset_BLAS_threads(::Int) + +unsafe_free!(_) = nothing +unsafe_free!(x::AbstractArray) = KA.unsafe_free!(x) + +CRC.@non_differentiable unsafe_free!(::Any) + +unsafe_known(x) = Static.known(x) # will drop gradients. needed for type stability in Zygote + +CRC.@non_differentiable unsafe_known(::Any) + +## depwarn but marked non-differentiable to prevent type instability +depwarn(msg::String, f::Symbol) = Base.depwarn(msg, f) + +CRC.@non_differentiable depwarn(::Any...) + +safe_eltype(::AbstractArray{T}) where {T} = T +safe_eltype(::T) where {T} = T +safe_eltype(::Nothing) = Bool + +CRC.@non_differentiable safe_eltype(::Any) + +default_epsilon(::Type{T}) where {T} = T(eps(T)^(5 / 7)) +default_epsilon(::AbstractArray{T}) where {T} = default_epsilon(T) + +CRC.@non_differentiable default_epsilon(::Any...) + +function concrete_bias_act_output_eltype(act::F, ::AbstractArray{Tw}, ::AbstractArray{Tx}, + b::Optional{<:AbstractVector}) where {F, Tw, Tx} + Ty = promote_type(Tw, Tx, safe_eltype(b)) + Tact = Core.Compiler._return_type(act, Tuple{Ty}) + return ifelse(isconcretetype(Tact), Tact, Ty) +end + +function concrete_bias_act_output_eltype( + act::F, x::AbstractArray, b::Optional{<:AbstractVector}) where {F} + return concrete_bias_act_output_eltype(act, x, x, b) +end + +CRC.@non_differentiable concrete_bias_act_output_eltype(::Any...) + +## Copy and don't allow gradient propagation +copy_drop_gradients(x) = copy(remove_tracking(x)) +copy_drop_gradients(::Nothing) = nothing + +CRC.@non_differentiable copy_drop_gradients(::Any) +EnzymeRules.inactive_noinl(::typeof(copy_drop_gradients), ::Any...) = nothing + +# Meta Programming Utilities +is_tracked(x) = x == :TrackedArray || x == :TrackedVector +is_tracked(args...) = unrolled_any(is_tracked, args) + +inferred_length(::Type{<:NTuple{N, Any}}) where {N} = N +@generated static_length(itr) = return :($(Val(inferred_length(itr)))) + +@generated function unrolled_any(f::F, xs) where {F} + L = inferred_length(xs) + L == 1 && return :(f(xs[1])) + return Expr(:call, :|, (:(f(xs[$i])) for i in 1:L)...) +end + +@generated function unrolled_map(f::F, xs) where {F} + L = inferred_length(xs) + return quote + $(Expr(:meta, :inline)) + $(Expr(:inbounds, true)) + res = $(Expr(:tuple, (:(f(xs[$i])) for i in 1:L)...)) + $(Expr(:inbounds, :pop)) + return res + end +end + +function unrolled_mapreduce(f::F, op::O, itr) where {F, O} + return unrolled_mapreduce(f, op, itr, static_length(itr)) +end + +function unrolled_mapreduce(::F, ::O, _, ::Val{0}) where {F, O} + error("Cannot unroll over an empty iterator.") +end + +unrolled_mapreduce(f::F, ::O, itr, ::Val{1}) where {F, O} = f(only(itr)) + +@generated function unrolled_mapreduce(f::F, op::O, itr, ::Val{N}) where {F, O, N} + syms = [gensym("f_itr_$(i)") for i in 1:N] + op_syms = [gensym("op_$(i)") for i in 1:(N - 1)] + f_applied = [:($(syms[i]) = f(itr[$i])) for i in 1:N] + combine_expr = [:($(op_syms[1]) = op($(syms[1]), $(syms[2])))] + for i in 2:(N - 1) + push!(combine_expr, :($(op_syms[i]) = op($(op_syms[i - 1]), $(syms[i + 1])))) + end + return quote + $(Expr(:meta, :inline)) + $(Expr(:inbounds, true)) + $(Expr(:block, f_applied...)) + $(Expr(:inbounds, :pop)) + $(Expr(:block, combine_expr...)) + return $(op_syms[end]) + end +end + +# Working with batches +batchview(x::AbstractArray{<:Any, 3}, k::Int) = view(x, :, :, k) +batchview(x::NNlib.BatchedTranspose, k::Int) = transpose(batchview(parent(x), k)) +batchview(x::NNlib.BatchedAdjoint, k::Int) = adjoint(batchview(parent(x), k)) + +batchview(x::AbstractArray{<:Any, 3}) = map(Base.Fix1(batchview, x), 1:size(x, 3)) + +expand_batchdim(x::AbstractMatrix) = reshape(x, size(x)..., 1) +function expand_batchdim(x::LinearAlgebra.Adjoint) + return NNlib.BatchedAdjoint(reshape(parent(x), size(parent(x))..., 1)) +end +function expand_batchdim(x::LinearAlgebra.Transpose) + return NNlib.BatchedTranspose(reshape(parent(x), size(parent(x))..., 1)) +end +expand_batchdim(x::AbstractVector) = reshape(x, :, 1) +expand_batchdim(x::SVector{L, T}) where {L, T} = SMatrix{L, 1, T}(x) + +function CRC.rrule(::typeof(expand_batchdim), x::AbstractMatrix) + proj_x = CRC.ProjectTo(x) + ∇expand_batchdim = @closure Δ -> begin + return ∂∅, proj_x(view(Δ, :, :, 1)) + end + return expand_batchdim(x), ∇expand_batchdim +end + +function safe_warning(msg::String, maxlog::Int) + if maxlog < 0 + @warn msg + else + @warn msg maxlog=maxlog + end +end + +CRC.@non_differentiable safe_warning(::Any...) + +function safe_minimum(x::AbstractArray, default) + length(x) == 0 && return default + return minimum(x) +end + +CRC.@non_differentiable safe_minimum(::Any...) + +# Switches function `foo` with function `bar`. To be used when Enzyme cannot differentiate +# through `foo` but supports `bar`. Use with caution, avoid multiple dispatch on `foo`. +# Also the function should always return `nothing` +macro enzyme_alternative(f₁, f₂) + return esc(quote + function EnzymeRules.augmented_primal( + ::EnzymeRules.RevConfig, ::EnzymeCore.Const{typeof($(f₁))}, + ::Type{RT}, args...) where {RT} + fwd, rev = EnzymeCore.autodiff_thunk( + EnzymeCore.ReverseSplitWithPrimal, EnzymeCore.Const{typeof($(f₂))}, + EnzymeCore.Const, typeof.(args)...) + + tape, result, shadow_result = fwd(EnzymeCore.Const($(f₂)), args...) + + return EnzymeRules.AugmentedReturn(result, shadow_result, (tape, rev)) + end + + function EnzymeRules.reverse( + ::EnzymeRules.RevConfig, ::EnzymeCore.Const{typeof($(f₁))}, + ::Type{RT}, (tape, rev), args...) where {RT} + return only(rev(EnzymeCore.Const($(f₂)), args..., tape)) + end + + function EnzymeRules.forward(cfg::EnzymeRules.FwdConfig, + ::EnzymeCore.Const{typeof($(f₁))}, ::Type{RT}, args...) where {RT} + EnzymeCore.autodiff(EnzymeCore.Forward, EnzymeCore.Const($(f₂)), RT, args...) + return + end + end) +end + +@inline function run_ka_kernel(f::F, backend, workgroupsize, ndrange, args...) where {F} + if workgroupsize === nothing + kernel = f(backend) + kernel(args...; ndrange) + return + end + kernel = f(backend, KA.StaticSize(workgroupsize), KA.StaticSize(ndrange)) + kernel(args...) + return +end + +within_autodiff_vararg(args...) = unrolled_any(within_autodiff, args) + +function within_autodiff(_) + unsafe_known(is_extension_loaded(Val(:Enzyme))) && + return static(EnzymeCore.within_autodiff()) + return False() +end +within_autodiff(::ForwardDiff.Dual) = True() +within_autodiff(::AbstractArray{<:ForwardDiff.Dual}) = True() + +CRC.rrule(::typeof(within_autodiff), x) = True(), _ -> (∂∅, ∂∅) + +static_training_mode(::Nothing, args...) = within_autodiff_vararg(args...) + +function static_training_mode( + training::Union{Bool, Val{true}, Val{false}, StaticBool}, args...) + return static_training_mode_check( + training, static(training), within_autodiff_vararg(args...)) +end + +function CRC.rrule(::typeof(static_training_mode), ::Nothing, args...) + return True(), _ -> ntuple(Returns(∂∅), length(args) + 2) +end + +function CRC.rrule(::typeof(static_training_mode), + training::Union{Bool, Val{true}, Val{false}, StaticBool}, args...) + res = static_training_mode_check(training, static(training), True()) + return res, _ -> ntuple(Returns(∂∅), length(args) + 2) +end + +static_training_mode_check(_, ::True, ::True) = True() +static_training_mode_check(_, ::False, ::False) = False() + +function static_training_mode_check(training, ::True, ::False) + @warn "`training` is set to `$(training)` but is not being used within an autodiff \ + call (gradient, jacobian, etc...). This will be slow. If you are using a \ + `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. \ + Reliance on this behavior is discouraged, and is not guaranteed by Semantic \ + Versioning, and might be removed without a deprecation cycle. It is recommended \ + to fix this issue in your code." maxlog=1 + return True() +end + +function static_training_mode_check(training, ::False, ::True) + @warn "`training` is set to `$(training)` but is being used within an autodiff call \ + (gradient, jacobian, etc...). This might lead to incorrect results. If you are \ + using a `Lux.jl` model, set it to training mode using \ + `LuxCore.trainmode`." maxlog=1 + return False() +end + +CRC.@non_differentiable static_training_mode_check(::Any...) + +@static if DISABLE_LOOP_VECTORIZATION + @inline can_loopvec_args(args...) = false +else + @inline function can_loopvec_args(args...) + return can_loopvec_args_check(is_extension_loaded(Val(:LoopVectorization)), args...) + end +end + +@inline can_loopvec_args_check(::False, args...) = false + +CRC.@non_differentiable can_loopvec_args_check(::Any...) + +EnzymeRules.inactive_noinl(::typeof(can_loopvec_args_check), ::Any...) = nothing + +end diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml new file mode 100644 index 000000000..1005c4881 --- /dev/null +++ b/lib/LuxLib/test/Project.toml @@ -0,0 +1,73 @@ +[deps] +AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +BLISBLAS = "6f275bd8-fec0-4d39-945b-7e95a765fa1e" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" +LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" +MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +AppleAccelerate = "0.4" +Aqua = "0.8.7" +BLISBLAS = "0.1" +BenchmarkTools = "1.5" +ChainRulesCore = "1.24" +ComponentArrays = "0.15.16" +Enzyme = "0.13.1" +EnzymeCore = "0.8" +ExplicitImports = "1.9.0" +ForwardDiff = "0.10.36" +Hwloc = "3.2" +InteractiveUtils = "<0.0.1, 1" +JLArrays = "0.1.5" +LoopVectorization = "0.12.171" +LuxTestUtils = "1.2.1" +MKL = "0.7" +MLDataDevices = "1.0.0" +NNlib = "0.9.21" +Octavian = "0.3.28" +Pkg = "1.10" +Preferences = "1.4.3" +Random = "1.10" +ReTestItems = "1.24.0" +Reexport = "1" +ReverseDiff = "1.15" +StableRNGs = "1.0.2" +Static = "0.8.4, 1" +StaticArrays = "1.9.7" +Statistics = "1.10" +Test = "1.10" +Tracker = "0.2.34" +Zygote = "0.6.70" + +[extras] +CUDA_Driver_jll = "4ee394cb-3365-5eb0-8335-949819d2adfc" + +[preferences.CUDA_Driver_jll] +compat = false diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl new file mode 100644 index 000000000..e2b80e711 --- /dev/null +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -0,0 +1,56 @@ +@testitem "Activation Functions" tags=[:other_ops] setup=[SharedTestSetup] begin + rng = StableRNG(1234) + + apply_act(f::F, x) where {F} = sum(abs2, f.(x)) + apply_act_fast(f::F, x) where {F} = sum(abs2, fast_activation!!(f, copy(x))) + apply_act_fast2(f::F, x) where {F} = sum(abs2, fast_activation(f, x)) + + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "$f: $T" for f in [identity, relu, sigmoid, sigmoid_fast, softplus, + logsigmoid, gelu, swish, lisht, tanh, tanh_fast], + T in [Float16, Float32, Float64] + + !fp64 && T == Float64 && continue + + x = rand(rng, T, 4, 3) |> aType + + y1 = apply_act(f, x) + y2 = apply_act_fast(f, x) + y3 = apply_act_fast2(f, x) + + fp16 = T == Float16 + atol = fp16 ? 1.0f-1 : 1.0f-3 + rtol = fp16 ? 1.0f-1 : 1.0f-3 + + @test y1≈y2 atol=atol rtol=rtol + @test y1≈y3 atol=atol rtol=rtol + @test eltype(y1) == T + @test eltype(y2) == T + @test eltype(y3) == T + + @test @inferred(apply_act(f, x)) isa Any + @test @inferred(apply_act_fast(f, x)) isa Any + @test @inferred(apply_act_fast2(f, x)) isa Any + + @jet apply_act_fast(f, x) + @jet apply_act_fast2(f, x) + + @test @inferred(Zygote.gradient(apply_act, f, x)) isa Any + if f !== lisht + @test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any + end + @test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any + + @test_gradients(Base.Fix1(apply_act, f), x; atol, rtol) + @test_gradients(Base.Fix1(apply_act_fast, f), x; atol, rtol) + @test_gradients(Base.Fix1(apply_act_fast2, f), x; atol, rtol) + + ∂x1 = Zygote.gradient(apply_act, f, x)[2] + ∂x2 = Zygote.gradient(apply_act_fast, f, x)[2] + ∂x3 = Zygote.gradient(apply_act_fast2, f, x)[2] + + @test ∂x1≈∂x2 atol=atol rtol=rtol + @test ∂x1≈∂x3 atol=atol rtol=rtol + end + end +end diff --git a/lib/LuxLib/test/common_ops/bias_act_tests.jl b/lib/LuxLib/test/common_ops/bias_act_tests.jl new file mode 100644 index 000000000..3b2f22d0c --- /dev/null +++ b/lib/LuxLib/test/common_ops/bias_act_tests.jl @@ -0,0 +1,103 @@ +@testitem "Bias Activation" tags=[:other_ops] setup=[SharedTestSetup] begin + rng = StableRNG(1234) + + bias_act_loss1(act, x, b) = sum(abs2, act.(x .+ LuxLib.Impl.reshape_bias(x, b))) + bias_act_loss2(act, x, b) = sum(abs2, bias_activation(act, x, b)) + bias_act_loss3(act, x, b) = sum(abs2, bias_activation!!(act, copy(x), b)) + + struct __Fix1{F, A} + f::F + act::A + end + (f::__Fix1)(x, b) = f.f(f.act, x, b) + + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "$act, $T, $sz" for act in [ + identity, relu, sigmoid, sigmoid_fast, softplus, + logsigmoid, gelu, swish, lisht, tanh, tanh_fast], + T in [Float16, Float32, Float64], + sz in [(2, 2, 3, 4), (4, 5)] + + !fp64 && T == Float64 && continue + + x = rand(rng, T, sz) |> aType + b = rand(rng, T, sz[end - 1]) |> aType + + y1 = bias_act_loss1(act, x, b) + y2 = bias_act_loss2(act, x, b) + y3 = bias_act_loss3(act, x, b) + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + @test y1≈y2 atol=atol rtol=rtol + @test y1≈y3 atol=atol rtol=rtol + @test eltype(y1) == T + @test eltype(y2) == T + @test eltype(y3) == T + + @test @inferred(bias_act_loss1(act, x, b)) isa Any + @test @inferred(bias_act_loss2(act, x, b)) isa Any + @test @inferred(bias_act_loss3(act, x, b)) isa Any + + @jet bias_act_loss2(act, x, b) + @jet bias_act_loss3(act, x, b) + + if act !== lisht && T != Float16 + @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any + @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any + end + + @test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol, + soft_fail=fp16 ? [AutoFiniteDiff()] : []) + @test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol, + soft_fail=fp16 ? [AutoFiniteDiff()] : []) + @test_gradients(__Fix1(bias_act_loss3, act), x, b; atol, rtol, + soft_fail=fp16 ? [AutoFiniteDiff()] : []) + + ∂x1, ∂b1 = Zygote.gradient(__Fix1(bias_act_loss1, act), x, b) + ∂x2, ∂b2 = Zygote.gradient(__Fix1(bias_act_loss2, act), x, b) + ∂x3, ∂b3 = Zygote.gradient(__Fix1(bias_act_loss3, act), x, b) + + @test ∂x1≈∂x2 atol=atol rtol=rtol + @test ∂x1≈∂x3 atol=atol rtol=rtol + @test ∂b1≈∂b2 atol=atol rtol=rtol + @test ∂b1≈∂b3 atol=atol rtol=rtol + end + end +end + +@testitem "Bias Activation (ReverseDiff)" tags=[:other_ops] setup=[SharedTestSetup] begin + using ReverseDiff, Tracker + + x = rand(Float32, 3, 4) + b = rand(Float32, 3) + act = tanh + + z = bias_activation(act, ReverseDiff.track(x), b) + @test z isa ReverseDiff.TrackedArray # If this fails then we fail to compile the tape + + z = bias_activation(identity, ReverseDiff.track(x), b) + @test z isa ReverseDiff.TrackedArray + + z = bias_activation(act, Tracker.param(x), b) + @test z isa Tracker.TrackedArray + + z = bias_activation(identity, Tracker.param(x), b) + @test z isa Tracker.TrackedArray +end + +@testitem "Bias Activation: Zero-sized Arrays" tags=[:other_ops] setup=[SharedTestSetup] begin + @testset "$mode" for (mode, aType, ongpu) in MODES + x = rand(Float32, 4, 3, 2, 0) |> aType + b = rand(Float32, 2) |> aType + @test size(bias_activation(identity, x, b)) == (4, 3, 2, 0) + @test size(bias_activation!!(identity, x, b)) == (4, 3, 2, 0) + + x = rand(Float32, 2, 0) |> aType + b = rand(Float32, 2) |> aType + @test size(bias_activation(relu, x, b)) == (2, 0) + @test size(bias_activation!!(relu, x, b)) == (2, 0) + end +end diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl new file mode 100644 index 000000000..c7426b205 --- /dev/null +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -0,0 +1,142 @@ +@testsetup module ConvSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib + +expand(_, i::Tuple) = i +expand(N, i::Integer) = ntuple(_ -> i, N) + +function convfilter(gen_f::Function, ::Type{wT}, filter::NTuple{N, Integer}, + ch::Pair{<:Integer, <:Integer}; groups=1) where {wT, N} + cin, cout = ch + @assert cin % groups==0 "Input channel dimension must be divisible by groups." + @assert cout % groups==0 "Output channel dimension must be divisible by groups." + return gen_f(wT, filter..., cin ÷ groups, cout) +end + +calc_padding(pad, ::NTuple{N}, dilation, stride) where {N} = expand(Val(2 * N), pad) + +function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, + hasbias, groups, Tw, Tx, aType, mode, ongpu) + weight = convfilter(gen_f, Tw, kernel, 4 => 8; groups) |> aType + x = gen_f(Tx, ntuple(Returns(4), length(kernel))..., 4, 2) |> aType + bias = hasbias ? aType(gen_f(Tx, 8)) : nothing + + cdims = DenseConvDims( + x, weight; stride, padding=calc_padding(padding, kernel, 1, stride), + dilation=1, groups) + + y = fused_conv_bias_activation(activation, weight, x, bias, cdims) + + generic_testing = !(mode == "amdgpu" && (Tx == Float64 || Tw == Float64)) + + fp16 = Tx == Float16 || Tw == Float16 + atol = fp16 ? 1.0f-1 : 1.0f-3 + rtol = fp16 ? 1.0f-1 : 1.0f-3 + + if generic_testing + y_generic = LuxLib.Impl.conv(x, weight, cdims) + y_generic = bias === nothing ? activation.(y_generic) : + activation.(y_generic .+ LuxLib.Impl.reshape_bias(y_generic, bias)) + # Operation reordering has an effect on the accuracy of the results + @test y≈y_generic atol=atol rtol=rtol + end + + @test eltype(y) == promote_type(Tw, Tx) + + @test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any + @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) + + __f = (σ, w, x, b, cdims) -> sum(abs2, fused_conv_bias_activation(σ, w, x, b, cdims)) + + if mode != "amdgpu" && activation !== anonact && !fp16 + @test @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) isa Any + else + try + @inferred(Zygote.gradient(__f, activation, weight, x, bias, cdims)) + @test true + catch e + e isa ErrorException || rethrow() + @test_broken false + end + end + + __f_grad = let activation = activation, cdims = cdims + (w, x, b) -> __f(activation, w, x, b, cdims) + end + + skip_backends = [] + mp = Tx != Tw + mp && push!(skip_backends, AutoReverseDiff()) + ((mp && ongpu) || (mode == "amdgpu" && (Tx == Float64 || Tw == Float64))) && + push!(skip_backends, AutoTracker()) + @test_gradients(__f_grad, weight, x, bias; atol, rtol, skip_backends, soft_fail=fp16) +end + +anonact = x -> gelu(x) + +const ELTYPES = [(Float16, Float16), (Float32, Float16), (Float32, Float32), + (Float32, Float64), (Float64, Float64)] +const ACTIVATIONS = [ + identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, swish, anonact] + +const ALL_TEST_CONFIGS = Iterators.product(ELTYPES, + (true, false), + ACTIVATIONS, + (((2,), (1,), (1,), 1), ((2, 2), (1, 1), (1, 1), 1), + ((2, 2), (0, 0), (2, 2), 1), ((2, 2), (0, 0), (1, 1), 2))) + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export expand, convfilter, calc_padding, anonact, TEST_BLOCKS, run_conv_testing + +end + +@testitem "Fused Conv: Group 1" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[1] + !fp64 && (Tx == Float64 || Tw == Float64) && continue + run_conv_testing(generate_fixed_array, activation, kernel, stride, + padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) + end + end +end + +@testitem "Fused Conv: Group 2" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[2] + !fp64 && (Tx == Float64 || Tw == Float64) && continue + run_conv_testing(generate_fixed_array, activation, kernel, stride, + padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) + end + end +end + +@testitem "Fused Conv: Group 3" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[3] + !fp64 && (Tx == Float64 || Tw == Float64) && continue + run_conv_testing(generate_fixed_array, activation, kernel, stride, + padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) + end + end +end + +@testitem "Fused Conv: Group 4" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[4] + !fp64 && (Tx == Float64 || Tw == Float64) && continue + run_conv_testing(generate_fixed_array, activation, kernel, stride, + padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) + end + end +end + +@testitem "Fused Conv: Group 5" tags=[:conv] setup=[SharedTestSetup, ConvSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "$(Tw) x $(Tx) hasbias: $(hasbias) activation: $(activation) kernel: $(kernel) padding: $(padding) stride: $(stride) groups: $(groups)" for ((Tx, Tw), hasbias, activation, (kernel, padding, stride, groups)) in TEST_BLOCKS[5] + !fp64 && (Tx == Float64 || Tw == Float64) && continue + run_conv_testing(generate_fixed_array, activation, kernel, stride, + padding, hasbias, groups, Tw, Tx, aType, mode, ongpu) + end + end +end diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl new file mode 100644 index 000000000..99d1810c9 --- /dev/null +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -0,0 +1,248 @@ +@testsetup module DenseSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs + +anonact = x -> x^3 + +dense_simple(act, w, x, ::Nothing) = act.(w * x) +dense_simple(act, w, x, b) = act.(w * x .+ b) + +function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) + rng = StableRNG(1234) + + bias = hasbias ? randn(rng, Tw, M) |> aType : nothing + w = randn(rng, Tw, M, N) |> aType + x = randn(rng, Tx, N, 3) |> aType + + if activation === tanh_fast || activation === tanh || activation === gelu + bias = bias === nothing ? nothing : (bias .* eltype(bias)(0.001)) + w = w .* eltype(w)(0.001) + x = x .* eltype(x)(0.001) + end + + y = fused_dense_bias_activation(activation, w, x, bias) + y_generic = bias === nothing ? activation.(w * x) : activation.(w * x .+ bias) + + @test y ≈ y_generic + @test eltype(y) == promote_type(Tw, Tx) + + @test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any + @jet fused_dense_bias_activation(activation, w, x, bias) + + fp16 = Tx == Float16 || Tw == Float16 + atol = fp16 ? 1.0f-1 : 1.0f-3 + rtol = fp16 ? 1.0f-1 : 1.0f-3 + + __f = (σ, w, x, b) -> sum(abs2, fused_dense_bias_activation(σ, w, x, b)) + + if !fp16 && activation !== anonact + @test @inferred(Zygote.gradient(__f, activation, w, x, bias)) isa Any + end + + skip_backends = [] + Tw != Tx && push!(skip_backends, AutoReverseDiff()) + fp16 && push!(skip_backends, AutoFiniteDiff()) + fp16 && push!(skip_backends, AutoTracker()) + + __f_grad = let activation = activation + (w, x, b) -> __f(activation, w, x, b) + end + @test_gradients(__f_grad, w, x, bias; atol, rtol, skip_backends, soft_fail=fp16) + + y_simple = dense_simple(activation, w, x, bias) + y_zyg = fused_dense_bias_activation(activation, w, x, bias) + @test y_simple≈y_zyg atol=atol rtol=rtol + + _, ∂w_true, ∂x_true, ∂b_true = Zygote.gradient( + sum ∘ dense_simple, activation, w, x, bias) + _, ∂w_zyg, ∂x_zyg, ∂b_zyg = Zygote.gradient( + sum ∘ fused_dense_bias_activation, activation, w, x, bias) + @test ∂w_true≈∂w_zyg atol=atol rtol=rtol + @test ∂x_true≈∂x_zyg atol=atol rtol=rtol + if bias !== nothing + @test ∂b_true≈∂b_zyg atol=atol rtol=rtol + end +end + +const ALL_TEST_CONFIGS = Iterators.product( + ((Float16, Float16), (Float32, Float16), (Float32, Float32), + (Float32, Float64), (Float64, Float64)), + (4, 32), + (4, 32), + (true, false), + (identity, tanh, tanh_fast, sigmoid, sigmoid_fast, relu, gelu, anonact)) + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export ALL_TEST_CONFIGS, TEST_BLOCKS, run_dense_testing + +end + +@testitem "Fused Dense: Group 1" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[1] + !fp64 && (Tx == Float64 || Tw == Float64) && continue + run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) + end + end +end + +@testitem "Fused Dense: Group 2" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[2] + !fp64 && (Tx == Float64 || Tw == Float64) && continue + run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) + end + end +end + +@testitem "Fused Dense: Group 3" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[3] + !fp64 && (Tx == Float64 || Tw == Float64) && continue + run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) + end + end +end + +@testitem "Fused Dense: Group 4" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[4] + !fp64 && (Tx == Float64 || Tw == Float64) && continue + run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) + end + end +end + +@testitem "Fused Dense: Group 5" tags=[:dense] setup=[SharedTestSetup, DenseSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $Tw x $Tx, size $M x $N, bias $hasbias, activation $activation" for ((Tx, Tw), M, N, hasbias, activation) in TEST_BLOCKS[5] + !fp64 && (Tx == Float64 || Tw == Float64) && continue + run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu) + end + end +end + +@testitem "Fused Dense: StaticArrays" tags=[:dense] begin + using StaticArrays, NNlib + + x = @SArray rand(2, 4) + weight = @SArray rand(3, 2) + bias = @SArray rand(3) + + @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa SArray +end + +@testitem "Fused Dense: CPU No Scalar Indexing" tags=[:dense] begin + using JLArrays, NNlib + + x = JLArray(rand(Float32, 2, 4)) + weight = JLArray(rand(Float32, 3, 2)) + bias = JLArray(rand(Float32, 3)) + + @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa JLArray + @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp +end + +@testitem "`LuxLib.Impl.matmul(add)` allocations" tags=[:dense] setup=[SharedTestSetup] begin + using BenchmarkTools, Statistics + + if BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" + @testset "size $N" for N in (1, 4, 32, 256, 1024) + x = rand(Float32, N, N) + + trial_opt = median(@benchmark(LuxLib.Impl.matmul($x, $x))) + trial_baseline = median(@benchmark($x*$x)) + + @test trial_opt.allocs ≤ trial_baseline.allocs + @test trial_opt.memory ≤ trial_baseline.memory + + bias = rand(Float32, N) + + trial_opt = median(@benchmark(LuxLib.Impl.matmuladd($x, $x, $bias))) + trial_baseline = median(@benchmark(muladd($x, $x, $bias))) + + @test trial_opt.allocs ≤ trial_baseline.allocs + @test trial_opt.memory ≤ trial_baseline.memory + end + end +end + +@testitem "Enzyme.Forward patch: dense" tags=[:dense] setup=[SharedTestSetup] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin + using LuxLib, Random, ForwardDiff, Enzyme + + x = rand(Float32, 2, 2) + + f(x) = sum(abs2, LuxLib.Impl.matmul(x, x)) + + @test only(Enzyme.gradient(Forward, f, x)) ≈ ForwardDiff.gradient(f, x) +end + +@testitem "Enzyme rules for fused dense" tags=[:dense] setup=[SharedTestSetup] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin + using LuxLib, NNlib, Zygote, Enzyme + + # These are mostly for testing the CUDA rules since we don't enable the CUDA tests + # in LuxTestUtils currently + function fused_dense!(y, act, weight, x, b) + op = LuxLib.internal_operation_mode((y, weight, x, b)) + LuxLib.Impl.fused_dense!(y, op, act, weight, x, b) + return + end + + function matmuladd!(C, A, B, bias) + op = LuxLib.internal_operation_mode((C, A, B, bias)) + LuxLib.Impl.matmuladd!(C, op, A, B, bias) + return + end + + rng = StableRNG(1234) + + ALL_ACTS = [identity, tanh, tanh_fast, sigmoid, sigmoid_fast, + relu, gelu, x -> x^3, x -> gelu(x)] + + @testset "$mode" for (mode, aType, ongpu) in MODES + mode ∈ ("cpu", "cuda") || continue + + y = zeros(Float32, 2, 2) |> aType + weight = randn(rng, Float32, 2, 2) |> aType + x = randn(rng, Float32, 2, 2) |> aType + @testset for (act, hasbias) in Iterators.product(ALL_ACTS, (true, false)) + b = hasbias ? aType(randn(rng, Float32, 2)) : nothing + + dy = randn(rng, Float32, 2, 2) |> aType + + dweight = zeros(Float32, 2, 2) |> aType + dx = zeros(Float32, 2, 2) |> aType + db = hasbias ? aType(zeros(Float32, 2)) : nothing + + b_enz = hasbias ? Duplicated(b, db) : Const(b) + + Enzyme.autodiff(Reverse, fused_dense!, Duplicated(y, copy(dy)), Const(act), + Duplicated(weight, dweight), Duplicated(x, dx), b_enz) + + _, pb_f = Zygote.pullback(fused_dense_bias_activation, act, weight, x, b) + _, dweight_zyg, dx_zyg, db_zyg = pb_f(dy) + + @test dweight≈dweight_zyg atol=1e-3 rtol=1e-3 + @test dx≈dx_zyg atol=1e-3 rtol=1e-3 + if hasbias + @test db≈db_zyg atol=1e-3 rtol=1e-3 + end + + (act === identity && hasbias) || continue + + dweight .= 0 + dx .= 0 + db .= 0 + Enzyme.autodiff(Reverse, matmuladd!, Duplicated(y, copy(dy)), + Duplicated(weight, dweight), Duplicated(x, dx), b_enz) + + _, pb_f = Zygote.pullback(LuxLib.Impl.matmuladd, weight, x, b) + dweight_zyg, dx_zyg, db_zyg = pb_f(dy) + + @test dweight≈dweight_zyg atol=1e-3 rtol=1e-3 + @test dx≈dx_zyg atol=1e-3 rtol=1e-3 + @test db≈db_zyg atol=1e-3 rtol=1e-3 + end + end +end diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl new file mode 100644 index 000000000..45f8fd017 --- /dev/null +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -0,0 +1,179 @@ +@testitem "Dropout" tags=[:other_ops] setup=[SharedTestSetup] begin + rng = StableRNG(12345) + + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "$T, $x_shape, $dims" for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)), + dims in (:, 1, (1, 2)) + + !fp64 && T == Float64 && continue + + x = randn(rng, T, x_shape) |> aType + + @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any + + y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), dims) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + dims isa Colon && @test size(mask_) == x_shape + @test rng != rng_ + + @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) + @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any + + __f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, dims))) + @test @inferred(Zygote.gradient(__f, x)) isa Any + + __f = let rng = rng, T = T + x -> sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) + end + @test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, + soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), + broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + + y, mask_, rng_ = dropout(rng, x, T(0.5), Val(false), T(2), dims) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test rng == rng_ + @test y == x + end + end +end + +@testitem "Dropout with Preset Mask" tags=[:other_ops] setup=[SharedTestSetup] begin + using Statistics + + rng = StableRNG(12345) + + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "$T: $x_shape" for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + + !fp64 && T == Float64 && continue + + x = randn(rng, T, x_shape) |> aType + mask = rand(T, x_shape) |> aType + + # Update mask + @test @inferred(dropout( + rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)) isa Any + + y, mask_, rng_ = dropout( + rng, x, mask, T(0.5), Val(true), Val(true), T(2), :) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng != rng_ + @test mask != mask_ + + __f = (x, mask) -> sum(first(dropout( + StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, :))) + @test @inferred(Zygote.gradient(__f, x, mask)) isa Any + + __f = let rng = rng, mask = mask, p = T(0.5), invp = T(2) + x -> sum(first(dropout(rng, x, mask, p, Val(true), Val(true), invp, :))) + end + @test_gradients(__f, x; atol=1.0f-3, + rtol=1.0f-3, + soft_fail=(T == Float16 ? [AutoFiniteDiff()] : [])) + + @jet sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(true), T(2), :))) + + # Try using mask if possible (possible!!) + @test @inferred(dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)) isa Any + + y, mask_, rng_ = dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), :) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test size(mask_) == x_shape + @test rng == rng_ + @test mask == mask_ + + __f = (x, mask) -> sum(first(dropout( + StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, :))) + @test @inferred(Zygote.gradient(__f, x, mask)) isa Any + + __f = let rng = rng, mask = mask, p = T(0.5), invp = T(2) + x -> sum(first(dropout(rng, x, mask, p, Val(true), Val(false), invp, :))) + end + + soft_fail = T == Float16 ? Any[AutoFiniteDiff()] : [] + skip_backends = length(x_shape) == 5 ? [AutoEnzyme()] : [] + + @test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, soft_fail, skip_backends) + + @jet sum(first(dropout( + rng, x, mask, T(0.5), Val(true), Val(false), T(2), :))) + mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType + + # Testing Mode + @test @inferred(dropout( + rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)) isa Any + + y, mask_, rng_ = dropout( + rng, x, mask, T(0.5), Val(false), Val(false), T(2), :) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test mask_ isa aType{T, length(x_shape)} + @test mask_ == mask + @test rng == rng_ + end + end +end + +@testitem "Alpha Dropout" tags=[:other_ops] setup=[SharedTestSetup] begin + using Statistics + + rng = StableRNG(12345) + + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "$T: $x_shape" for T in (Float16, Float32, Float64), + x_shape in ((2, 3), (2, 2, 3), (2, 2, 3, 1), (2, 2, 1, 3, 1)) + + !fp64 && T == Float64 && continue + + x = randn(rng, T, x_shape) |> aType + + @test @inferred(alpha_dropout(rng, x, T(0.5), Val(true))) isa Any + + y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test rng != rng_ + + @test_broken std(y)≈std(x) atol=1.0f-2 rtol=1.0f-2 + + __f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true)))) + @test @inferred(Zygote.gradient(__f, x)) isa Any + + __f = let rng = rng + x -> sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) + end + @test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, + soft_fail=(T == Float16 ? [AutoFiniteDiff()] : []), + broken_backends=(T == Float16 && Sys.iswindows() ? [AutoEnzyme()] : [])) + + @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) + @test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any + + y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) + + @test y isa aType{T, length(x_shape)} + @test size(y) == x_shape + @test rng == rng_ + @test y == x + end + end +end diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl new file mode 100644 index 000000000..3936200a8 --- /dev/null +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -0,0 +1,195 @@ +@testsetup module BatchNormSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static + +function setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool) + x = gen_f(T, sz) |> aType + scale = affine ? aType(gen_f(T, sz[end - 1])) : nothing + bias = affine ? aType(gen_f(T, sz[end - 1])) : nothing + + if track_stats + running_mean = gen_f(T, sz[end - 1]) |> aType + running_var = abs2.(gen_f(T, sz[end - 1])) |> aType + return x, scale, bias, running_mean, running_var + else + return x, scale, bias, nothing, nothing + end +end + +# Bypassing all optimizations +function batchnorm_fallback( + x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, + bias::LuxLib.Optional{<:AbstractVector}, + running_mean::LuxLib.Optional{<:AbstractVector}, + running_var::LuxLib.Optional{<:AbstractVector}, training::Val, + σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N} + y, xm, xv = LuxLib.Impl.normalization(x, LuxLib.Utils.remove_tracking(running_mean), + LuxLib.Utils.remove_tracking(running_var), scale, bias, + LuxLib.Impl.batchnorm_reduce_dims(x), static(training), momentum, epsilon, σ) + return (y, + (; running_mean=LuxLib.Utils.remove_tracking(LuxLib.Utils.safe_vec(xm)), + running_var=LuxLib.Utils.remove_tracking(LuxLib.Utils.safe_vec(xv)))) +end + +anonact = x -> x^3 + +is_training(::Val{training}) where {training} = training + +function run_batchnorm_testing( + gen_f, T, sz, training, affine, track_stats, act, aType, mode, ongpu) + epsilon = eps(T)^(5 // 7) + x, scale, bias, rm, rv = setup_batchnorm(gen_f, aType, T, sz; track_stats, affine) + + y, nt = batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + y_simple, nt_simple = batchnorm_fallback( + x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + @test y≈y_simple atol=atol rtol=rtol + if track_stats + @test nt.running_mean≈nt_simple.running_mean atol=atol rtol=rtol + @test nt.running_var≈nt_simple.running_var atol=atol rtol=rtol + end + + # Check the rrules + if is_training(training) + _f = (args...) -> sum(first(batchnorm( + args..., rm, rv, training, act, T(0.9), epsilon))) + _f2 = (args...) -> sum(first(batchnorm_fallback( + args..., rm, rv, training, act, T(0.9), epsilon))) + + ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias) + @test ∂x≈∂x_simple atol=atol rtol=rtol + if affine + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol + end + end + + @test @inferred(batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa + Any + @jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + + @test y isa aType{T, length(sz)} + @test size(y) == sz + if rm !== nothing + @test size(nt.running_mean) == (size(x, length(sz) - 1),) + @test size(nt.running_var) == (size(x, length(sz) - 1),) + end + + if is_training(training) && affine + skip_backends = [] + act === relu && push!(skip_backends, AutoFiniteDiff()) + + soft_fail = if fp16 + if Sys.iswindows() + [AutoTracker(), AutoFiniteDiff(), AutoReverseDiff(), AutoForwardDiff()] + else + true + end + else + false + end + + broken_backends = Sys.iswindows() && fp16 ? [AutoEnzyme()] : [] + + __f = (args...) -> sum(first(batchnorm( + args..., rm, rv, training, act, T(0.9), epsilon))) + @test_gradients(__f, x, scale, bias; atol, rtol, skip_backends, soft_fail, + broken_backends) + end + + if anonact !== act + lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm( + x, sc, b, rm, rv, tr, act, ϵ))) + @test @inferred(Zygote.gradient( + lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any + end +end + +const ALL_TEST_CONFIGS = Iterators.product( + [Float16, Float32, Float64], ((4, 4, 6, 2), (8, 2), (4, 4, 4, 3, 2)), + (Val(true), Val(false)), (true, false), (true, false), + (identity, relu, tanh_fast, sigmoid_fast, anonact)) + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export setup_batchnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_batchnorm_testing + +end + +@testitem "Batch Norm: Group 1" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[1] + !fp64 && T == Float64 && continue + run_batchnorm_testing(generate_fixed_array, T, sz, training, + affine, track_stats, act, aType, mode, ongpu) + end + end +end + +@testitem "Batch Norm: Group 2" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[2] + !fp64 && T == Float64 && continue + run_batchnorm_testing(generate_fixed_array, T, sz, training, + affine, track_stats, act, aType, mode, ongpu) + end + end +end + +@testitem "Batch Norm: Group 3" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[3] + !fp64 && T == Float64 && continue + run_batchnorm_testing(generate_fixed_array, T, sz, training, + affine, track_stats, act, aType, mode, ongpu) + end + end +end + +@testitem "Batch Norm: Group 4" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[4] + !fp64 && T == Float64 && continue + run_batchnorm_testing(generate_fixed_array, T, sz, training, + affine, track_stats, act, aType, mode, ongpu) + end + end +end + +@testitem "Batch Norm: Group 5" tags=[:batch_norm] setup=[SharedTestSetup, BatchNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $sz, $act $affine $track_stats" for (T, sz, training, affine, track_stats, act) in TEST_BLOCKS[5] + !fp64 && T == Float64 && continue + run_batchnorm_testing(generate_fixed_array, T, sz, training, + affine, track_stats, act, aType, mode, ongpu) + end + end +end + +@testitem "Batch Norm: Mixed Precision" tags=[:batch_norm] setup=[SharedTestSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + !fp64 && aType == Float64 && continue + + x = rand(Float64, 4, 4, 6, 2) |> aType + scale = rand(Float32, 6) |> aType + bias = rand(Float32, 6) |> aType + running_mean = rand(Float32, 6) |> aType + running_var = rand(Float32, 6) |> aType + + y, nt = batchnorm( + x, scale, bias, running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5) + @test y isa aType{Float64, 4} + @test nt.running_mean isa aType && length(nt.running_mean) == 6 + @test nt.running_var isa aType && length(nt.running_var) == 6 + + __f = (args...) -> sum(first(batchnorm( + args..., running_mean, running_var, Val(true), identity, 0.9f0, 1.0f-5))) + @test_gradients(__f, x, scale, bias; atol=1.0f-3, rtol=1.0f-3) + end +end diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl new file mode 100644 index 000000000..3c638885c --- /dev/null +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -0,0 +1,138 @@ +@testsetup module GroupNormSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs +using LuxTestUtils: check_approx + +function setup_groupnorm(rng, aType, T, sz, affine) + x = randn(rng, T, sz) |> aType + if affine + scale = randn(rng, T, sz[end - 1]) |> aType + bias = randn(rng, T, sz[end - 1]) |> aType + return x, scale, bias + end + return x, nothing, nothing +end + +# Bypassing all optimizations +function groupnorm_fallback( + x::AbstractArray{<:Real, N}, scale::LuxLib.Optional{<:AbstractVector}, + bias::LuxLib.Optional{<:AbstractVector}, groups::Int, + σ::F=identity, epsilon::Real=1.0f-5) where {F, N} + sz = size(x) + x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N]) + y, _, _ = LuxLib.Impl.normalization(x_reshaped, nothing, nothing, scale, bias, + LuxLib.Impl.groupnorm_reduce_dims(x), False(), nothing, epsilon, σ) + return reshape(y, sz) +end + +anonact = x -> x^3 + +is_training(::Val{training}) where {training} = training + +function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) + _f = (args...) -> groupnorm(args..., groups, act, epsilon) + _f2 = (args...) -> groupnorm_fallback(args..., groups, act, epsilon) + + epsilon = LuxLib.Utils.default_epsilon(T) + x, scale, bias = setup_groupnorm(StableRNG(0), aType, T, sz, affine) + + y = _f(x, scale, bias) + y_simple = _f2(x, scale, bias) + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + @test y≈y_simple atol=atol rtol=rtol + + # Check the rrules + if !fp16 + ∂x, ∂scale, ∂bias = Zygote.gradient(sum ∘ _f, x, scale, bias) + ∂x_simple, ∂scale_simple, ∂bias_simple = Zygote.gradient(sum ∘ _f2, x, scale, bias) + if length(sz) == 5 && !ongpu + @test_softfail check_approx(∂x, ∂x_simple; atol, rtol) + else + @test ∂x≈∂x_simple atol=atol rtol=rtol + end + if affine + @test ∂scale≈∂scale_simple atol=atol rtol=rtol + @test ∂bias≈∂bias_simple atol=atol rtol=rtol + end + end + + @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any + @jet groupnorm(x, scale, bias, groups, act, epsilon) + + if anonact !== act + lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ)) + @test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa Any + end + + @test y isa aType{T, length(sz)} + @test size(y) == sz + + soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] + + if affine + __f = (args...) -> sum(groupnorm(args..., groups, act, epsilon)) + @test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) + end +end + +const ALL_TEST_CONFIGS = Iterators.product([Float16, Float32, Float64], + ((6, 2), (4, 6, 2), (8, 8, 8, 6, 2), (3, 16, 16, 12, 2), + (4, 4, 6, 2), (2, 2, 6, 2), (3, 3, 12, 4)), + (2, 3), + (true, false), + (identity, relu, tanh_fast, sigmoid_fast, anonact)) + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export setup_groupnorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_groupnorm_testing + +end + +@testitem "Group Norm: Group 1" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[1] + !fp64 && T == Float64 && continue + run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) + end + end +end + +@testitem "Group Norm: Group 2" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[2] + !fp64 && T == Float64 && continue + run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) + end + end +end + +@testitem "Group Norm: Group 3" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[3] + !fp64 && T == Float64 && continue + run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) + end + end +end + +@testitem "Group Norm: Group 4" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[4] + !fp64 && T == Float64 && continue + run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) + end + end +end + +@testitem "Group Norm: Group 5" tags=[:group_norm] setup=[SharedTestSetup, GroupNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $sz, $groups, $affine, $act" for (T, sz, groups, affine, act) in TEST_BLOCKS[5] + !fp64 && T == Float64 && continue + run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) + end + end +end diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl new file mode 100644 index 000000000..ff166cfa5 --- /dev/null +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -0,0 +1,138 @@ +@testsetup module InstanceNormSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib + +is_training(::Val{training}) where {training} = training + +function setup_instancenorm(gen_f, aType, T, sz; affine::Bool=true) + x = gen_f(T, sz) |> aType + scale = affine ? aType(gen_f(T, sz[end - 1])) : nothing + bias = affine ? aType(gen_f(T, sz[end - 1])) : nothing + return x, scale, bias +end + +anonact = x -> x^3 + +function run_instancenorm_testing(gen_f, T, sz, training, act, aType, mode, ongpu) + _f = (args...) -> first(instancenorm(args..., training, act, epsilon)) + + epsilon = LuxLib.Utils.default_epsilon(T) + x, scale, bias = setup_instancenorm(gen_f, aType, T, sz) + + # First test without running stats + y, nt = instancenorm(x, scale, bias, training, act, epsilon) + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + @test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any + @jet instancenorm(x, scale, bias, training, act, epsilon) + + if anonact !== act && is_training(training) + lfn = (x, sc, b, act, ϵ) -> sum(first(instancenorm(x, sc, b, Val(true), act, ϵ))) + @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, epsilon)) isa Any + end + + @test y isa aType{T, length(sz)} + @test size(y) == sz + + if is_training(training) + __f = (args...) -> sum(first(instancenorm(args..., training, act, epsilon))) + soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] + @test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) + end + + # Now test with running stats + rm = rand(T, sz[end - 1]) |> aType + rv = abs2.(gen_f(T, sz[end - 1])) |> aType + + y, nt = instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) + + @test @inferred(instancenorm( + x, scale, bias, rm, rv, training, act, T(0.1), epsilon)) isa Any + @jet instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) + + if anonact !== act && is_training(training) + lfn = (x, sc, b, rm, rv, act, m, ϵ) -> sum(first(instancenorm( + x, sc, b, rm, rv, Val(true), act, m, ϵ))) + @test @inferred(Zygote.gradient( + lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon)) isa Any + end + + @test y isa aType{T, length(sz)} + @test size(y) == sz + + if is_training(training) + __f = (args...) -> sum(first(instancenorm( + args..., rm, rv, training, act, T(0.1), epsilon))) + soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] + skip_backends = (Sys.iswindows() && fp16) ? [AutoEnzyme()] : [] + @test_gradients(__f, x, scale, bias; atol, rtol, soft_fail, skip_backends) + end +end + +const ALL_TEST_CONFIGS = Iterators.product( + [Float16, Float32, Float64], ((4, 4, 6, 2), (3, 4, 2), (4, 4, 4, 3, 2)), + (Val(true), Val(false)), (identity, relu, tanh_fast, sigmoid_fast, anonact)) + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export setup_instancenorm, ALL_TEST_CONFIGS, TEST_BLOCKS, run_instancenorm_testing + +end + +@testitem "Instance Norm: Group 1" tags=[:instance_norm] setup=[ + SharedTestSetup, InstanceNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[1] + !fp64 && T == Float64 && continue + run_instancenorm_testing( + generate_fixed_array, T, sz, training, act, aType, mode, ongpu) + end + end +end + +@testitem "Instance Norm: Group 2" tags=[:instance_norm] setup=[ + SharedTestSetup, InstanceNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[2] + !fp64 && T == Float64 && continue + run_instancenorm_testing( + generate_fixed_array, T, sz, training, act, aType, mode, ongpu) + end + end +end + +@testitem "Instance Norm: Group 3" tags=[:instance_norm] setup=[ + SharedTestSetup, InstanceNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[3] + !fp64 && T == Float64 && continue + run_instancenorm_testing( + generate_fixed_array, T, sz, training, act, aType, mode, ongpu) + end + end +end + +@testitem "Instance Norm: Group 4" tags=[:instance_norm] setup=[ + SharedTestSetup, InstanceNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[4] + !fp64 && T == Float64 && continue + run_instancenorm_testing( + generate_fixed_array, T, sz, training, act, aType, mode, ongpu) + end + end +end + +@testitem "Instance Norm: Group 5" tags=[:instance_norm] setup=[ + SharedTestSetup, InstanceNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $sz, $training $act" for (T, sz, training, act) in TEST_BLOCKS[5] + !fp64 && T == Float64 && continue + run_instancenorm_testing( + generate_fixed_array, T, sz, training, act, aType, mode, ongpu) + end + end +end diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl new file mode 100644 index 000000000..37ca3c702 --- /dev/null +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -0,0 +1,155 @@ +@testsetup module LayerNormSetup +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics +using LuxTestUtils: check_approx + +function setup_layernorm(gen_f, aType, T, x_size, affine_shape, expand_dims::Bool=true) + x = gen_f(T, x_size) |> aType + if affine_shape !== nothing + if expand_dims + scale = gen_f(T, (affine_shape..., 1)) |> aType + bias = gen_f(T, (affine_shape..., 1)) |> aType + else + scale = gen_f(T, affine_shape) |> aType + bias = gen_f(T, affine_shape) |> aType + end + return x, scale, bias + else + return x, nothing, nothing + end +end + +function run_layernorm_testing(gen_f, aType, T, x_size, affine_shape, act, ongpu, mode) + @testset for dims in (Colon(), nothing) + if dims === nothing + affine_shape === nothing && continue + length(x_size) ≤ length(affine_shape) && continue + x, scale, bias = setup_layernorm(gen_f, aType, T, x_size, affine_shape, false) + else + x, scale, bias = setup_layernorm(gen_f, aType, T, x_size, affine_shape) + end + + run_layernorm_testing_core( + aType, T, x_size, affine_shape, act, dims, x, scale, bias) + end +end + +function run_layernorm_testing_core( + aType, T, x_size, affine_shape, act, dims, x, scale, bias) + epsilon = LuxLib.Utils.default_epsilon(T) + _f = (args...) -> layernorm(args..., act, dims, epsilon) + + @test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any + @jet layernorm(x, scale, bias, act, dims, epsilon) + + y = _f(x, scale, bias) + + @test y isa aType{T, length(x_size)} + @test size(y) == x_size + + if affine_shape === nothing && act === identity + @test check_approx(mean(y; dims), 0; atol=1e-3, rtol=1e-3) + @test check_approx(std(y; dims), 1; atol=1e-1, rtol=1e-1) + end + + fp16 = T == Float16 + atol = fp16 ? 1.0f-2 : 1.0f-3 + rtol = fp16 ? 1.0f-2 : 1.0f-3 + + soft_fail = fp16 ? fp16 : [AutoFiniteDiff()] + if affine_shape !== nothing + __f = (args...) -> sum(_f(args...)) + @test_gradients(__f, x, scale, bias; atol, rtol, soft_fail) + else + __f = x -> sum(_f(x, scale, bias)) + @test_gradients(__f, x; atol, rtol, soft_fail) + end + + if anonact !== act + lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ)) + @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon)) isa Any + end +end + +anonact = x -> x^3 + +const ALL_TEST_CONFIGS = Any[] + +for T in (Float16, Float32, Float64), + x_shape in ((3, 3, 2, 1), (2, 2, 2, 1), (2, 3, 2, 2)), + affine_shape in (nothing, x_shape[1:3], (1, 1, 1), (1, 1, x_shape[3])), + act in (identity, relu, tanh_fast, sigmoid_fast, anonact) + + push!(ALL_TEST_CONFIGS, (T, x_shape, affine_shape, act)) +end + +const TEST_BLOCKS = collect(Iterators.partition( + ALL_TEST_CONFIGS, ceil(Int, length(ALL_TEST_CONFIGS) / 5))) + +export ALL_TEST_CONFIGS, TEST_BLOCKS, run_layernorm_testing + +end + +@testitem "Layer Norm: Group 1" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[1] + !fp64 && T == Float64 && continue + run_layernorm_testing( + generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) + end + end +end + +@testitem "Layer Norm: Group 2" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[2] + !fp64 && T == Float64 && continue + run_layernorm_testing( + generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) + end + end +end + +@testitem "Layer Norm: Group 3" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[3] + !fp64 && T == Float64 && continue + run_layernorm_testing( + generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) + end + end +end + +@testitem "Layer Norm: Group 4" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[4] + !fp64 && T == Float64 && continue + run_layernorm_testing( + generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) + end + end +end + +@testitem "Layer Norm: Group 5" tags=[:layer_norm] setup=[SharedTestSetup, LayerNormSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "eltype $T, size $x_shape, $act" for (T, x_shape, affine_shape, act) in TEST_BLOCKS[5] + !fp64 && T == Float64 && continue + run_layernorm_testing( + generate_fixed_array, aType, T, x_shape, affine_shape, act, ongpu, mode) + end + end +end + +@testitem "Layer Norm: Error Checks" tags=[:layer_norm] setup=[SharedTestSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + !fp64 && continue + + x = rand(2, 3) |> aType + + @test_throws ArgumentError layernorm(x, nothing, nothing, identity, nothing, 1e-5) + + sc = rand(2, 1) |> aType + b = rand(2, 1) |> aType + + @test_throws AssertionError layernorm(x, sc, b, identity, nothing, 1e-5) + end +end diff --git a/lib/LuxLib/test/others/bmm_tests.jl b/lib/LuxLib/test/others/bmm_tests.jl new file mode 100644 index 000000000..2b89b0ef2 --- /dev/null +++ b/lib/LuxLib/test/others/bmm_tests.jl @@ -0,0 +1,336 @@ +# Most of the tests in this file were derived from https://github.com/FluxML/NNlib.jl/blob/master/test/batchedmul.jl +@testsetup module BatchedMMSetup + +using NNlib + +function bmm_test(a, b; transA=false, transB=false) + bs = size(a, 3) + transA && (a = permutedims(a, [2, 1, 3])) + transB && (b = permutedims(b, [2, 1, 3])) + c = [] + for i in 1:bs + push!(c, a[:, :, i] * b[:, :, i]) + end + + return cat(c...; dims=3) +end + +function bmm_adjtest(a, b; adjA=false, adjB=false) + bs = size(a, 3) + c = [] + for i in 1:bs + ai = adjA ? adjoint(a[:, :, i]) : a[:, :, i] + bi = adjB ? adjoint(b[:, :, i]) : b[:, :, i] + push!(c, ai * bi) + end + + return cat(c...; dims=3) +end + +function half_batched_mul(x, y) + @assert size(y, 3) == 1 + d = size(x, 2) + x_mat = reshape(permutedims(x, (1, 3, 2)), :, d) + y_mat = reshape(y, d, :) + z_mat = x_mat * y_mat + return permutedims(reshape(z_mat, size(x, 1), size(x, 3), :), (1, 3, 2)) +end + +perm_12(A) = PermutedDimsArray(A, (2, 1, 3)) +perm_23(A) = PermutedDimsArray(A, (1, 3, 2)) + +export bmm_test, bmm_adjtest, half_batched_mul, perm_12, perm_23 + +end + +@testitem "batched_mul" tags=[:batched_ops] setup=[SharedTestSetup, BatchedMMSetup] begin + rng = StableRNG(1234) + + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + @testset "batched_mul: Float64 × $(TB)" for TB in [Float64, Float32] + !fp64 && continue + + @testset "real" begin + A = randn(rng, 7, 5, 3) |> aType + B = randn(rng, TB, 5, 7, 3) |> aType + C = randn(rng, 7, 6, 3) |> aType + + @test batched_matmul(A, B) ≈ bmm_test(A, B) + @test batched_matmul(batched_transpose(A), batched_transpose(B)) ≈ + bmm_test(A, B; transA=true, transB=true) + @test batched_matmul(batched_transpose(A), C) ≈ bmm_test(A, C; transA=true) + @test batched_matmul(A, batched_transpose(A)) ≈ bmm_test(A, A; transB=true) + end + + @testset "complex" begin + cA = randn(rng, Complex{Float64}, 7, 5, 3) |> aType + cB = randn(rng, Complex{TB}, 5, 7, 3) |> aType + cC = randn(rng, Complex{Float64}, 7, 6, 3) |> aType + + @test batched_matmul(cA, cB) ≈ bmm_adjtest(cA, cB) + @test batched_matmul(batched_adjoint(cA), batched_adjoint(cB)) ≈ + bmm_adjtest(cA, cB; adjA=true, adjB=true) + @test batched_matmul(batched_adjoint(cA), cC) ≈ + bmm_adjtest(cA, cC; adjA=true) + @test batched_matmul(cA, batched_adjoint(cA)) ≈ + bmm_adjtest(cA, cA; adjB=true) + + @testset "Integers" begin + TBi = TB == Float64 ? Int64 : Int32 + iA = rand(rng, 1:99, 7, 5, 3) |> aType + iB = TB.(rand(rng, 1:99, 5, 7, 3)) |> aType + iC = zeros(Int, 7, 6, 3) |> aType + + @test batched_matmul(iA, iB) == bmm_adjtest(iA, iB) + @test batched_matmul(cA, iB) ≈ bmm_adjtest(cA, iB) + end + end + + @testset "Errors" begin + @test_throws DimensionMismatch batched_matmul( + aType(rand(rng, 2, 2, 2)), aType(rand(rng, TB, 2, 2, 10))) + @test_throws DimensionMismatch batched_matmul( + aType(rand(rng, 2, 2, 2)), aType(rand(rng, TB, 10, 2, 2))) + end + + @testset "PermutedDimsArrays" begin + if !ongpu + for perm in [(1, 3, 2), (2, 1, 3), (3, 2, 1)], + fun in [identity, batched_adjoint], + ty in [identity, complex] + + A = randn(rng, ty(Float64), 4, 4, 4) |> aType + B = randn(rng, ty(TB), 4, 4, 4) |> aType + + @test batched_matmul(fun(A), PermutedDimsArray(B, perm)) ≈ + batched_matmul(fun(A), permutedims(B, perm)) + @test batched_matmul(fun(PermutedDimsArray(A, perm)), B) ≈ + batched_matmul(fun(permutedims(A, perm)), B) + end + end + end + + @testset "Large output, multi-threaded path" begin + if TB == Float64 + N = 50 + A = rand(rng, N, N, N) |> aType + B = rand(rng, N, N, N) |> aType + C = reshape( + reduce(hcat, [vec(A[:, :, k] * B[:, :, k]) for k in 1:N]), N, N, N) + @test C ≈ A ⊠ B + + D = rand(rng, N, N, 1) |> aType + E = reshape( + reduce(hcat, [vec(A[:, :, k] * D[:, :, 1]) for k in 1:N]), N, N, N) + @test E ≈ A ⊠ D + end + end + end + end +end + +@testitem "batched_mul: trivial dimensions & unit strides" tags=[:batched_ops] setup=[ + SharedTestSetup, BatchedMMSetup] begin + rng = StableRNG(1234) + + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + !fp64 && continue + + @testset "Float64 × $(TB)" for TB in [Float64, ComplexF64] + @testset "trivial dimensions & unit strides" begin + @testset "$tA(rand$((sA...,3))) ⊠ $tB(rand$((sB...,3)))" for tA in [ + identity, batched_adjoint, batched_transpose, perm_12, perm_23], + sA in [(1, 1), (1, 3), (3, 1), (3, 3)], + tB in [identity, batched_adjoint, batched_transpose, perm_12, perm_23], + sB in [(1, 1), (1, 3), (3, 1), (3, 3)] + + A = tA(rand(rng, TB, sA..., 3)) |> aType + B = tB(rand(rng, TB, sB..., 3)) |> aType + + if size(A, 2) != size(B, 1) || size(A, 3) != 3 || size(B, 3) != 3 + @test true # avoid a warning in ReTestItems.jl + continue + end + + C = cat(A[:, :, 1] * B[:, :, 1], A[:, :, 2] * B[:, :, 2], + A[:, :, 3] * B[:, :, 3]; dims=3) + @test batched_matmul(A, B) ≈ C + end + end + end + end +end + +@testitem "BatchedAdjOrTrans interface" tags=[:batched_ops] setup=[ + SharedTestSetup, BatchedMMSetup] begin + rng = StableRNG(1234) + + @testset "Float64 × $(TB)" for TB in [Float64, Float32] + A = randn(rng, 7, 5, 3) + B = randn(rng, TB, 5, 7, 3) + C = randn(rng, 7, 6, 3) + + function interface_tests(X, _X) + @test length(_X) == length(X) + @test size(_X) == (size(X, 2), size(X, 1), size(X, 3)) + @test axes(_X) == (axes(X, 2), axes(X, 1), axes(X, 3)) + + @test getindex(_X, 2, 3, 3) == getindex(X, 3, 2, 3) + @test getindex(_X, 5, 4, 1) == getindex(X, 4, 5, 1) + + setindex!(_X, 2.0, 2, 4, 1) + @test getindex(_X, 2, 4, 1) == 2.0 + setindex!(_X, 3.0, 1, 2, 2) + @test getindex(_X, 1, 2, 2) == 3.0 + + _sim = similar(_X, TB, (2, 3)) + @test size(_sim) == (2, 3) + @test typeof(_sim) == Array{TB, 2} + + _sim = similar(_X, TB) + @test length(_sim) == length(_X) + @test typeof(_sim) == Array{TB, 3} + + _sim = similar(_X, (2, 3)) + @test size(_sim) == (2, 3) + @test typeof(_sim) == Array{Float64, 2} + + _sim = similar(_X) + @test length(_sim) == length(_X) + @test typeof(_sim) == Array{Float64, 3} + + @test parent(_X) == _X.parent + end + + for (X, _X) in zip([A, B, C], map(batched_adjoint, [A, B, C])) + interface_tests(X, _X) + + @test -_X == NNlib.BatchedAdjoint(-_X.parent) + + _copyX = copy(_X) + @test _X == _copyX + + setindex!(_copyX, 2.0, 1, 2, 1) + @test _X != _copyX + end + + for (X, _X) in zip([A, B, C], map(batched_transpose, [A, B, C])) + interface_tests(X, _X) + + @test -_X == NNlib.BatchedTranspose(-_X.parent) + + _copyX = copy(_X) + @test _X == _copyX + + setindex!(_copyX, 2.0, 1, 2, 1) + @test _X != _copyX + end + end +end + +@testitem "batched_matmul(ndims < 3)" tags=[:batched_ops] setup=[ + SharedTestSetup, BatchedMMSetup] begin + rng = StableRNG(1234) + + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + !fp64 && continue + + @testset "Float64 × $(TB)" for TB in [Float64, ComplexF64] + A = randn(rng, 3, 3, 3) |> aType + M = aType(rand(rng, TB, 3, 3)) .+ im + V = aType(rand(rng, TB, 3)) + + # These are all reshaped and sent to batched_matmul(3-array, 3-array) + @test batched_matmul(A, M) ≈ cat([A[:, :, k] * M for k in 1:3]...; dims=3) + @test batched_matmul(A, M') ≈ cat([A[:, :, k] * M' for k in 1:3]...; dims=3) + @test A ⊠ transpose(M) ≈ + cat([A[:, :, k] * transpose(M) for k in 1:3]...; dims=3) + + @test batched_matmul(M, A) ≈ cat([M * A[:, :, k] for k in 1:3]...; dims=3) + @test batched_matmul(M', A) ≈ cat([M' * A[:, :, k] for k in 1:3]...; dims=3) + @test transpose(M) ⊠ A ≈ + cat([transpose(M) * A[:, :, k] for k in 1:3]...; dims=3) + + # batched_vec + @test batched_vec(A, M) ≈ hcat([A[:, :, k] * M[:, k] for k in 1:3]...) + @test batched_vec(A, M') ≈ hcat([A[:, :, k] * (M')[:, k] for k in 1:3]...) + @test batched_vec(A, V) ≈ hcat([A[:, :, k] * V for k in 1:3]...) + end + end +end + +@testitem "BMM AutoDiff" tags=[:batched_ops] setup=[SharedTestSetup, BatchedMMSetup] begin + rng = StableRNG(1234) + + fn(A, B) = sum(batched_matmul(A, B)) + fn_vec(A, B) = sum(batched_vec(A, B)) + + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + M, P, Q = 13, 7, 11 + B = 3 + + @testset "Two 3-arrays" begin + @test_gradients(fn, aType(randn(rng, Float32, M, P, B)), + aType(randn(rng, Float32, P, Q, B)); atol=1e-3, rtol=1e-3) + @test_gradients(fn, batched_adjoint(aType(randn(rng, Float32, P, M, B))), + aType(randn(rng, Float32, P, Q, B)); atol=1e-3, rtol=1e-3) + @test_gradients(fn, aType(randn(rng, Float32, M, P, B)), + batched_transpose(aType(randn(rng, Float32, Q, P, B))); atol=1e-3, + rtol=1e-3) + end + + @testset "One a matrix..." begin + @test_gradients(fn, aType(randn(rng, Float32, M, P)), + aType(randn(rng, Float32, P, Q, B)); atol=1e-3, rtol=1e-3) + @test_gradients(fn, adjoint(aType(randn(rng, Float32, P, M))), + aType(randn(rng, Float32, P, Q, B)); atol=1e-3, rtol=1e-3) + @test_gradients(fn, aType(randn(rng, Float32, M, P)), + batched_adjoint(aType(randn(rng, Float32, Q, P, B))); atol=1e-3, rtol=1e-3) + + @test_gradients(fn, aType(randn(rng, Float32, M, P)), + aType(randn(rng, Float32, P, Q, B)); atol=1e-3, rtol=1e-3) + @test_gradients(fn, adjoint(aType(randn(rng, Float32, P, M))), + aType(randn(rng, Float32, P, Q, B)); atol=1e-3, rtol=1e-3) + @test_gradients(fn, aType(randn(rng, Float32, M, P)), + batched_adjoint(aType(randn(rng, Float32, Q, P, B))); atol=1e-3, rtol=1e-3) + end + + @testset "... or equivalent to a matrix" begin + @test_gradients(fn, aType(randn(rng, Float32, M, P, 1)), + aType(randn(rng, Float32, P, Q, B)); atol=1e-3, rtol=1e-3) + @test_gradients(fn, batched_transpose(aType(randn(rng, Float32, P, M, 1))), + aType(randn(rng, Float32, P, Q, B)); atol=1e-3, rtol=1e-3) + @test_gradients(fn, aType(randn(rng, Float32, M, P, 1)), + batched_transpose(aType(randn(rng, Float32, Q, P, B))); atol=1e-3, + rtol=1e-3) + end + end +end + +@testitem "BMM Tracker AoS" tags=[:batched_ops] setup=[SharedTestSetup, BatchedMMSetup] begin + using Tracker, Zygote, NNlib + + rng = StableRNG(1234) + + fn(A, B) = sum(batched_matmul(A, B)) + + ops = (identity, NNlib.batched_adjoint, NNlib.batched_transpose) + + @testset "$mode" for (mode, aType, ongpu) in MODES + x = randn(rng, Float32, 3, 3, 2) |> aType + + @testset "$(op1) x $(op2)" for (op1, op2) in Iterators.product(ops, ops) + x1 = op1(x) + x2 = op2(x) + + ∂x1_tr, ∂x2_tr = Tracker.gradient(fn, x1, x2) + ∂x1_zy, ∂x2_zy = Zygote.gradient(fn, x1, x2) + + @test ∂x1_tr≈∂x1_zy atol=1e-3 rtol=1e-3 + @test ∂x2_tr≈∂x2_zy atol=1e-3 rtol=1e-3 + + @test ∂x1_tr isa Tracker.TrackedArray + @test ∂x2_tr isa Tracker.TrackedArray + end + end +end diff --git a/lib/LuxLib/test/others/forwarddiff_tests.jl b/lib/LuxLib/test/others/forwarddiff_tests.jl new file mode 100644 index 000000000..228aa7d38 --- /dev/null +++ b/lib/LuxLib/test/others/forwarddiff_tests.jl @@ -0,0 +1,113 @@ +@testitem "Efficient JVPs" tags=[:others] setup=[SharedTestSetup] begin + using ForwardDiff, Zygote, ComponentArrays + using LuxTestUtils: check_approx + + # Computes (∂f/∂x)u + function jvp_forwarddiff(f::F, x, u) where {F} + uu = reshape(u, axes(x)) + y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(x))), eltype(x), + 1}.(x, ForwardDiff.Partials.(tuple.(uu))) + return vec(ForwardDiff.partials.(vec(f(y)), 1)) + end + + function jvp_forwarddiff(f::F, x::ComponentArray, u) where {F} + xx = getdata(x) + uu = vec(u) + y = ComponentArray( + ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(x))), eltype(x), + 1}.(xx, ForwardDiff.Partials.(tuple.(uu))), + getaxes(x)) + return vec(ForwardDiff.partials.(vec(f(y)), 1)) + end + + ## This exists exclusively for testing. It has horrifying performance implications + jvp_forwarddiff_concrete(f::F, x, u) where {F} = ForwardDiff.jacobian(f, x) * vec(u) + jvp_zygote(f::F, x, u) where {F} = only(Zygote.jacobian(f, x)) * vec(u) + + function test_jvp_computation(f::F, x, u, ongpu, nested=false) where {F} + jvp₁ = jvp_forwarddiff(f, x, u) + if !(x isa ComponentArray && ongpu) + # ComponentArray + ForwardDiff on GPU don't play nice + jvp₂ = jvp_forwarddiff_concrete(f, x, u) + @test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5) + end + + if !nested + jvp₃ = jvp_zygote(f, x, u) + @test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5) + end + end + + @testset "$(mode): Jacobian Vector Products" for (mode, aType, ongpu, fp64) in MODES + @testset "$(op)(; flipped = $flipped)" for flipped in (true, false), + op in (depthwiseconv, conv) + + op === depthwiseconv && ongpu && continue + + input_dims = [(2, 4, 2, 1, 3), (4, 4, 1, 3), (4, 4, 3, 2), (4, 1, 3), (4, 3, 2)] + weight_dims = if op === depthwiseconv + [(2, 2, 2, 1, 1), (3, 3, 1, 1), (3, 3, 3, 3), (3, 1, 1), (3, 3, 3)] + else + [(2, 2, 2, 1, 4), (3, 3, 1, 4), (3, 3, 3, 2), (3, 1, 4), (3, 3, 2)] + end + + @testset "Input Dims: $(in_dims) | Weight Dims: $(w_dims)" for (in_dims, w_dims) in zip( + input_dims, weight_dims) + x = randn(Float32, in_dims...) |> aType + w = randn(Float32, w_dims...) |> aType + ux = randn(Float32, size(x)...) |> aType + uw = randn(Float32, size(w)...) |> aType + u = randn(Float32, length(x) + length(w)) |> aType + + test_jvp_computation(x -> op(x, w; flipped), x, ux, ongpu) + test_jvp_computation(w -> op(x, w; flipped), w, uw, ongpu) + test_jvp_computation( + xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u, ongpu) + + op === depthwiseconv && continue + + # Zygote.gradient here is used to test the ∇conv_data and ∇conv_filter + # functions. Also implicitly tests nested AD + test_jvp_computation( + x -> only(Zygote.gradient(w -> sum(abs2, op(x, w; flipped)), w)), + x, ux, ongpu, true) + test_jvp_computation( + x -> only(Zygote.gradient(x -> sum(abs2, op(x, w; flipped)), x)), + x, ux, ongpu, true) + test_jvp_computation( + w -> only(Zygote.gradient(x -> sum(abs2, op(x, w; flipped)), x)), + w, uw, ongpu, true) + test_jvp_computation( + w -> only(Zygote.gradient(w -> sum(abs2, op(x, w; flipped)), w)), + w, uw, ongpu, true) + test_jvp_computation( + xw -> only(Zygote.gradient( + xw -> sum(abs2, op(xw.x, xw.w; flipped)), xw)), + ComponentArray(; x, w), + u, + ongpu, + true) + end + end + end +end + +@testitem "ForwardDiff dropout" tags=[:other_ops] setup=[SharedTestSetup] begin + using ForwardDiff + using LuxTestUtils: check_approx + + rng = StableRNG(12345) + + @testset "$mode: dropout" for (mode, aType, ongpu, fp64) in MODES + x = randn(rng, Float32, 10, 2) |> aType + x_dual = ForwardDiff.Dual.(x) + + @test_nowarn dropout(rng, x_dual, 0.5f0, Val(true), 2.0f0, :) + + x_dropout = dropout(rng, x, 0.5f0, Val(true), 2.0f0, :)[1] + x_dual_dropout = ForwardDiff.value.(dropout( + rng, x_dual, 0.5f0, Val(true), 2.0f0, :)[1]) + + @test check_approx(x_dropout, x_dual_dropout) + end +end diff --git a/lib/LuxLib/test/others/misc_tests.jl b/lib/LuxLib/test/others/misc_tests.jl new file mode 100644 index 000000000..6e046eea2 --- /dev/null +++ b/lib/LuxLib/test/others/misc_tests.jl @@ -0,0 +1,33 @@ +@testitem "internal_operation_mode: Wrapped Arrays" tags=[:others] setup=[SharedTestSetup] begin + @testset "$mode" for (mode, aType, ongpu, fp64) in MODES + x = rand(Float32, 4, 3) |> aType + retval = ongpu ? LuxLib.GPUBroadcastOp : LuxLib.LoopedArrayOp + @test LuxLib.internal_operation_mode(x) isa retval + end + + using StaticArrays, JLArrays + + x = rand(Float32, 4, 3) |> JLArray + @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp + + x = @SArray rand(Float32, 4, 3) + @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp + + x = reshape(@SArray(rand(Float32, 4)), :, 1) + @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp +end + +@testitem "Matmul: StaticArrays" tags=[:others] setup=[SharedTestSetup] begin + using LuxLib.Impl: matmuladd + using StaticArrays + + A = rand(2, 2) + bias = rand(2) + + # This works with LoopVectorization + B = ones(SMatrix{2, 1, Float64}) + @test matmuladd(A, B, bias) ≈ A * B .+ bias + + b = ones(SVector{2, Float64}) + @test matmuladd(A, b, bias) ≈ A * b .+ bias +end diff --git a/lib/LuxLib/test/others/qa_tests.jl b/lib/LuxLib/test/others/qa_tests.jl new file mode 100644 index 000000000..ed7e9f980 --- /dev/null +++ b/lib/LuxLib/test/others/qa_tests.jl @@ -0,0 +1,25 @@ +@testitem "Aqua: Quality Assurance" tags=[:others] begin + using Aqua, ChainRulesCore, EnzymeCore, NNlib + using EnzymeCore: EnzymeRules + + Aqua.test_all( + LuxLib; ambiguities=false, piracies=false, stale_deps=Sys.ARCH === :x86_64) + Aqua.test_ambiguities(LuxLib; recursive=false, + exclude=[conv, ∇conv_data, ∇conv_filter, depthwiseconv, ChainRulesCore.frule]) + Aqua.test_piracies(LuxLib; + treat_as_own=[conv, ∇conv_data, ∇conv_filter, depthwiseconv, + EnzymeRules.augmented_primal, EnzymeRules.reverse]) +end + +@testitem "Explicit Imports" tags=[:others] setup=[SharedTestSetup] begin + using ExplicitImports + + @test check_no_implicit_imports(LuxLib) === nothing + @test check_no_stale_explicit_imports( + LuxLib; ignore=(:TrackedVector, :batched_mul, :batched_matmul)) === nothing + @test check_no_self_qualified_accesses(LuxLib) === nothing + @test check_all_explicit_imports_via_owners(LuxLib) === nothing + @test check_all_qualified_accesses_via_owners(LuxLib) === nothing + @test_broken check_all_explicit_imports_are_public(LuxLib) === nothing # mostly upstream problems + @test_broken check_all_qualified_accesses_are_public(LuxLib) === nothing # mostly upstream problems +end diff --git a/lib/LuxLib/test/runtests.jl b/lib/LuxLib/test/runtests.jl new file mode 100644 index 000000000..fea1e6422 --- /dev/null +++ b/lib/LuxLib/test/runtests.jl @@ -0,0 +1,54 @@ +using ReTestItems, Pkg, LuxTestUtils, Preferences +using InteractiveUtils, Hwloc + +@info sprint(versioninfo) + +Preferences.set_preferences!("LuxLib", "instability_check" => "error") + +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all")) +const EXTRA_PKGS = PackageSpec[] +const EXTRA_DEV_PKGS = PackageSpec[] + +const LUXLIB_BLAS_BACKEND = lowercase(get(ENV, "LUXLIB_BLAS_BACKEND", "default")) +@assert LUXLIB_BLAS_BACKEND in ("default", "appleaccelerate", "blis", "mkl") +@info "Running tests with BLAS backend: $(LUXLIB_BLAS_BACKEND)" + +if (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") + if isdir(joinpath(@__DIR__, "../../LuxCUDA")) + @info "Using local LuxCUDA" + push!(EXTRA_DEV_PKGS, PackageSpec(; path=joinpath(@__DIR__, "../../LuxCUDA"))) + else + push!(EXTRA_PKGS, PackageSpec(; name="LuxCUDA")) + end +end +(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && + push!(EXTRA_PKGS, PackageSpec(; name="AMDGPU")) +(BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && + push!(EXTRA_PKGS, PackageSpec(; name="oneAPI")) +(BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && + push!(EXTRA_PKGS, PackageSpec(; name="Metal")) + +if !isempty(EXTRA_PKGS) || !isempty(EXTRA_DEV_PKGS) + @info "Installing Extra Packages for testing" EXTRA_PKGS EXTRA_DEV_PKGS + isempty(EXTRA_PKGS) || Pkg.add(EXTRA_PKGS) + isempty(EXTRA_DEV_PKGS) || Pkg.develop(EXTRA_DEV_PKGS) + Pkg.update() + Base.retry_load_extensions() + Pkg.instantiate() +end + +const LUXLIB_TEST_GROUP = get(ENV, "LUXLIB_TEST_GROUP", "all") +const RETESTITEMS_NWORKERS = parse( + Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 4)))) +const RETESTITEMS_NWORKER_THREADS = parse(Int, + get(ENV, "RETESTITEMS_NWORKER_THREADS", + string(max(Hwloc.num_virtual_cores() ÷ RETESTITEMS_NWORKERS, 1)))) + +@info "Running tests for group: $LUXLIB_TEST_GROUP with $RETESTITEMS_NWORKERS workers" + +using LuxLib + +ReTestItems.runtests( + LuxLib; tags=(LUXLIB_TEST_GROUP == "all" ? nothing : [Symbol(LUXLIB_TEST_GROUP)]), + nworkers=RETESTITEMS_NWORKERS, + nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=3600) diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl new file mode 100644 index 000000000..2ba51d0a0 --- /dev/null +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -0,0 +1,84 @@ +@testsetup module SharedTestSetup +import Reexport: @reexport + +using LuxLib, MLDataDevices +@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote, NNlib + +LuxTestUtils.jet_target_modules!(["LuxLib"]) + +const LUXLIB_BLAS_BACKEND = lowercase(get(ENV, "LUXLIB_BLAS_BACKEND", "default")) + +if parse(Bool, get(ENV, "LUXLIB_LOAD_LOOPVEC", "true")) + import LoopVectorization, Octavian +end + +if LUXLIB_BLAS_BACKEND == "default" + @info "Using default BLAS backend: OpenBLAS" +elseif LUXLIB_BLAS_BACKEND == "appleaccelerate" + @info "Using AppleAccelerate BLAS backend" + using AppleAccelerate +elseif LUXLIB_BLAS_BACKEND == "blis" + @info "Using BLIS BLAS backend" + using BLISBLAS +elseif LUXLIB_BLAS_BACKEND == "mkl" + @info "Using MKL BLAS backend" + using MKL +else + error("Unknown BLAS backend: $(LUXLIB_BLAS_BACKEND)") +end + +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) + +if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" + using LuxCUDA +end + +if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" + using AMDGPU +end + +if BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi" + using oneAPI +end + +if BACKEND_GROUP == "all" || BACKEND_GROUP == "metal" + using Metal +end + +cpu_testing() = BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" +function cuda_testing() + return (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && + MLDataDevices.functional(CUDADevice) +end +function amdgpu_testing() + return (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && + MLDataDevices.functional(AMDGPUDevice) +end +function oneapi_testing() + return (BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && + MLDataDevices.functional(oneAPIDevice) +end +function metal_testing() + return (BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && + MLDataDevices.functional(MetalDevice) +end + +const MODES = begin + modes = [] + cpu_testing() && push!(modes, ("cpu", Array, false, true)) + cuda_testing() && push!(modes, ("cuda", CuArray, true, true)) + amdgpu_testing() && push!(modes, ("amdgpu", ROCArray, true, true)) + oneapi_testing() && push!(modes, ("oneapi", oneArray, true, false)) + metal_testing() && push!(modes, ("metal", MtlArray, true, false)) + modes +end + +generate_fixed_array(::Type{T}, sz...) where {T} = generate_fixed_array(T, sz) +function generate_fixed_array(::Type{T}, sz) where {T} + return reshape(T.(collect(1:prod(sz)) ./ prod(sz)), sz...) +end +generate_fixed_array(::Type{T}, sz::Int) where {T} = T.(collect(1:sz) ./ sz) + +export MODES, StableRNG, generate_fixed_array, BACKEND_GROUP + +end diff --git a/lib/LuxTestUtils/LICENSE b/lib/LuxTestUtils/LICENSE new file mode 100644 index 000000000..f7f6ca989 --- /dev/null +++ b/lib/LuxTestUtils/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023: Avik Pal. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lib/LuxTestUtils/Project.toml b/lib/LuxTestUtils/Project.toml new file mode 100644 index 000000000..92c319980 --- /dev/null +++ b/lib/LuxTestUtils/Project.toml @@ -0,0 +1,39 @@ +name = "LuxTestUtils" +uuid = "ac9de150-d08f-4546-94fb-7472b5760531" +authors = ["Avik Pal "] +version = "1.3.1" + +[deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +ADTypes = "1.8.1" +ArrayInterface = "7.9" +ChainRulesCore = "1.24.0" +ComponentArrays = "0.15.14" +DispatchDoctor = "0.4.12" +Enzyme = "0.13" +FiniteDiff = "2.23.1" +ForwardDiff = "0.10.36" +Functors = "0.4.11" +JET = "0.9.6" +MLDataDevices = "1.0.0" +ReverseDiff = "1.15.3" +Test = "1.10" +Tracker = "0.2.34" +Zygote = "0.6.70" +julia = "1.10" diff --git a/lib/LuxTestUtils/README.md b/lib/LuxTestUtils/README.md new file mode 100644 index 000000000..6715404a5 --- /dev/null +++ b/lib/LuxTestUtils/README.md @@ -0,0 +1,15 @@ +# LuxTestUtils.jl + +Utilities for testing [Lux.jl](http://lux.csail.mit.edu/). + +## Installation + +```julia +] add LuxTestUtils +``` + +> [!WARNING] +> +> This is a testing package. Hence, we don't use features like weak dependencies to reduce + load times. It is recommended that you exclusively use this package for testing and not + add a dependency to it in your main package Project.toml. diff --git a/lib/LuxTestUtils/src/LuxTestUtils.jl b/lib/LuxTestUtils/src/LuxTestUtils.jl new file mode 100644 index 000000000..795665cdd --- /dev/null +++ b/lib/LuxTestUtils/src/LuxTestUtils.jl @@ -0,0 +1,59 @@ +module LuxTestUtils + +using ArrayInterface: ArrayInterface +using ComponentArrays: ComponentArray, getdata, getaxes +using DispatchDoctor: allow_unstable +using Functors: Functors +using MLDataDevices: cpu_device, gpu_device, get_device, get_device_type, AbstractGPUDevice +using Test: Test, Error, Broken, Pass, Fail, get_testset, @testset, @test, @test_skip, + @test_broken, eval_test, Threw, Returned + +# Autodiff +using ADTypes: AutoEnzyme, AutoFiniteDiff, AutoTracker, AutoForwardDiff, AutoReverseDiff, + AutoZygote +using ChainRulesCore: ChainRulesCore +using FiniteDiff: FiniteDiff +using ForwardDiff: ForwardDiff +using ReverseDiff: ReverseDiff +using Tracker: Tracker +using Zygote: Zygote + +const CRC = ChainRulesCore +const FD = FiniteDiff + +# Check if JET will work +try + using JET: JET, JETTestFailure, get_reports, report_call, report_opt + global JET_TESTING_ENABLED = true +catch err + @error "`JET.jl` did not successfully precompile on $(VERSION). All `@jet` tests will \ + be skipped." maxlog=1 err=err + global JET_TESTING_ENABLED = false +end + +# Check if Enzyme will work +try + using Enzyme: Enzyme + __ftest(x) = x + Enzyme.autodiff(Enzyme.Reverse, __ftest, Enzyme.Active, Enzyme.Active(2.0)) + global ENZYME_TESTING_ENABLED = length(VERSION.prerelease) == 0 +catch err + global ENZYME_TESTING_ENABLED = false +end + +if !ENZYME_TESTING_ENABLED + @warn "`Enzyme.jl` is currently not functional on $(VERSION) either because it errored \ + of the current version is a prerelease. Enzyme tests will be skipped..." +end + +include("test_softfail.jl") +include("utils.jl") +include("autodiff.jl") +include("jet.jl") + +export AutoEnzyme, AutoFiniteDiff, AutoTracker, AutoForwardDiff, AutoReverseDiff, AutoZygote +export test_gradients, @test_gradients +export @jet, jet_target_modules! +export @test_softfail + +end diff --git a/lib/LuxTestUtils/src/autodiff.jl b/lib/LuxTestUtils/src/autodiff.jl new file mode 100644 index 000000000..f46136f53 --- /dev/null +++ b/lib/LuxTestUtils/src/autodiff.jl @@ -0,0 +1,248 @@ +# Zygote.jl +function gradient(f::F, ::AutoZygote, args...) where {F} + return map((xᵢ, dxᵢ) -> dxᵢ === nothing || xᵢ isa Number ? CRC.NoTangent() : dxᵢ, + args, Zygote.gradient(f, args...)) +end + +# FiniteDiff.jl +function gradient(f::F, ::AutoFiniteDiff, args...) where {F} + return gradient(f, FD.finite_difference_gradient, args...) +end + +# Enzyme.jl +function gradient(f::F, ::AutoEnzyme{Nothing}, args...) where {F} + return gradient(f, AutoEnzyme(; mode=Enzyme.Reverse), args...) +end + +function gradient(f::F, ad::AutoEnzyme{<:Enzyme.ReverseMode}, args...) where {F} + !ENZYME_TESTING_ENABLED && + return ntuple(Returns(GradientComputationSkipped()), length(args)) + + args_activity = map(args) do x + needs_gradient(x) && return Enzyme.Duplicated(x, Enzyme.make_zero(x)) + return Enzyme.Const(x) + end + Enzyme.autodiff(ad.mode, Enzyme.Const(f), Enzyme.Active, args_activity...) + return Tuple(map(enumerate(args)) do (i, x) + needs_gradient(x) && return args_activity[i].dval + return CRC.NoTangent() + end) +end + +function gradient(::F, ::AutoEnzyme{<:Enzyme.ForwardMode}, args...) where {F} + return error("AutoEnzyme{ForwardMode} is not supported yet.") +end + +# Tracker.jl +function gradient(f::F, ::AutoTracker, args...) where {F} + counter = 0 + tracked_args = map(args) do x + if needs_gradient(x) + counter += 1 + return Functors.fmap(Tracker.param, x; exclude=_tracker_leaf) + end + return x + end + + @assert counter>0 "No tracked arguments found in `gradient(f, AutoTracker, args...)`" + Tracker.back!(f(tracked_args...)) + + return Tuple(map(tracked_args) do x + needs_gradient(x) && return Functors.fmap(__tracker_grad, x; exclude=_tracker_leaf) + return CRC.NoTangent() + end) +end + +_tracker_leaf(x) = Functors.isleaf(x) +_tracker_leaf(::AbstractArray) = true + +__tracker_grad(x) = Tracker.grad(x) +__tracker_grad(x::ComponentArray) = ComponentArray(__tracker_grad(getdata(x)), getaxes(x)) + +# ReverseDiff.jl +function gradient(f::F, ::AutoReverseDiff, args...) where {F} + return gradient(f, ReverseDiff.gradient, args...) +end + +# ForwardDiff.jl +function gradient(f::F, ::AutoForwardDiff, args...) where {F} + return gradient(f, ForwardDiff.gradient, args...) +end + +function gradient(f::F, grad_fn::GFN, args...) where {F, GFN <: Function} + gs = Vector{Any}(undef, length(args)) + for i in 1:length(args) + _f, x = partial_function(f, i, args...) + if x isa AbstractArray + gs[i] = grad_fn(_f, x) + elseif x isa NamedTuple + __f, x_flat = flatten_gradient_computable(_f, x) + gs[i] = x_flat === nothing ? CRC.NoTangent() : NamedTuple(grad_fn(__f, x_flat)) + else + gs[i] = CRC.NoTangent() + end + end + return Tuple(gs) +end + +# Main Functionality to Test Gradient Correctness +""" + test_gradients(f, args...; skip_backends=[], broken_backends=[], kwargs...) + +Test the gradients of `f` with respect to `args` using the specified backends. + +| Backend | ADType | CPU | GPU | Notes | +|:-------------- |:------------------- |:--- |:--- |:----------------- | +| Zygote.jl | `AutoZygote()` | ✔ | ✔ | | +| Tracker.jl | `AutoTracker()` | ✔ | ✔ | | +| ReverseDiff.jl | `AutoReverseDiff()` | ✔ | ✖ | | +| ForwardDiff.jl | `AutoForwardDiff()` | ✔ | ✖ | `len ≤ 100` | +| FiniteDiff.jl | `AutoFiniteDiff()` | ✔ | ✖ | `len ≤ 100` | +| Enzyme.jl | `AutoEnzyme()` | ✔ | ✖ | Only Reverse Mode | + +## Arguments + + - `f`: The function to test the gradients of. + - `args`: The arguments to test the gradients of. Only `AbstractArray`s are considered + for gradient computation. Gradients wrt all other arguments are assumed to be + `NoTangent()`. + +## Keyword Arguments + + - `skip_backends`: A list of backends to skip. + - `broken_backends`: A list of backends to treat as broken. + - `soft_fail`: If `true`, then the test will be recorded as a `soft_fail` test. This + overrides any `broken` kwargs. Alternatively, a list of backends can be passed to + `soft_fail` to allow soft_fail tests for only those backends. + - `enzyme_set_runtime_activity`: If `true`, then activate runtime activity for Enzyme. + - `kwargs`: Additional keyword arguments to pass to `check_approx`. + +## Example + +```julia +julia> f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z) + +julia> x = (; t=rand(10), x=(z=[2.0],)) + +julia> test_gradients(f, 1.0, x, nothing) + +``` +""" +function test_gradients(f, args...; skip_backends=[], broken_backends=[], + soft_fail::Union{Bool, Vector}=false, + enzyme_set_runtime_activity::Bool=false, + # Internal kwargs start + source::LineNumberNode=LineNumberNode(0, nothing), + test_expr::Expr=:(check_approx(∂args, ∂args_gt; kwargs...)), + # Internal kwargs end + kwargs...) + # TODO: We should add a macro version that propagates the line number info and the test_expr + on_gpu = get_device_type(args) <: AbstractGPUDevice + total_length = mapreduce(__length, +, Functors.fleaves(args); init=0) + + # Choose the backends to test + backends = [] + push!(backends, AutoZygote()) + if !on_gpu + push!(backends, AutoReverseDiff()) + total_length ≤ 100 && push!(backends, AutoForwardDiff()) + total_length ≤ 100 && push!(backends, AutoFiniteDiff()) + # TODO: Move Enzyme out of here once it supports GPUs + if ENZYME_TESTING_ENABLED + mode = enzyme_set_runtime_activity ? + Enzyme.set_runtime_activity(Enzyme.Reverse) : + Enzyme.Reverse + push!(backends, AutoEnzyme(; mode)) + end + end + push!(backends, AutoTracker()) + + intersect_backends = intersect(broken_backends, skip_backends) + if !isempty(intersect_backends) + throw(ArgumentError("`broken_backends` and `skip_backends` cannot contain the same \ + backends -- $(intersect_backends).")) + end + + # Test the gradients + ∂args_gt = gradient(f, backends[1], args...) # Should be Zygote in most cases + + @assert (backends[1] ∉ broken_backends)&&(backends[1] ∉ skip_backends) "first backend cannot be broken or skipped" + + @testset "gradtest($(f))" begin + @testset "$(nameof(typeof(backends[1])))() vs $(nameof(typeof(backend)))()" for backend in backends[2:end] + local_test_expr = :([$(nameof(typeof(backend)))] - $(test_expr)) + + result = if check_ad_backend_in(backend, skip_backends) + Broken(:skipped, local_test_expr) + elseif (soft_fail isa Bool && soft_fail) || + (soft_fail isa Vector && check_ad_backend_in(backend, soft_fail)) + try + ∂args = allow_unstable() do + return gradient(f, backend, args...) + end + matched = check_approx(∂args, ∂args_gt; kwargs...) + if matched + Pass(:test, local_test_expr, nothing, nothing, source) + else + Broken(:test, local_test_expr) + end + catch + Broken(:test, local_test_expr) + end + elseif check_ad_backend_in(backend, broken_backends) + try + ∂args = allow_unstable() do + return gradient(f, backend, args...) + end + matched = check_approx(∂args, ∂args_gt; kwargs...) + if matched + Error(:test_unbroken, local_test_expr, matched, nothing, source) + else + Broken(:test, local_test_expr) + end + catch + Broken(:test, local_test_expr) + end + else + try + ∂args = allow_unstable() do + return gradient(f, backend, args...) + end + matched = check_approx(∂args, ∂args_gt; kwargs...) + if matched + Pass(:test, local_test_expr, nothing, nothing, source) + else + context = "\n ∂args: $(∂args)\n∂args_gt: $(∂args_gt)" + Fail( + :test, local_test_expr, matched, nothing, context, source, false) + end + catch err + err isa InterruptException && rethrow() + Error(:test, local_test_expr, err, Base.current_exceptions(), source) + end + end + Test.record(get_testset(), result) + end + end +end + +""" + @test_gradients(f, args...; kwargs...) + +See the documentation of [`test_gradients`](@ref) for more details. This macro provides +correct line information for the failing tests. +""" +macro test_gradients(exprs...) + exs = reorder_macro_kw_params(exprs) + kwarg_idx = findfirst(ex -> Meta.isexpr(ex, :kw), exs) + if kwarg_idx === nothing + args = [exs...] + kwargs = [] + else + args = [exs[1:(kwarg_idx - 1)]...] + kwargs = [exs[kwarg_idx:end]...] + end + push!(kwargs, Expr(:kw, :source, QuoteNode(__source__))) + push!(kwargs, Expr(:kw, :test_expr, QuoteNode(:(test_gradients($(exs...)))))) + return esc(:($(test_gradients)($(args...); $(kwargs...)))) +end diff --git a/lib/LuxTestUtils/src/jet.jl b/lib/LuxTestUtils/src/jet.jl new file mode 100644 index 000000000..23963bdda --- /dev/null +++ b/lib/LuxTestUtils/src/jet.jl @@ -0,0 +1,90 @@ +# Testing using JET.jl +const JET_TARGET_MODULES = Ref{Union{Nothing, Vector{String}}}(nothing) + +""" + jet_target_modules!(list::Vector{String}; force::Bool=false) + +This sets `target_modules` for all JET tests when using [`@jet`](@ref). +""" +function jet_target_modules!(list::Vector{String}; force::Bool=false) + if JET_TARGET_MODULES[] === nothing || (force && JET_TARGET_MODULES[] !== nothing) + JET_TARGET_MODULES[] = list + @info "JET_TARGET_MODULES set to $list" + return list + else + @info "JET_TARGET_MODULES is already set to $(JET_TARGET_MODULES[]). No changes \ + made. Use `force=true` to force-set the target modules." + end +end + +""" + @jet f(args...) call_broken=false opt_broken=false + +Run JET tests on the function `f` with the arguments `args...`. If `JET.jl` fails to +compile, then the macro will be a no-op. + +## Keyword Arguments + + - `call_broken`: Marks the test_call as broken. + - `opt_broken`: Marks the test_opt as broken. + +All additional arguments will be forwarded to `JET.@test_call` and `JET.@test_opt`. + +!!! tip + + Instead of specifying `target_modules` with every call, you can set global target + modules using [`jet_target_modules!`](@ref). + + ```julia + using LuxTestUtils + + jet_target_modules!(["Lux", "LuxLib"]) # Expects Lux and LuxLib to be present in the module calling `@jet` + ``` + +## Example + +```jldoctest +julia> @jet sum([1, 2, 3]) target_modules=(Base, Core) +Test Passed + +julia> @jet sum(1, 1) target_modules=(Base, Core) opt_broken=true call_broken=true +Test Broken + Expression: #= REPL[21]:1 =# JET.@test_opt target_modules = (Base, Core) sum(1, 1) +``` +""" +macro jet(expr, args...) + !JET_TESTING_ENABLED && return :() + + all_args, call_extras, opt_extras = [], [], [] + target_modules_set = false + for kwexpr in args + if Meta.isexpr(kwexpr, :(=)) + if kwexpr.args[1] == :call_broken + push!(call_extras, :(broken = $(kwexpr.args[2]))) + elseif kwexpr.args[1] == :opt_broken + push!(opt_extras, :(broken = $(kwexpr.args[2]))) + elseif kwexpr.args[1] == :broken + throw(ArgumentError("`broken` keyword argument is ambiguous. Use `call_broken` or `opt_broken` instead.")) + else + kwexpr.args[1] == :target_modules && (target_modules_set = true) + push!(all_args, kwexpr) + end + else + push!(all_args, kwexpr) + end + end + + if !target_modules_set && JET_TARGET_MODULES[] !== nothing + target_modules = getproperty.((__module__,), Tuple(Symbol.(JET_TARGET_MODULES[]))) + push!(all_args, :(target_modules = $target_modules)) + end + + push!(all_args, expr) + + ex_call = JET.call_test_ex(:report_call, Symbol("@test_call"), + vcat(call_extras, all_args), __module__, __source__) + ex_opt = JET.call_test_ex(:report_opt, Symbol("@test_opt"), + vcat(opt_extras, all_args), __module__, __source__) + + return Expr(:block, ex_call, ex_opt) +end diff --git a/lib/LuxTestUtils/src/test_softfail.jl b/lib/LuxTestUtils/src/test_softfail.jl new file mode 100644 index 000000000..7e2c9a255 --- /dev/null +++ b/lib/LuxTestUtils/src/test_softfail.jl @@ -0,0 +1,40 @@ +# Based off of the official `@test` macro +""" + @test_softfail expr + +Evaluate `expr` and record a test result. If `expr` throws an exception, the test +result will be recorded as an error. If `expr` returns a value, and it is not a boolean, +the test result will be recorded as an error. + +If the test result is false then the test will be recorded as a broken test, else it will be +recorded as a pass. +""" +macro test_softfail(ex) + Test.test_expr!("@test_softfail", ex) # Build the test expression + result = Test.get_test_result(ex, __source__) + ex = Expr(:inert, ex) + result = quote + do_softfail_test($result, $ex) + end + return result +end + +function do_softfail_test(result, orig_expr) + if isa(result, Test.Returned) + value = result.value + testres = if isa(value, Bool) + if value + Pass(:test, orig_expr, result.data, value, result.source) + else + Broken(:test, orig_expr) + end + else + Error(:test_nonbool, orig_expr, value, nothing, result.source) + end + else + @assert isa(result, Threw) + testres = Error(:test_throws, orig_expr, result.exception, + result.backtrace::Vector{Any}, result.source) + end + Test.record(get_testset(), testres) +end diff --git a/lib/LuxTestUtils/src/utils.jl b/lib/LuxTestUtils/src/utils.jl new file mode 100644 index 000000000..432750409 --- /dev/null +++ b/lib/LuxTestUtils/src/utils.jl @@ -0,0 +1,131 @@ +# Taken from https://github.com/JuliaLang/julia/pull/54653 +struct Fix{N, F, T} <: Function + f::F + x::T + + function Fix{N}(f::F, x) where {N, F} + if N isa Int && N < 1 + throw(ArgumentError("expected `N` in `Fix{N}` to be integer greater than 0, \ + but got $N")) + elseif !(N isa Union{Int, Symbol}) + throw(ArgumentError("expected type parameter in `Fix` to be `Int` or `Symbol`, \ + but got `$N::$(typeof(N))`")) + end + return new{N, Base._stable_typeof(f), Base._stable_typeof(x)}(f, x) + end +end +function Fix(f::F; kws...) where {F} + length(kws) != 1 && + throw(ArgumentError("`Fix` expects exactly one argument or keyword argument, but \ + got keywords `$(keys(kws))`")) + return Fix{only(keys(kws))}(f, only(values(kws))) +end + +function (f::Fix{N})(args::Vararg{Any, M}; kws...) where {N, M} + if N isa Symbol + N in keys(kws) && + throw(ArgumentError("found duplicate keyword argument `$N` passed to a `Fix` \ + function")) + f_kws = NamedTuple{(N,)}((f.x,)) + return f.f(args...; f_kws..., kws...) + else # Int + M < N - 1 && + throw(ArgumentError("expected at least $(N-1) arguments to a `Fix` function with `N=$(N)`, but got $M")) + return f.f( + args[begin:(begin + (N - 2))]..., f.x, args[(begin + (N - 1)):end]...; kws...) + end +end + +# Special cases for improved constant propagation +(f::Fix{1})(arg; kws...) = f.f(f.x, arg; kws...) +(f::Fix{2})(arg; kws...) = f.f(arg, f.x; kws...) + +function partial_function(f::F, idx::Int, args...) where {F} + partial_f = f + for (i, arg) in enumerate(args) + i == idx && continue + i < idx && (partial_f = Fix{1}(partial_f, arg)) + i > idx && (partial_f = Fix{2}(partial_f, arg)) + end + return partial_f, args[idx] +end + +function flatten_gradient_computable(f, nt::NamedTuple) + if needs_gradient(nt) + _f = (x) -> f(NamedTuple(x)) + xxx = nt |> cpu_device() |> ComponentArray |> get_device(nt) + eltype(xxx) == Any && + error("eltype of the flattened vector is `Any`. Check your inputs.") + return _f, xxx + end + return nothing, nothing +end + +needs_gradient(y) = all(Fix{2}(isa, AbstractArray), Functors.fleaves(y)) + +__length(x) = 0 +__length(x::AbstractArray) = length(x) +__length(::Number) = 1 + +# Equality Checks +struct GradientComputationSkipped end + +@generated function check_approx(x::X, y::Y; kwargs...) where {X, Y} + device = cpu_device() + (X == GradientComputationSkipped || Y == GradientComputationSkipped) && return :(true) + hasmethod(isapprox, (X, Y)) && return :(isapprox($(device)(x), $(device)(y); kwargs...)) + return :($(device)(x) == $(device)(y)) +end + +check_approx(x::Tuple, y::Tuple; kwargs...) = all(check_approx.(x, y; kwargs...)) + +function check_approx( + nt1::NamedTuple{fields}, nt2::NamedTuple{fields}; kwargs...) where {fields} + _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) + _check_approx(t::Tuple{Nothing, Nothing}) = true + return all(_check_approx, zip(values(nt1), values(nt2))) +end + +function check_approx(t1::NTuple{N, T}, t2::NTuple{N, T}; kwargs...) where {N, T} + _check_approx(xy) = check_approx(xy[1], xy[2]; kwargs...) + _check_approx(t::Tuple{Nothing, Nothing}) = true + return all(_check_approx, zip(t1, t2)) +end + +function check_approx(ca::ComponentArray, nt::NamedTuple; kwargs...) + return check_approx(NamedTuple(ca), nt; kwargs...) +end +function check_approx(nt::NamedTuple, ca::ComponentArray; kwargs...) + return check_approx(nt, NamedTuple(ca); kwargs...) +end + +check_approx(::Nothing, v::AbstractArray; kwargs...) = length(v) == 0 +check_approx(v::AbstractArray, ::Nothing; kwargs...) = length(v) == 0 +check_approx(v::NamedTuple, ::Nothing; kwargs...) = length(v) == 0 +check_approx(::Nothing, v::NamedTuple; kwargs...) = length(v) == 0 +check_approx(v::Tuple, ::Nothing; kwargs...) = length(v) == 0 +check_approx(::Nothing, v::Tuple; kwargs...) = length(v) == 0 +check_approx(x::AbstractArray, y::NamedTuple; kwargs...) = length(x) == 0 && length(y) == 0 +check_approx(x::NamedTuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 +check_approx(x::AbstractArray, y::Tuple; kwargs...) = length(x) == 0 && length(y) == 0 +check_approx(x::Tuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0 + +# Taken from discourse. normalizes the order of keyword arguments in a macro +function reorder_macro_kw_params(exs) + exs = Any[exs...] + i = findfirst([(ex isa Expr && ex.head == :parameters) for ex in exs]) + if i !== nothing + extra_kw_def = exs[i].args + for ex in extra_kw_def + push!(exs, ex isa Symbol ? Expr(:kw, ex, ex) : ex) + end + deleteat!(exs, i) + end + return Tuple(exs) +end + +function check_ad_backend_in(backend, backends) + backends_type = map(ArrayInterface.parameterless_type ∘ typeof, backends) + backend_type = ArrayInterface.parameterless_type(typeof(backend)) + return backend_type in backends_type +end diff --git a/lib/LuxTestUtils/test/Project.toml b/lib/LuxTestUtils/test/Project.toml new file mode 100644 index 000000000..3701de4ff --- /dev/null +++ b/lib/LuxTestUtils/test/Project.toml @@ -0,0 +1,17 @@ +[deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +MetaTesting = "9e32d19f-1e4f-477a-8631-b16c78aa0f56" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +CUDA = "5" +ComponentArrays = "0.15" +Hwloc = "3" +InteractiveUtils = "<0.0.1, 1" +MetaTesting = "0.1" +ReTestItems = "1.25" +Test = "1.10" diff --git a/lib/LuxTestUtils/test/runtests.jl b/lib/LuxTestUtils/test/runtests.jl new file mode 100644 index 000000000..365a77213 --- /dev/null +++ b/lib/LuxTestUtils/test/runtests.jl @@ -0,0 +1,8 @@ +using InteractiveUtils, Hwloc, ReTestItems, LuxTestUtils + +@info sprint(versioninfo) + +const RETESTITEMS_NWORKERS = parse( + Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 16)))) + +ReTestItems.runtests(LuxTestUtils; nworkers=RETESTITEMS_NWORKERS) diff --git a/lib/LuxTestUtils/test/unit_tests.jl b/lib/LuxTestUtils/test/unit_tests.jl new file mode 100644 index 000000000..a76a1c135 --- /dev/null +++ b/lib/LuxTestUtils/test/unit_tests.jl @@ -0,0 +1,71 @@ +@testitem "@jet" begin + LuxTestUtils.jet_target_modules!(["LuxTestUtils"]) + + @jet sum([1, 2, 3]) target_modules=(Base, Core) +end + +@testitem "test_gradients" begin + using MetaTesting, ComponentArrays + + f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z) + + x = (; t=rand(10), x=(z=[2.0],)) + + test_gradients(f, 1.0, x, nothing) + + test_gradients(f, 1.0, x, nothing; skip_backends=[AutoTracker()]) + @test_gradients(f, 1.0, x, nothing; skip_backends=[AutoTracker()]) + + @test errors() do + test_gradients(f, 1.0, x, nothing; broken_backends=[AutoTracker()]) + end + + @test errors() do + @test_gradients(f, 1.0, x, nothing; broken_backends=[AutoTracker()]) + end + + @test_throws ArgumentError test_gradients( + f, 1.0, x, nothing; broken_backends=[AutoTracker()], + skip_backends=[AutoTracker(), AutoEnzyme()]) + @test_throws ArgumentError @test_gradients(f, 1.0, x, nothing; + broken_backends=[AutoTracker()], + skip_backends=[AutoTracker(), AutoEnzyme()]) + + test_gradients(f, 1.0, x, nothing; soft_fail=[AutoTracker()]) + @test_gradients(f, 1.0, x, nothing; soft_fail=[AutoTracker()]) + + test_gradients(f, 1.0, x, nothing; soft_fail=true) + @test_gradients(f, 1.0, x, nothing; soft_fail=true) + + x_ca = ComponentArray(x) + + test_gradients(f, 1.0, x_ca, nothing) + @test_gradients(f, 1.0, x_ca, nothing) + + x_2 = (; t=x.t', x=(z=x.x.z',)) + + test_gradients(f, 1.0, x_2, nothing) + @test_gradients(f, 1.0, x_2, nothing) +end + +@testitem "test_gradients (CUDA.jl)" skip=:(using CUDA; !CUDA.functional()) begin + using CUDA + + f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z) + + x = (; t=cu(rand(10)), x=(z=cu([2.0]),)) + + test_gradients(f, 1.0, x, nothing) + + test_gradients(f, 1.0, x, nothing; skip_backends=[AutoTracker()]) +end + +@testitem "@test_softfail" begin + using MetaTesting + + @test errors() do + @test_softfail 1 + 1 + end + @test_softfail 1 + 1 == 2 + @test_softfail 1 + 1 < 2 +end diff --git a/lib/MLDataDevices/LICENSE b/lib/MLDataDevices/LICENSE new file mode 100644 index 000000000..e87b80c0d --- /dev/null +++ b/lib/MLDataDevices/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Avik Pal and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml new file mode 100644 index 000000000..68d43257b --- /dev/null +++ b/lib/MLDataDevices/Project.toml @@ -0,0 +1,68 @@ +name = "MLDataDevices" +uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +authors = ["Avik Pal and contributors"] +version = "1.4.2" + +[deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[weakdeps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" + +[extensions] +MLDataDevicesAMDGPUExt = "AMDGPU" +MLDataDevicesCUDAExt = "CUDA" +MLDataDevicesChainRulesCoreExt = "ChainRulesCore" +MLDataDevicesFillArraysExt = "FillArrays" +MLDataDevicesGPUArraysExt = "GPUArrays" +MLDataDevicesMLUtilsExt = "MLUtils" +MLDataDevicesMetalExt = ["GPUArrays", "Metal"] +MLDataDevicesReactantExt = "Reactant" +MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools" +MLDataDevicesReverseDiffExt = "ReverseDiff" +MLDataDevicesSparseArraysExt = "SparseArrays" +MLDataDevicesTrackerExt = "Tracker" +MLDataDevicesZygoteExt = "Zygote" +MLDataDevicescuDNNExt = ["CUDA", "cuDNN"] +MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] + +[compat] +AMDGPU = "0.9.6, 1" +Adapt = "4.1" +CUDA = "5.2" +ChainRulesCore = "1.23" +Compat = "4.15" +FillArrays = "1" +Functors = "0.4.8" +GPUArrays = "10, 11" +MLUtils = "0.4.4" +Metal = "1" +Preferences = "1.4" +Random = "1.10" +Reactant = "0.2" +RecursiveArrayTools = "3.8" +ReverseDiff = "1.15" +SparseArrays = "1.10" +Tracker = "0.2.34" +Zygote = "0.6.69" +cuDNN = "1.3" +julia = "1.10" +oneAPI = "1.5" diff --git a/lib/MLDataDevices/README.md b/lib/MLDataDevices/README.md new file mode 100644 index 000000000..2fda26602 --- /dev/null +++ b/lib/MLDataDevices/README.md @@ -0,0 +1,19 @@ +# MLDataDevices + +`MLDataDevices.jl` is a lightweight package defining rules for transferring data across +devices. It is used in deep learning frameworks such as [Lux.jl](https://lux.csail.mit.edu/) and [Flux.jl](https://fluxml.ai/). + +Currently we provide support for the following backends: + +1. `CPUDevice`: for CPUs -- no additional packages required. +2. `CUDADevice`: `CUDA.jl` for NVIDIA GPUs. +3. `AMDGPUDevice`: `AMDGPU.jl` for AMD ROCM GPUs. +4. `MetalDevice`: `Metal.jl` for Apple Metal GPUs. **(Experimental)** +5. `oneAPIDevice`: `oneAPI.jl` for Intel GPUs. **(Experimental)** +6. `XLADevice`: `Reactant.jl` for XLA Support. **(Experimental)** + +## Updating to v1.0 + + * Package was renamed from `LuxDeviceUtils.jl` to `MLDataDevices.jl`. + * `Lux(***)Device` has been renamed to `(***)Device`. + * `Lux(***)Adaptor` objects have been removed. Use `(***)Device` objects instead. diff --git a/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl b/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl new file mode 100644 index 000000000..ca275b55a --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl @@ -0,0 +1,97 @@ +module MLDataDevicesAMDGPUExt + +using Adapt: Adapt +using AMDGPU: AMDGPU +using MLDataDevices: MLDataDevices, Internal, AMDGPUDevice, CPUDevice, reset_gpu_device! +using Random: Random + +__init__() = reset_gpu_device!() + +# This code used to be in `LuxAMDGPU.jl`, but we no longer need that package. +const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing) + +function check_use_amdgpu!() + USE_AMD_GPU[] === nothing || return + + USE_AMD_GPU[] = AMDGPU.functional() + if USE_AMD_GPU[] && !AMDGPU.functional(:MIOpen) + @warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be \ + available." maxlog=1 + end + return +end + +MLDataDevices.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true +function MLDataDevices.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool + check_use_amdgpu!() + return USE_AMD_GPU[] +end + +Internal.with_device(::Type{AMDGPUDevice}, ::Nothing) = AMDGPUDevice(nothing) +function Internal.with_device(::Type{AMDGPUDevice}, id::Integer) + id > length(AMDGPU.devices()) && + throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) + old_dev = AMDGPU.device() + AMDGPU.device!(AMDGPU.devices()[id]) + device = AMDGPUDevice(AMDGPU.device()) + AMDGPU.device!(old_dev) + return device +end + +Internal.get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device) + +# Default RNG +MLDataDevices.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng() + +# Query Device from Array +function Internal.get_device(x::AMDGPU.AnyROCArray) + parent_x = parent(x) + parent_x === x && return AMDGPUDevice(AMDGPU.device(x)) + return Internal.get_device(parent_x) +end +Internal.get_device(::AMDGPU.rocRAND.RNG) = AMDGPUDevice(AMDGPU.device()) + +Internal.get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice +Internal.get_device_type(::AMDGPU.rocRAND.RNG) = AMDGPUDevice + +# Set Device +function MLDataDevices.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice) + return AMDGPU.device!(dev) +end +function MLDataDevices.set_device!(::Type{AMDGPUDevice}, id::Integer) + return MLDataDevices.set_device!(AMDGPUDevice, AMDGPU.devices()[id]) +end +function MLDataDevices.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Integer) + id = mod1(rank + 1, length(AMDGPU.devices())) + return MLDataDevices.set_device!(AMDGPUDevice, id) +end + +# unsafe_free! +function Internal.unsafe_free_internal!(::Type{AMDGPUDevice}, x::AbstractArray) + AMDGPU.unsafe_free!(x) + return +end + +# Device Transfer +Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) +function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray) + old_dev = AMDGPU.device() # remember the current device + dev = MLDataDevices.get_device(x) + if !(dev isa AMDGPUDevice) + AMDGPU.device!(to.device) + x_new = AMDGPU.roc(x) + AMDGPU.device!(old_dev) + return x_new + elseif AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device) + return x + else + AMDGPU.device!(to.device) + x_new = copy(x) + AMDGPU.device!(old_dev) + return x_new + end +end + +Adapt.adapt_storage(::CPUDevice, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() + +end diff --git a/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl new file mode 100644 index 000000000..9355b8171 --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl @@ -0,0 +1,90 @@ +module MLDataDevicesCUDAExt + +using Adapt: Adapt +using CUDA: CUDA +using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector, AbstractCuSparseArray +using MLDataDevices: MLDataDevices, Internal, CUDADevice, CPUDevice +using Random: Random + +Internal.with_device(::Type{CUDADevice}, ::Nothing) = CUDADevice(nothing) +function Internal.with_device(::Type{CUDADevice}, id::Integer) + id > length(CUDA.devices()) && + throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) + old_dev = CUDA.device() + CUDA.device!(id - 1) + device = CUDADevice(CUDA.device()) + CUDA.device!(old_dev) + return device +end + +Internal.get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1 + +# Default RNG +MLDataDevices.default_device_rng(::CUDADevice) = CUDA.default_rng() + +# Query Device from Array +function Internal.get_device(x::CUDA.AnyCuArray) + parent_x = parent(x) + parent_x === x && return CUDADevice(CUDA.device(x)) + return MLDataDevices.get_device(parent_x) +end +Internal.get_device(x::AbstractCuSparseArray) = CUDADevice(CUDA.device(x.nzVal)) +Internal.get_device(::CUDA.RNG) = CUDADevice(CUDA.device()) +Internal.get_device(::CUDA.CURAND.RNG) = CUDADevice(CUDA.device()) + +Internal.get_device_type(::Union{<:CUDA.AnyCuArray, <:AbstractCuSparseArray}) = CUDADevice +Internal.get_device_type(::CUDA.RNG) = CUDADevice +Internal.get_device_type(::CUDA.CURAND.RNG) = CUDADevice + +# Set Device +MLDataDevices.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) = CUDA.device!(dev) +function MLDataDevices.set_device!(::Type{CUDADevice}, id::Integer) + return MLDataDevices.set_device!(CUDADevice, collect(CUDA.devices())[id]) +end +function MLDataDevices.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer) + id = mod1(rank + 1, length(CUDA.devices())) + return MLDataDevices.set_device!(CUDADevice, id) +end + +# unsafe_free! +function Internal.unsafe_free_internal!(::Type{CUDADevice}, x::AbstractArray) + CUDA.unsafe_free!(x) + return +end + +# Device Transfer +Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) +function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray) + old_dev = CUDA.device() # remember the current device + dev = MLDataDevices.get_device(x) + if !(dev isa CUDADevice) + CUDA.device!(to.device) + x_new = CUDA.cu(x) + CUDA.device!(old_dev) + return x_new + elseif dev.device == to.device + return x + else + CUDA.device!(to.device) + x_new = copy(x) + CUDA.device!(old_dev) + return x_new + end +end + +Adapt.adapt_storage(::CPUDevice, rng::CUDA.RNG) = Random.default_rng() + +# Defining as extensions seems to case precompilation errors +@static if isdefined(CUDA.CUSPARSE, :SparseArrays) + function Adapt.adapt_storage(::CPUDevice, x::AbstractCuSparseMatrix) + return CUDA.CUSPARSE.SparseArrays.SparseMatrixCSC(x) + end + function Adapt.adapt_storage(::CPUDevice, x::AbstractCuSparseVector) + return CUDA.CUSPARSE.SparseArrays.SparseVector(x) + end +else + @warn "CUDA.CUSPARSE seems to have removed SparseArrays as a dependency. Please open \ + an issue in MLDataDevices.jl repository." +end + +end diff --git a/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl b/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl new file mode 100644 index 000000000..518ff205d --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl @@ -0,0 +1,27 @@ +module MLDataDevicesChainRulesCoreExt + +using Adapt: Adapt +using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, @non_differentiable + +using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type + +@non_differentiable get_device(::Any) +@non_differentiable get_device_type(::Any) + +function ChainRulesCore.rrule(::typeof(Adapt.adapt), to::AbstractDevice, x::AbstractArray) + dev = get_device(x) + y = Adapt.adapt_storage(to, x) + if dev === nothing || dev isa UnknownDevice + dev isa UnknownDevice && + @warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1 + ∇adapt_storage_unknown = Δ -> (NoTangent(), NoTangent(), Δ) + return y, ∇adapt_storage_unknown + else + ∇adapt_storage = let dev = dev, x = x + Δ -> (NoTangent(), NoTangent(), ProjectTo(x)(dev(Δ))) + end + return Adapt.adapt_storage(to, x), ∇adapt_storage + end +end + +end diff --git a/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl b/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl new file mode 100644 index 000000000..5a88241e6 --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesFillArraysExt.jl @@ -0,0 +1,10 @@ +module MLDataDevicesFillArraysExt + +using Adapt: Adapt +using FillArrays: FillArrays, AbstractFill +using MLDataDevices: MLDataDevices, CPUDevice, AbstractDevice + +Adapt.adapt_structure(::CPUDevice, x::AbstractFill) = x +Adapt.adapt_structure(to::AbstractDevice, x::AbstractFill) = Adapt.adapt(to, collect(x)) + +end diff --git a/lib/MLDataDevices/ext/MLDataDevicesGPUArraysExt.jl b/lib/MLDataDevices/ext/MLDataDevicesGPUArraysExt.jl new file mode 100644 index 000000000..a09a3861f --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesGPUArraysExt.jl @@ -0,0 +1,13 @@ +module MLDataDevicesGPUArraysExt + +using Adapt: Adapt +using GPUArrays: GPUArrays +using MLDataDevices: Internal, CPUDevice +using Random: Random + +Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng() + +Internal.get_device(rng::GPUArrays.RNG) = Internal.get_device(rng.state) +Internal.get_device_type(rng::GPUArrays.RNG) = Internal.get_device_type(rng.state) + +end diff --git a/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl new file mode 100644 index 000000000..be3d285b0 --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesMLUtilsExt.jl @@ -0,0 +1,38 @@ +module MLDataDevicesMLUtilsExt + +using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, + MetalDevice, oneAPIDevice, XLADevice, DeviceIterator +using MLUtils: MLUtils, DataLoader + +for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, XLADevice) + @eval function (D::$(dev))(dataloader::DataLoader) + if dataloader.parallel + if dataloader.buffer + @warn "Using `buffer=true` for parallel DataLoader with automatic device \ + transfer is currently not implemented. Ignoring `buffer=true`." + end + + # Mostly from https://github.com/JuliaML/MLUtils.jl/blob/main/src/eachobs.jl + data = MLUtils.ObsView(dataloader.data) + data = dataloader.shuffle ? MLUtils.shuffleobs(data) : data + data = if dataloader.batchsize > 0 + MLUtils.BatchView( + data; dataloader.batchsize, dataloader.partial, dataloader.collate) + else + data + end + + return DeviceIterator(identity, eachobsparallel(D, data)) + end + return DeviceIterator(D, dataloader) + end +end + +function eachobsparallel(dev::AbstractDevice, data) + return MLUtils.Loader(1:MLUtils.numobs(data)) do ch, i + obs = MLUtils.getobs(data, i) + put!(ch, dev(obs)) + end +end + +end diff --git a/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl b/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl new file mode 100644 index 000000000..e5eb16dd5 --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl @@ -0,0 +1,30 @@ +module MLDataDevicesMetalExt + +using Adapt: Adapt +using GPUArrays: GPUArrays +using MLDataDevices: MLDataDevices, Internal, MetalDevice, reset_gpu_device! +using Metal: Metal, MtlArray + +__init__() = reset_gpu_device!() + +MLDataDevices.loaded(::Union{MetalDevice, Type{<:MetalDevice}}) = true +MLDataDevices.functional(::Union{MetalDevice, Type{<:MetalDevice}}) = Metal.functional() + +# Default RNG +MLDataDevices.default_device_rng(::MetalDevice) = GPUArrays.default_rng(MtlArray) + +# Query Device from Array +Internal.get_device(::MtlArray) = MetalDevice() + +Internal.get_device_type(::MtlArray) = MetalDevice + +# unsafe_free! +function Internal.unsafe_free_internal!(::Type{MetalDevice}, x::AbstractArray) + Metal.unsafe_free!(x) + return +end + +# Device Transfer +Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x) + +end diff --git a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl new file mode 100644 index 000000000..3abc8fca2 --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl @@ -0,0 +1,26 @@ +module MLDataDevicesReactantExt + +using Adapt: Adapt +using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice +using Reactant: Reactant, RArray + +MLDataDevices.loaded(::Union{XLADevice, Type{<:XLADevice}}) = true +MLDataDevices.functional(::Union{XLADevice, Type{<:XLADevice}}) = true + +# Default RNG: Forward to CPU, we will compile it +function MLDataDevices.default_device_rng(::XLADevice) + return MLDataDevices.default_device_rng(CPUDevice()) +end + +# Query Device from Array +Internal.get_device(::RArray) = XLADevice() + +Internal.get_device_type(::RArray) = XLADevice + +# unsafe_free! +Internal.unsafe_free_internal!(::Type{XLADevice}, x::AbstractArray) = nothing + +# Device Transfer +Adapt.adapt_storage(::XLADevice, x::AbstractArray) = Reactant.to_rarray(x) + +end diff --git a/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl b/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl new file mode 100644 index 000000000..f0b29a2d0 --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesRecursiveArrayToolsExt.jl @@ -0,0 +1,24 @@ +module MLDataDevicesRecursiveArrayToolsExt + +using Adapt: Adapt, adapt +using MLDataDevices: MLDataDevices, Internal, AbstractDevice +using RecursiveArrayTools: VectorOfArray, DiffEqArray + +# We want to preserve the structure +function Adapt.adapt_structure(to::AbstractDevice, x::VectorOfArray) + return VectorOfArray(map(Base.Fix1(adapt, to), x.u)) +end + +function Adapt.adapt_structure(to::AbstractDevice, x::DiffEqArray) + # Don't move the `time` to the GPU + return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) +end + +for op in (:get_device, :get_device_type) + @eval function Internal.$(op)(x::Union{VectorOfArray, DiffEqArray}) + length(x.u) == 0 && return $(op == :get_device ? nothing : Nothing) + return mapreduce(Internal.$(op), Internal.combine_devices, x.u) + end +end + +end diff --git a/lib/MLDataDevices/ext/MLDataDevicesReverseDiffExt.jl b/lib/MLDataDevices/ext/MLDataDevicesReverseDiffExt.jl new file mode 100644 index 000000000..eeb944290 --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesReverseDiffExt.jl @@ -0,0 +1,13 @@ +module MLDataDevicesReverseDiffExt + +using MLDataDevices: Internal +using ReverseDiff: ReverseDiff + +for op in (:get_device, :get_device_type) + @eval begin + Internal.$(op)(x::ReverseDiff.TrackedArray) = Internal.$(op)(ReverseDiff.value(x)) + Internal.$(op)(x::AbstractArray{<:ReverseDiff.TrackedReal}) = Internal.$(op)(ReverseDiff.value.(x)) + end +end + +end diff --git a/lib/MLDataDevices/ext/MLDataDevicesSparseArraysExt.jl b/lib/MLDataDevices/ext/MLDataDevicesSparseArraysExt.jl new file mode 100644 index 000000000..a52871f74 --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesSparseArraysExt.jl @@ -0,0 +1,9 @@ +module MLDataDevicesSparseArraysExt + +using Adapt: Adapt +using MLDataDevices: CPUDevice +using SparseArrays: AbstractSparseArray + +Adapt.adapt_storage(::CPUDevice, x::AbstractSparseArray) = x + +end diff --git a/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl b/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl new file mode 100644 index 000000000..f9b90d9cb --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesTrackerExt.jl @@ -0,0 +1,23 @@ +module MLDataDevicesTrackerExt + +using Adapt: Adapt +using MLDataDevices: Internal, AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice +using Tracker: Tracker + +for op in (:get_device, :get_device_type) + @eval Internal.$(op)(x::Tracker.TrackedArray) = Internal.$(op)(Tracker.data(x)) + @eval Internal.$(op)(x::AbstractArray{<:Tracker.TrackedReal}) = Internal.$(op)(Tracker.data.(x)) +end + +Internal.special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true + +for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice, + CUDADevice{Nothing}, MetalDevice, oneAPIDevice) + @eval function Adapt.adapt_storage(to::$(T), x::AbstractArray{<:Tracker.TrackedReal}) + @warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting \ + to Tracker.TrackedArray." maxlog=1 + return to(Tracker.collect(x)) + end +end + +end diff --git a/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl new file mode 100644 index 000000000..1b705c582 --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesZygoteExt.jl @@ -0,0 +1,10 @@ +module MLDataDevicesZygoteExt + +using Adapt: Adapt +using MLDataDevices: AbstractDevice, CPUDevice +using Zygote: OneElement + +Adapt.adapt_structure(::CPUDevice, x::OneElement) = x +Adapt.adapt_structure(to::AbstractDevice, x::OneElement) = Adapt.adapt(to, collect(x)) + +end diff --git a/lib/MLDataDevices/ext/MLDataDevicescuDNNExt.jl b/lib/MLDataDevices/ext/MLDataDevicescuDNNExt.jl new file mode 100644 index 000000000..a332c7ad3 --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicescuDNNExt.jl @@ -0,0 +1,36 @@ +module MLDataDevicescuDNNExt + +using CUDA: CUDA +using cuDNN: cuDNN +using MLDataDevices: MLDataDevices, CUDADevice, reset_gpu_device! + +__init__() = reset_gpu_device!() + +const USE_CUDA_GPU = Ref{Union{Nothing, Bool}}(nothing) + +function _check_use_cuda!() + USE_CUDA_GPU[] === nothing || return + + USE_CUDA_GPU[] = CUDA.functional() + if USE_CUDA_GPU[] + if !cuDNN.has_cudnn() + @warn """ + cuDNN is not functional. Some functionality will not be available. + """ maxlog=1 + + # We make the device selectable only if cuDNN is functional + # to avoid issues with convolutions and other deep learning operations + USE_CUDA_GPU[] = false + end + end + return +end + +MLDataDevices.loaded(::Union{CUDADevice, Type{<:CUDADevice}}) = true + +function MLDataDevices.functional(::Union{CUDADevice, Type{<:CUDADevice}})::Bool + _check_use_cuda!() + return USE_CUDA_GPU[] +end + +end diff --git a/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl b/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl new file mode 100644 index 000000000..75fc2f035 --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl @@ -0,0 +1,51 @@ +module MLDataDevicesoneAPIExt + +using Adapt: Adapt +using GPUArrays: GPUArrays +using MLDataDevices: MLDataDevices, Internal, oneAPIDevice, reset_gpu_device! +using oneAPI: oneAPI, oneArray, oneL0 + +const SUPPORTS_FP64 = Dict{oneL0.ZeDevice, Bool}() + +function __init__() + reset_gpu_device!() + for dev in oneAPI.devices() + SUPPORTS_FP64[dev] = oneL0.module_properties(dev).fp64flags & + oneL0.ZE_DEVICE_MODULE_FLAG_FP64 == + oneL0.ZE_DEVICE_MODULE_FLAG_FP64 + end +end + +MLDataDevices.loaded(::Union{oneAPIDevice, Type{<:oneAPIDevice}}) = true +function MLDataDevices.functional(::Union{oneAPIDevice, Type{<:oneAPIDevice}}) + return oneAPI.functional() +end + +# Default RNG +MLDataDevices.default_device_rng(::oneAPIDevice) = GPUArrays.default_rng(oneArray) + +# Query Device from Array +Internal.get_device(::oneArray) = oneAPIDevice() + +Internal.get_device_type(::oneArray) = oneAPIDevice + +# unsafe_free! +function Internal.unsafe_free_internal!(::Type{oneAPIDevice}, x::AbstractArray) + oneAPI.unsafe_free!(x) + return +end + +# Device Transfer +for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) + @eval function Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray{$(T1)}) + if !SUPPORTS_FP64[oneAPI.device()] + @warn LazyString( + "Double type is not supported on this device. Using `", $(T2), "` instead.") + return oneArray{$(T2)}(x) + end + return oneArray(x) + end +end +Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray) = oneArray(x) + +end diff --git a/lib/MLDataDevices/src/MLDataDevices.jl b/lib/MLDataDevices/src/MLDataDevices.jl new file mode 100644 index 000000000..108d8bf78 --- /dev/null +++ b/lib/MLDataDevices/src/MLDataDevices.jl @@ -0,0 +1,31 @@ +module MLDataDevices + +using Adapt: Adapt +using Functors: Functors, fleaves +using Preferences: @delete_preferences!, @load_preference, @set_preferences! +using Random: AbstractRNG, Random +using Compat: @compat + +abstract type AbstractDevice <: Function end +abstract type AbstractCPUDevice <: AbstractDevice end +abstract type AbstractAcceleratorDevice <: AbstractDevice end +abstract type AbstractGPUDevice <: AbstractAcceleratorDevice end + +include("public.jl") +include("iterator.jl") +include("internal.jl") + +export gpu_backend!, supported_gpu_backends, reset_gpu_device! +export default_device_rng +export gpu_device, cpu_device, xla_device + +export CPUDevice +export CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice +export XLADevice +export get_device, get_device_type + +export DeviceIterator + +@compat(public, (isleaf,)) + +end diff --git a/lib/MLDataDevices/src/internal.jl b/lib/MLDataDevices/src/internal.jl new file mode 100644 index 000000000..5da37ac20 --- /dev/null +++ b/lib/MLDataDevices/src/internal.jl @@ -0,0 +1,224 @@ +module Internal + +using Functors: fmap +using Preferences: load_preference +using Random: AbstractRNG + +using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice, + MetalDevice, oneAPIDevice, XLADevice, UnknownDevice, + supported_gpu_backends, GPU_DEVICES, loaded, functional + +for dev in (CPUDevice, MetalDevice, oneAPIDevice) + msg = "`device_id` is not applicable for `$dev`." + @eval begin + with_device(::Type{$dev}, ::Nothing) = $dev() + function with_device(::Type{$dev}, device_id) + @warn $(msg) maxlog=1 + return $dev() + end + end +end + +for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) + tpkg = name === :CPU ? "" : string(name) + ldev = Symbol(name, :Device) + @eval begin + get_device_name(::Union{$ldev, Type{<:$ldev}}) = $(string(name)) + get_triggerpkg_name(::Union{$ldev, Type{<:$ldev}}) = $(tpkg) + end +end +get_device_name(::XLADevice) = "XLA" +get_triggerpkg_name(::XLADevice) = "Reactant" + +for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing}, + MetalDevice, oneAPIDevice, XLADevice) + @eval get_device_id(::$(T)) = nothing +end + +struct DeviceSelectionException <: Exception + dev::String +end + +function Base.showerror(io::IO, d::DeviceSelectionException) + return print(io, "DeviceSelectionException: No functional $(d.dev) device found!") +end + +function get_gpu_device(; force::Bool) + backend = load_preference(MLDataDevices, "gpu_backend", nothing) + + # If backend set with preferences, use it + if backend !== nothing + allowed_backends = supported_gpu_backends() + if backend ∉ allowed_backends + @warn "`gpu_backend` preference is set to $backend, which is not a valid \ + backend. Valid backends are $allowed_backends. Defaulting to automatic \ + GPU Backend selection." maxlog=1 + else + @debug "Using GPU backend set in preferences: $backend." + idx = findfirst(isequal(backend), allowed_backends) + device = GPU_DEVICES[idx] + if !loaded(device) + @warn "Trying to use backend: $(get_device_name(device)) but the trigger \ + package $(get_triggerpkg_name(device)) is not loaded. Ignoring the \ + Preferences backend!!! Please load the package and call this \ + function again to respect the Preferences backend." maxlog=1 + else + if functional(device) + @debug "Using GPU backend: $(get_device_name(device))." + return device + else + @warn "GPU backend: $(get_device_name(device)) set via Preferences.jl \ + is not functional. Defaulting to automatic GPU Backend \ + selection." maxlog=1 + end + end + end + end + + @debug "Running automatic GPU backend selection..." + for device in GPU_DEVICES + if loaded(device) + @debug "Trying backend: $(get_device_name(device))." + if functional(device) + @debug "Using GPU backend: $(get_device_name(device))." + return device + end + @debug "GPU backend: $(get_device_name(device)) is not functional." + else + @debug "Trigger package for backend ($(get_device_name(device))): \ + $(get_triggerpkg_name(device)) not loaded." + end + end + + force && throw(DeviceSelectionException("GPU")) + @warn """No functional GPU backend found! Defaulting to CPU. + + 1. If no GPU is available, nothing needs to be done. + 2. If GPU is available, load the corresponding trigger package. + a. `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for NVIDIA CUDA Support. + b. `AMDGPU.jl` for AMD GPU ROCM Support. + c. `Metal.jl` for Apple Metal GPU Support. (Experimental) + d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)""" maxlog=1 + return CPUDevice +end + +special_aos(::AbstractArray) = false + +recursive_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number) + +combine_devices(::Nothing, ::Nothing) = nothing +combine_devices(::Nothing, dev::AbstractDevice) = dev +combine_devices(dev::AbstractDevice, ::Nothing) = dev +function combine_devices(dev1::AbstractDevice, dev2::AbstractDevice) + dev1 == dev2 && return dev1 + dev1 isa UnknownDevice && return dev2 + dev2 isa UnknownDevice && return dev1 + throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2).")) +end + +combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing +combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractDevice} = T +combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractDevice} = T +combine_devices(::Type{T}, ::Type{UnknownDevice}) where {T <: AbstractDevice} = T +combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractDevice} = T +combine_devices(::Type{UnknownDevice}, ::Type{T}) where {T <: AbstractDevice} = T +combine_devices(::Type{UnknownDevice}, ::Type{UnknownDevice}) = UnknownDevice +function combine_devices(T1::Type{<:AbstractDevice}, T2::Type{<:AbstractDevice}) + throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2).")) +end + +for op in (:get_device, :get_device_type) + cpu_ret_val = op == :get_device ? CPUDevice() : CPUDevice + unknown_ret_val = op == :get_device ? UnknownDevice() : UnknownDevice + not_assigned_msg = "AbstractArray has some undefined references. Giving up, returning \ + $(unknown_ret_val)..." + + @eval begin + function $(op)(x::AbstractArray{T}) where {T} + if recursive_array_eltype(T) + if any(!isassigned(x, i) for i in eachindex(x)) + @warn $(not_assigned_msg) + return $(unknown_ret_val) + end + return mapreduce(MLDataDevices.$(op), combine_devices, x) + end + if hasmethod(parent, Tuple{typeof(x)}) + parent_x = parent(x) + parent_x === x && return $(cpu_ret_val) + return $(op)(parent_x) + end + return $(cpu_ret_val) + end + + function $(op)(x::Union{Tuple, NamedTuple}) + length(x) == 0 && return $(op == :get_device ? nothing : Nothing) + return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, values(x)) + end + + function $(op)(f::F) where {F <: Function} + Base.issingletontype(F) && + return $(op == :get_device ? UnknownDevice() : UnknownDevice) + return unrolled_mapreduce(MLDataDevices.$(op), combine_devices, + map(Base.Fix1(getfield, f), fieldnames(F))) + end + end + + for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange) + @eval $(op)(::$(T)) = $(op == :get_device ? nothing : Nothing) + end +end + +get_device(_) = UnknownDevice() +get_device_type(_) = UnknownDevice + +fast_structure(::AbstractArray) = true +fast_structure(::Union{Tuple, NamedTuple}) = true +for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange) + @eval fast_structure(::$(T)) = true +end +fast_structure(::Function) = true +fast_structure(_) = false + +function unrolled_mapreduce(f::F, op::O, itr) where {F, O} + return unrolled_mapreduce(f, op, itr, static_length(itr)) +end + +function unrolled_mapreduce(::F, ::O, _, ::Val{0}) where {F, O} + error("Cannot unroll over an empty iterator.") +end + +unrolled_mapreduce(f::F, ::O, itr, ::Val{1}) where {F, O} = f(only(itr)) + +@generated function unrolled_mapreduce(f::F, op::O, itr, ::Val{N}) where {F, O, N} + syms = [gensym("f_itr_$(i)") for i in 1:N] + op_syms = [gensym("op_$(i)") for i in 1:(N - 1)] + f_applied = [:($(syms[i]) = f(itr[$i])) for i in 1:N] + combine_expr = [:($(op_syms[1]) = op($(syms[1]), $(syms[2])))] + for i in 2:(N - 1) + push!(combine_expr, :($(op_syms[i]) = op($(op_syms[i - 1]), $(syms[i + 1])))) + end + return quote + $(Expr(:meta, :inline)) + $(Expr(:inbounds, true)) + $(Expr(:block, f_applied...)) + $(Expr(:inbounds, :pop)) + $(Expr(:block, combine_expr...)) + return $(op_syms[end]) + end +end + +function unsafe_free_internal!(x::AbstractArray) + unsafe_free_internal!(MLDataDevices.get_device_type(x), x) + return +end +unsafe_free_internal!(::Type, x::AbstractArray) = nothing +unsafe_free_internal!(_) = nothing + +function unsafe_free!(x) + fmap(unsafe_free_internal!, x) + return +end + +static_length(t::Tuple) = Val(length(t)) + +end diff --git a/lib/MLDataDevices/src/iterator.jl b/lib/MLDataDevices/src/iterator.jl new file mode 100644 index 000000000..af3c08193 --- /dev/null +++ b/lib/MLDataDevices/src/iterator.jl @@ -0,0 +1,73 @@ +""" + DeviceIterator(dev::AbstractDevice, iterator) + +Create a `DeviceIterator` that iterates through the provided `iterator` via `iterate`. Upon +each iteration, the current batch is copied to the device `dev`, and the previous iteration +is marked as freeable from GPU memory (via `unsafe_free!`) (no-op for a CPU device). + +The conversion follows the same semantics as `dev()`. + +!!! tip "Similarity to `CUDA.CuIterator`" + + The design inspiration was taken from `CUDA.CuIterator` and was generalized to work with + other backends and more complex iterators (using `Functors`). + +!!! tip "`MLUtils.DataLoader`" + + Calling `dev(::MLUtils.DataLoader)` will automatically convert the dataloader to use the + same semantics as `DeviceIterator`. This is generally preferred over looping over the + dataloader directly and transferring the data to the device. + +## Examples + +The following was run on a computer with an NVIDIA GPU. + +```julia-repl +julia> using MLDataDevices, MLUtils + +julia> X = rand(Float64, 3, 33); + +julia> dataloader = DataLoader(X; batchsize=13, shuffle=false); + +julia> for (i, x) in enumerate(dataloader) + @show i, summary(x) + end +(i, summary(x)) = (1, "3×13 Matrix{Float64}") +(i, summary(x)) = (2, "3×13 Matrix{Float64}") +(i, summary(x)) = (3, "3×7 Matrix{Float64}") + +julia> for (i, x) in enumerate(CUDADevice()(dataloader)) + @show i, summary(x) + end +(i, summary(x)) = (1, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}") +(i, summary(x)) = (2, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}") +(i, summary(x)) = (3, "3×7 CuArray{Float32, 2, CUDA.DeviceMemory}") +``` +""" +struct DeviceIterator{D <: Function, I} + dev::D + iterator::I +end + +function Base.iterate(c::DeviceIterator) + item = iterate(c.iterator) + item === nothing && return nothing + batch, next_state = item + dev_batch = c.dev(batch) + return dev_batch, (next_state, dev_batch) +end + +function Base.iterate(c::DeviceIterator, (state, prev_batch)) + item = iterate(c.iterator, state) + item === nothing && return nothing + batch, next_state = item + Internal.unsafe_free!(prev_batch) # free the previous batch + dev_batch = c.dev(batch) + return dev_batch, (next_state, dev_batch) +end + +Base.IteratorSize(::Type{DeviceIterator{D, I}}) where {D, I} = Base.IteratorSize(I) +Base.length(c::DeviceIterator) = length(c.iterator) +Base.axes(c::DeviceIterator) = axes(c.iterator) + +Base.IteratorEltype(::Type{DeviceIterator{D, I}}) where {D, I} = Base.EltypeUnknown() diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl new file mode 100644 index 000000000..6440ddbe7 --- /dev/null +++ b/lib/MLDataDevices/src/public.jl @@ -0,0 +1,396 @@ +struct CPUDevice <: AbstractCPUDevice end + +@kwdef struct CUDADevice{D} <: AbstractGPUDevice + device::D = nothing +end +@kwdef struct AMDGPUDevice{D} <: AbstractGPUDevice + device::D = nothing +end +struct MetalDevice <: AbstractGPUDevice end +struct oneAPIDevice <: AbstractGPUDevice end + +# TODO: Later we might want to add the client field here? +struct XLADevice <: AbstractAcceleratorDevice end + +# Fallback for when we don't know the device type +struct UnknownDevice <: AbstractDevice end + +""" + functional(x::AbstractDevice) -> Bool + functional(::Type{<:AbstractDevice}) -> Bool + +Checks if the device is functional. This is used to determine if the device can be used for +computation. Note that even if the backend is loaded (as checked via +[`MLDataDevices.loaded`](@ref)), the device may not be functional. + +Note that while this function is not exported, it is considered part of the public API. +""" +functional(x) = false +functional(::Union{CPUDevice, Type{<:CPUDevice}}) = true + +""" + loaded(x::AbstractDevice) -> Bool + loaded(::Type{<:AbstractDevice}) -> Bool + +Checks if the trigger package for the device is loaded. Trigger packages are as follows: + + - `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for NVIDIA CUDA Support. + - `AMDGPU.jl` for AMD GPU ROCM Support. + - `Metal.jl` for Apple Metal GPU Support. + - `oneAPI.jl` for Intel oneAPI GPU Support. +""" +loaded(x) = false +loaded(::Union{CPUDevice, Type{<:CPUDevice}}) = true + +# Order is important here +const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice) + +const GPU_DEVICE = Ref{Union{Nothing, AbstractDevice}}(nothing) + +""" + reset_gpu_device!() + +Resets the selected GPU device. This is useful when automatic GPU selection needs to be +run again. +""" +reset_gpu_device!() = (GPU_DEVICE[] = nothing) + +""" + supported_gpu_backends() -> Tuple{String, ...} + +Return a tuple of supported GPU backends. + +!!! warning + + This is not the list of functional backends on the system, but rather backends which + `MLDataDevices.jl` supports. +""" +supported_gpu_backends() = map(Internal.get_device_name, GPU_DEVICES) + +""" + gpu_device(device_id::Union{Nothing, Integer}=nothing; + force::Bool=false) -> AbstractDevice + +Selects GPU device based on the following criteria: + + 1. If `gpu_backend` preference is set and the backend is functional on the system, then + that device is selected. + 2. Otherwise, an automatic selection algorithm is used. We go over possible device + backends in the order specified by `supported_gpu_backends()` and select the first + functional backend. + 3. If no GPU device is functional and `force` is `false`, then `cpu_device()` is + invoked. + 4. If nothing works, an error is thrown. + +## Arguments + + - `device_id::Union{Nothing, Integer}`: The device id to select. If `nothing`, then we return + the last selected device or if none was selected then we run the autoselection and + choose the current device using `CUDA.device()` or `AMDGPU.device()` or similar. If + `Integer`, then we select the device with the given id. Note that this is `1`-indexed, in + contrast to the `0`-indexed `CUDA.jl`. For example, `id = 4` corresponds to + `CUDA.device!(3)`. + +!!! warning + + `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI` + and `CPU` backends, `device_id` is ignored and a warning is printed. + +!!! warning + + `gpu_device` won't select a CUDA device unless both CUDA.jl and cuDNN.jl are loaded. + This is to ensure that deep learning operations work correctly. + Nonetheless, if cuDNN is not loaded you can still manually create a + `CUDADevice` object and use it (e.g. `dev = CUDADevice()`). + +## Keyword Arguments + + - `force::Bool`: If `true`, then an error is thrown if no functional GPU + device is found. +""" +function gpu_device(device_id::Union{Nothing, <:Integer}=nothing; force::Bool=false, + force_gpu_usage::Union{Missing, Bool}=missing)::AbstractDevice + if force_gpu_usage !== missing + Base.depwarn( + "`force_gpu_usage` is deprecated and will be removed in v2. Use \ + `force` instead.", :gpu_device) + force = force_gpu_usage + end + + device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed.")) + + if GPU_DEVICE[] !== nothing + dev = GPU_DEVICE[] + if device_id === nothing + force && + !(dev isa AbstractGPUDevice) && + throw(Internal.DeviceSelectionException("GPU")) + return dev + else + selected_device_id = Internal.get_device_id(dev) + selected_device_id !== nothing && selected_device_id == device_id && return dev + end + end + + device_type = Internal.get_gpu_device(; force) + device = Internal.with_device(device_type, device_id) + GPU_DEVICE[] = device + + return device +end + +""" + gpu_backend!() = gpu_backend!("") + gpu_backend!(backend) = gpu_backend!(string(backend)) + gpu_backend!(backend::AbstractGPUDevice) + gpu_backend!(backend::String) + +Creates a `LocalPreferences.toml` file with the desired GPU backend. + +If `backend == ""`, then the `gpu_backend` preference is deleted. Otherwise, `backend` is +validated to be one of the possible backends and the preference is set to `backend`. + +If a new backend is successfully set, then the Julia session must be restarted for the +change to take effect. +""" +gpu_backend!(backend) = gpu_backend!(string(backend)) +gpu_backend!(backend::AbstractGPUDevice) = gpu_backend!(Internal.get_device_name(backend)) +gpu_backend!() = gpu_backend!("") +function gpu_backend!(backend::String) + if backend == "" + @delete_preferences!("gpu_backend") + @info "Deleted the local preference for `gpu_backend`. Restart Julia to use the \ + new backend." + return + end + + allowed_backends = supported_gpu_backends() + + set_backend = @load_preference("gpu_backend", nothing) + if set_backend == backend + @info "GPU backend is already set to $backend. No action is required." + return + end + + if backend ∉ allowed_backends + throw(ArgumentError("Invalid backend: $backend. Valid backends are $allowed_backends.")) + end + + @set_preferences!("gpu_backend"=>backend) + @info "GPU backend has been set to $backend. Restart Julia to use the new backend." + return +end + +""" + cpu_device() -> CPUDevice() + +Return a `CPUDevice` object which can be used to transfer data to CPU. +""" +cpu_device() = CPUDevice() + +""" + xla_device(; force::Bool=false) -> Union{XLADevice, CPUDevice} + +Return a `XLADevice` object if functional. Otherwise, throw an error if `force` is `true`. +Falls back to `CPUDevice` if `force` is `false`. + +!!! danger + + This is an experimental feature and might change without deprecations +""" +function xla_device(; force::Bool=false) + msg = "`XLADevice` is not loaded or not functional. Load `Reactant.jl` before calling \ + this function. Defaulting to CPU." + if loaded(XLADevice) + functional(XLADevice) && return XLADevice() + msg = "`XLADevice` is loaded but not functional. Defaulting to CPU." + end + force && throw(Internal.DeviceSelectionException("XLA")) + @warn msg maxlog=1 + return cpu_device() +end + +""" + default_device_rng(::AbstractDevice) + +Returns the default RNG for the device. This can be used to directly generate parameters +and states on the device using +[WeightInitializers.jl](https://github.com/LuxDL/WeightInitializers.jl). +""" +function default_device_rng(D::AbstractDevice) + return error("""`default_device_rng` not implemented for `$(typeof(D))`. This is \ + either because: + + 1. The default RNG for this device is not known / officially provided. + 2. The trigger package for the device ($(Internal.get_device_name(D)).jl) is \ + not loaded. + """) +end +default_device_rng(::CPUDevice) = Random.default_rng() + +const GET_DEVICE_ADMONITIONS = """ +!!! note + + Trigger Packages must be loaded for this to return the correct device. +""" + +# Query Device from Array +""" + get_device(x) -> dev::AbstractDevice | Exception | Nothing + +If all arrays (on the leaves of the structure) are on the same device, we return that +device. Otherwise, we throw an error. If the object is device agnostic, we return `nothing`. + +$(GET_DEVICE_ADMONITIONS) + +## Special Retuened Values + + - `nothing` -- denotes that the object is device agnostic. For example, scalar, abstract + range, etc. + - `UnknownDevice()` -- denotes that the device type is unknown + +See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch +based on device type. +""" +function get_device end + +""" + get_device_type(x) -> Type{<:AbstractDevice} | Exception | Type{Nothing} + +Similar to [`get_device`](@ref) but returns the type of the device instead of the device +itself. This value is often a compile time constant and is recommended to be used instead +of [`get_device`](@ref) where ever defining dispatches based on the device type. + +$(GET_DEVICE_ADMONITIONS) + +## Special Retuened Values + + - `Nothing` -- denotes that the object is device agnostic. For example, scalar, abstract + range, etc. + - `UnknownDevice` -- denotes that the device type is unknown +""" +function get_device_type end + +# Set the device +const SET_DEVICE_DOCS = """ +Set the device for the given type. This is a no-op for `CPUDevice`. For `CUDADevice` +and `AMDGPUDevice`, it prints a warning if the corresponding trigger package is not +loaded. + +Currently, `MetalDevice` and `oneAPIDevice` don't support setting the device. +""" + +const SET_DEVICE_DANGER = """ +!!! danger + + This specific function should be considered experimental at this point and is currently + provided to support distributed training in Lux. As such please use + `Lux.DistributedUtils` instead of using this function. +""" + +""" + set_device!(T::Type{<:AbstractDevice}, dev_or_id) + +$SET_DEVICE_DOCS + +## Arguments + + - `T::Type{<:AbstractDevice}`: The device type to set. + - `dev_or_id`: Can be the device from the corresponding package. For example for CUDA it + can be a `CuDevice`. If it is an integer, it is the device id to set. This is + `1`-indexed. + +$SET_DEVICE_DANGER +""" +function set_device!(::Type{T}, dev_or_id) where {T <: AbstractDevice} + T === CUDADevice && @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." + T === AMDGPUDevice && + @warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." + T === MetalDevice && + @warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." + T === oneAPIDevice && + @warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." + T === CPUDevice && + @warn "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting." + T === XLADevice && + @warn "Setting device for `XLADevice` hasn't been implemented yet. Ignoring the device setting." + return +end + +""" + set_device!(T::Type{<:AbstractDevice}, ::Nothing, rank::Integer) + +$SET_DEVICE_DOCS + +## Arguments + + - `T::Type{<:AbstractDevice}`: The device type to set. + - `rank::Integer`: Local Rank of the process. This is applicable for distributed training and + must be `0`-indexed. + +$SET_DEVICE_DANGER +""" +function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractDevice} + return set_device!(T, rank) +end + +# Dispatches for Different Data Structures +# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability +# For all other types we rely on fmap which means we lose type stability. +# For Lux, typically models only has these 3 datastructures so we should be mostly fine. +for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA) + ldev = Symbol(dev, :Device) + @eval begin + function (D::$(ldev))(x::AbstractArray{T}) where {T} + if isbitstype(T) || Internal.special_aos(x) || x isa Adapt.WrappedArray + return Adapt.adapt(D, x) + end + return map(D, x) + end + (D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x) + function (D::$(ldev))(x) + isleaf(x) && return Adapt.adapt(D, x) + return Functors.fmap(D, x; exclude=isleaf) + end + end +end + +for op in (:get_device, :get_device_type) + @eval function $(op)(x) + Internal.fast_structure(x) && return Internal.$(op)(x) + return mapreduce(Internal.$(op), Internal.combine_devices, fleaves(x)) + end +end + +# Adapt Interface +Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) +Adapt.adapt_storage(::CPUDevice, rng::AbstractRNG) = rng + +for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, XLADevice) + @eval begin + function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) + return default_device_rng(to) + end + Adapt.adapt_storage(::$(T), rng::AbstractRNG) = rng + end +end + +""" + isleaf(x) -> Bool + +Returns `true` if `x` is a leaf node in the data structure. + +Defining `MLDataDevices.isleaf(x::T) = true` for custom types +can be used to customize the behavior the data movement behavior +when an object with nested structure containing the type is transferred to a device. + +`Adapt.adapt_structure(::AbstractDevice, x::T)` or +`Adapt.adapt_structure(::AbstractDevice, x::T)` will be called during +data movement if `isleaf(x::T) == true`. + +If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Functors.isleaf(x)`. +""" +isleaf(x) = Functors.isleaf(x) + +isleaf(::AbstractArray{T}) where {T} = isbitstype(T) +isleaf(::Adapt.WrappedArray) = false diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml new file mode 100644 index 000000000..9914e0f57 --- /dev/null +++ b/lib/MLDataDevices/test/Project.toml @@ -0,0 +1,41 @@ +[deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +Adapt = "4" +Aqua = "0.8.4" +ArrayInterface = "7.11" +ChainRulesTestUtils = "1.13.0" +ComponentArrays = "0.15.8" +ExplicitImports = "1.9.0" +FillArrays = "1" +ForwardDiff = "0.10.36" +Functors = "0.4.8" +MLUtils = "0.4" +Pkg = "1.10" +Random = "1.10" +RecursiveArrayTools = "3.8" +ReverseDiff = "1.15" +SafeTestsets = "0.1" +SparseArrays = "1.10" +Test = "1.10" +Tracker = "0.2.34" +Zygote = "0.6.69" diff --git a/lib/MLDataDevices/test/amdgpu_tests.jl b/lib/MLDataDevices/test/amdgpu_tests.jl new file mode 100644 index 000000000..a771ada6e --- /dev/null +++ b/lib/MLDataDevices/test/amdgpu_tests.jl @@ -0,0 +1,192 @@ +using MLDataDevices, Random, Test +using ArrayInterface: parameterless_type + +@testset "CPU Fallback" begin + @test !MLDataDevices.functional(AMDGPUDevice) + @test cpu_device() isa CPUDevice + @test gpu_device() isa CPUDevice + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force=true) + @test_throws Exception default_device_rng(AMDGPUDevice(nothing)) + @test_logs (:warn, "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting.") MLDataDevices.set_device!( + AMDGPUDevice, nothing, 1) +end + +using AMDGPU + +@testset "Loaded Trigger Package" begin + @test MLDataDevices.GPU_DEVICE[] === nothing + + if MLDataDevices.functional(AMDGPUDevice) + @info "AMDGPU is functional" + @test gpu_device() isa AMDGPUDevice + @test gpu_device(; force=true) isa AMDGPUDevice + else + @info "AMDGPU is NOT functional" + @test gpu_device() isa CPUDevice + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; + force=true) + end + @test MLDataDevices.GPU_DEVICE[] !== nothing +end + +using FillArrays, Zygote # Extensions + +@testset "Data Transfer" begin + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, + d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, + rng_default=Random.default_rng(), rng=MersenneTwister(), + one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) + + device = gpu_device() + aType = MLDataDevices.functional(AMDGPUDevice) ? ROCArray : Array + rngType = MLDataDevices.functional(AMDGPUDevice) ? AMDGPU.rocRAND.RNG : + Random.AbstractRNG + + ps_xpu = ps |> device + @test get_device(ps_xpu) isa AMDGPUDevice + @test get_device_type(ps_xpu) <: AMDGPUDevice + @test ps_xpu.a.c isa aType + @test ps_xpu.b isa aType + @test ps_xpu.a.d == ps.a.d + @test ps_xpu.mixed isa Vector + @test ps_xpu.mixed[1] isa Float32 + @test ps_xpu.mixed[2] isa Float64 + @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa AbstractRange + @test ps_xpu.e == ps.e + @test ps_xpu.d == ps.d + @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) isa AMDGPUDevice + @test get_device_type(ps_xpu.rng_default) <: AMDGPUDevice + @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing + + if MLDataDevices.functional(AMDGPUDevice) + @test ps_xpu.one_elem isa ROCArray + @test ps_xpu.farray isa ROCArray + else + @test ps_xpu.one_elem isa Zygote.OneElement + @test ps_xpu.farray isa Fill + end + + ps_cpu = ps_xpu |> cpu_device() + @test get_device(ps_cpu) isa CPUDevice + @test get_device_type(ps_cpu) <: CPUDevice + @test ps_cpu.a.c isa Array + @test ps_cpu.b isa Array + @test ps_cpu.a.c == ps.a.c + @test ps_cpu.b == ps.b + @test ps_cpu.a.d == ps.a.d + @test ps_cpu.mixed isa Vector + @test ps_cpu.mixed[1] isa Float32 + @test ps_cpu.mixed[2] isa Float64 + @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa AbstractRange + @test ps_cpu.e == ps.e + @test ps_cpu.d == ps.d + @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing + @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing + + if MLDataDevices.functional(AMDGPUDevice) + @test ps_cpu.one_elem isa Array + @test ps_cpu.farray isa Array + else + @test ps_cpu.one_elem isa Zygote.OneElement + @test ps_cpu.farray isa Fill + end + + ps_mixed = (; a=rand(2), b=device(rand(2))) + @test_throws ArgumentError get_device(ps_mixed) + + dev = gpu_device() + x = rand(Float32, 10, 2) + x_dev = x |> dev + @test get_device(x_dev) isa parameterless_type(typeof(dev)) + @test get_device_type(x_dev) <: parameterless_type(typeof(dev)) + + if MLDataDevices.functional(AMDGPUDevice) + dev2 = gpu_device(length(AMDGPU.devices())) + x_dev2 = x_dev |> dev2 + @test get_device(x_dev2) isa typeof(dev2) + @test get_device_type(x_dev2) <: parameterless_type(typeof(dev2)) + end + + @testset "get_device_type compile constant" begin + x = rand(10, 10) |> device + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + end +end + +@testset "Functions" begin + if MLDataDevices.functional(AMDGPUDevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> AMDGPUDevice() + @test get_device(ff_xpu) isa AMDGPUDevice + @test get_device_type(ff_xpu) <: AMDGPUDevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + +@testset "Wrapped Arrays" begin + if MLDataDevices.functional(AMDGPUDevice) + x = rand(10, 10) |> AMDGPUDevice() + @test get_device(x) isa AMDGPUDevice + @test get_device_type(x) <: AMDGPUDevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa AMDGPUDevice + @test get_device_type(x_view) <: AMDGPUDevice + end +end + +@testset "Multiple Devices AMDGPU" begin + if MLDataDevices.functional(AMDGPUDevice) + ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) + ps_cpu = deepcopy(ps) + cdev = cpu_device() + for idx in 1:length(AMDGPU.devices()) + amdgpu_device = gpu_device(idx) + @test typeof(amdgpu_device.device) <: AMDGPU.HIPDevice + @test AMDGPU.device_id(amdgpu_device.device) == idx + + ps = ps |> amdgpu_device + @test ps.weight isa ROCArray + @test ps.bias isa ROCArray + @test AMDGPU.device_id(AMDGPU.device(ps.weight)) == idx + @test AMDGPU.device_id(AMDGPU.device(ps.bias)) == idx + @test isequal(cdev(ps.weight), ps_cpu.weight) + @test isequal(cdev(ps.bias), ps_cpu.bias) + end + + ps = ps |> cdev + @test ps.weight isa Array + @test ps.bias isa Array + end +end + +@testset "setdevice!" begin + if MLDataDevices.functional(AMDGPUDevice) + for i in 1:10 + @test_nowarn MLDataDevices.set_device!(AMDGPUDevice, nothing, i) + end + end +end diff --git a/lib/MLDataDevices/test/cuda_tests.jl b/lib/MLDataDevices/test/cuda_tests.jl new file mode 100644 index 000000000..2fce4806a --- /dev/null +++ b/lib/MLDataDevices/test/cuda_tests.jl @@ -0,0 +1,244 @@ +using MLDataDevices, Random, Functors, Test +using ArrayInterface: parameterless_type + +@testset "CPU Fallback" begin + @test !MLDataDevices.functional(CUDADevice) + @test cpu_device() isa CPUDevice + @test gpu_device() isa CPUDevice + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force=true) + @test_throws Exception default_device_rng(CUDADevice(nothing)) + @test_logs (:warn, "`CUDA.jl` hasn't been loaded. Ignoring the device setting.") MLDataDevices.set_device!( + CUDADevice, nothing, 1) +end + +using LuxCUDA + +@testset "Loaded Trigger Package" begin + @test MLDataDevices.GPU_DEVICE[] === nothing + + if MLDataDevices.functional(CUDADevice) + @info "LuxCUDA is functional" + @test gpu_device() isa CUDADevice + @test gpu_device(; force=true) isa CUDADevice + else + @info "LuxCUDA is NOT functional" + @test gpu_device() isa CPUDevice + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; + force=true) + end + @test MLDataDevices.GPU_DEVICE[] !== nothing +end + +using FillArrays, Zygote # Extensions + +@testset "Data Transfer" begin + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, + d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, + rng_default=Random.default_rng(), rng=MersenneTwister(), + one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) + + device = gpu_device() + aType = MLDataDevices.functional(CUDADevice) ? CuArray : Array + rngType = MLDataDevices.functional(CUDADevice) ? CUDA.RNG : Random.AbstractRNG + + ps_xpu = ps |> device + @test get_device(ps_xpu) isa CUDADevice + @test get_device_type(ps_xpu) <: CUDADevice + @test ps_xpu.a.c isa aType + @test ps_xpu.b isa aType + @test ps_xpu.a.d == ps.a.d + @test ps_xpu.mixed isa Vector + @test ps_xpu.mixed[1] isa Float32 + @test ps_xpu.mixed[2] isa Float64 + @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa AbstractRange + @test ps_xpu.e == ps.e + @test ps_xpu.d == ps.d + @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) isa CUDADevice + @test get_device_type(ps_xpu.rng_default) <: CUDADevice + @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing + + if MLDataDevices.functional(CUDADevice) + @test ps_xpu.one_elem isa CuArray + @test ps_xpu.farray isa CuArray + else + @test ps_xpu.one_elem isa Zygote.OneElement + @test ps_xpu.farray isa Fill + end + + ps_cpu = ps_xpu |> cpu_device() + @test get_device(ps_cpu) isa CPUDevice + @test get_device_type(ps_cpu) <: CPUDevice + @test ps_cpu.a.c isa Array + @test ps_cpu.b isa Array + @test ps_cpu.a.c == ps.a.c + @test ps_cpu.b == ps.b + @test ps_cpu.a.d == ps.a.d + @test ps_cpu.mixed isa Vector + @test ps_cpu.mixed[1] isa Float32 + @test ps_cpu.mixed[2] isa Float64 + @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa AbstractRange + @test ps_cpu.e == ps.e + @test ps_cpu.d == ps.d + @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing + @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing + + if MLDataDevices.functional(CUDADevice) + @test ps_cpu.one_elem isa Array + @test ps_cpu.farray isa Array + else + @test ps_cpu.one_elem isa Zygote.OneElement + @test ps_cpu.farray isa Fill + end + + struct MyStruct + x::Any + end + + Functors.@functor MyStruct + + data = MyStruct(rand(10)) + @test get_device(data) isa CPUDevice + @test get_device_type(data) <: CPUDevice + data_dev = data |> device + if MLDataDevices.functional(CUDADevice) + @test get_device(data_dev) isa CUDADevice + @test get_device_type(data_dev) <: CUDADevice + else + @test get_device(data_dev) isa CPUDevice + @test get_device_type(data_dev) <: CPUDevice + end + + ps_mixed = (; a=rand(2), c=(rand(2), 1), st=MyStruct(rand(2)), b=device(rand(2))) + @test get_device(ps_mixed.st) isa CPUDevice + @test get_device_type(ps_mixed.st) <: CPUDevice + @test get_device(ps_mixed.c) isa CPUDevice + @test get_device_type(ps_mixed.c) <: CPUDevice + @test_throws ArgumentError get_device(ps_mixed) + @test_throws ArgumentError get_device_type(ps_mixed) + + dev = gpu_device() + x = rand(Float32, 10, 2) + x_dev = x |> dev + @test get_device(x_dev) isa parameterless_type(typeof(dev)) + @test get_device_type(x_dev) <: parameterless_type(typeof(dev)) + + if MLDataDevices.functional(CUDADevice) + dev2 = gpu_device(length(CUDA.devices())) + x_dev2 = x_dev |> dev2 + @test get_device(x_dev2) isa typeof(dev2) + @test get_device_type(x_dev2) <: parameterless_type(typeof(dev2)) + end + + @testset "get_device_type compile constant" begin + x = rand(10, 10) |> device + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + + return_val2(x) = Val(get_device(x)) + @test_throws ErrorException @inferred(return_val2(ps)) + end +end + +@testset "Functions" begin + if MLDataDevices.functional(CUDADevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> CUDADevice() + @test get_device(ff_xpu) isa CUDADevice + @test get_device_type(ff_xpu) <: CUDADevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + +@testset "Wrapped Arrays" begin + if MLDataDevices.functional(CUDADevice) + x = rand(10, 10) |> CUDADevice() + @test get_device(x) isa CUDADevice + @test get_device_type(x) <: CUDADevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa CUDADevice + @test get_device_type(x_view) <: CUDADevice + end +end + +@testset "Multiple Devices CUDA" begin + if MLDataDevices.functional(CUDADevice) + ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) + ps_cpu = deepcopy(ps) + cdev = cpu_device() + for idx in 1:length(CUDA.devices()) + cuda_device = gpu_device(idx) + @test typeof(cuda_device.device) <: CUDA.CuDevice + @test cuda_device.device.handle == (idx - 1) + + ps = ps |> cuda_device + @test ps.weight isa CuArray + @test ps.bias isa CuArray + @test CUDA.device(ps.weight).handle == idx - 1 + @test CUDA.device(ps.bias).handle == idx - 1 + @test isequal(cdev(ps.weight), ps_cpu.weight) + @test isequal(cdev(ps.bias), ps_cpu.bias) + end + + ps = ps |> cdev + @test ps.weight isa Array + @test ps.bias isa Array + end +end + +using SparseArrays + +@testset "CUDA Sparse Arrays" begin + if MLDataDevices.functional(CUDADevice) + ps = (; weight=sprand(Float32, 10, 10, 0.1), bias=sprand(Float32, 10, 0.1)) + ps_cpu = deepcopy(ps) + cdev = cpu_device() + for idx in 1:length(CUDA.devices()) + cuda_device = gpu_device(idx) + @test typeof(cuda_device.device) <: CUDA.CuDevice + @test cuda_device.device.handle == (idx - 1) + + ps = ps |> cuda_device + @test ps.weight isa CUSPARSE.CuSparseMatrixCSC + @test ps.bias isa CUSPARSE.CuSparseVector + @test get_device(ps.weight).device.handle == idx - 1 + @test get_device(ps.bias).device.handle == idx - 1 + @test isequal(cdev(ps.weight), ps_cpu.weight) + @test isequal(cdev(ps.bias), ps_cpu.bias) + end + + ps = ps |> cdev + @test ps.weight isa SparseMatrixCSC + @test ps.bias isa SparseVector + end +end + +@testset "setdevice!" begin + if MLDataDevices.functional(CUDADevice) + for i in 1:10 + @test_nowarn MLDataDevices.set_device!(CUDADevice, nothing, i) + end + end +end diff --git a/lib/MLDataDevices/test/iterator_tests.jl b/lib/MLDataDevices/test/iterator_tests.jl new file mode 100644 index 000000000..132acd7de --- /dev/null +++ b/lib/MLDataDevices/test/iterator_tests.jl @@ -0,0 +1,129 @@ +using MLDataDevices, MLUtils + +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none")) + +if BACKEND_GROUP == "cuda" || BACKEND_GROUP == "all" + using LuxCUDA +end + +if BACKEND_GROUP == "amdgpu" || BACKEND_GROUP == "all" + using AMDGPU +end + +if BACKEND_GROUP == "metal" || BACKEND_GROUP == "all" + using Metal +end + +if BACKEND_GROUP == "oneapi" || BACKEND_GROUP == "all" + using oneAPI +end + +if BACKEND_GROUP == "xla" || BACKEND_GROUP == "all" + using Reactant + if "gpu" in keys(Reactant.XLA.backends) + Reactant.set_default_backend("gpu") + end +end + +DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, XLADevice] + +freed_if_can_be_freed(x) = freed_if_can_be_freed(get_device_type(x), x) +freed_if_can_be_freed(::Type{CPUDevice}, x) = true +freed_if_can_be_freed(::Type{XLADevice}, x) = true +function freed_if_can_be_freed(::Type, x) + try + Array(x) + return false + catch err + err isa ArgumentError && return true + rethrow() + end +end + +@testset "Device Iterator: $(dev_type)" for dev_type in DEVICES + dev = dev_type() + + !MLDataDevices.functional(dev) && continue + + @info "Testing Device Iterator for $(dev)..." + + @testset "Basic Device Iterator" begin + datalist = [rand(10) for _ in 1:10] + + prev_batch = nothing + for data in DeviceIterator(dev, datalist) + prev_batch === nothing || @test freed_if_can_be_freed(prev_batch) + prev_batch = data + @test size(data) == (10,) + @test get_device_type(data) == dev_type + end + end + + @testset "DataLoader: parallel=$parallel" for parallel in (true, false) + @info "Testing DataLoader with parallel=$parallel" + X = rand(Float64, 3, 33) + post = DataLoader(X; batchsize=13, shuffle=false, parallel) |> dev + if dev_type === XLADevice + pre = post # XXX: deadlocks and other shenanigans + else + pre = DataLoader(dev(X); batchsize=13, shuffle=false, parallel) + end + + for epoch in 1:2 + prev_pre, prev_post = nothing, nothing + for (p, q) in zip(pre, post) + @test get_device_type(p) == dev_type + @test get_device_type(q) == dev_type + # Ordering is not guaranteed in parallel + !parallel && @test p ≈ q + + if dev_type === CPUDevice || dev_type === XLADevice + continue + end + + prev_pre === nothing || @test !freed_if_can_be_freed(prev_pre) + prev_pre = p + + prev_post === nothing || @test freed_if_can_be_freed(prev_post) + prev_post = q + end + end + + Y = rand(Float64, 1, 33) + post = DataLoader((; x=X, y=Y); batchsize=13, shuffle=false, parallel) |> dev + if dev_type === XLADevice + pre = post # XXX: deadlocks and other shenanigans + else + pre = DataLoader((; x=dev(X), y=dev(Y)); batchsize=13, shuffle=false, parallel) + end + + for epoch in 1:2 + prev_pre, prev_post = nothing, nothing + for (p, q) in zip(pre, post) + @test get_device_type(p.x) == dev_type + @test get_device_type(p.y) == dev_type + @test get_device_type(q.x) == dev_type + @test get_device_type(q.y) == dev_type + # Ordering is not guaranteed in parallel + !parallel && @test p.x ≈ q.x + !parallel && @test p.y ≈ q.y + + if dev_type === CPUDevice || dev_type === XLADevice + continue + end + + if prev_pre !== nothing + @test !freed_if_can_be_freed(prev_pre.x) + @test !freed_if_can_be_freed(prev_pre.y) + end + prev_pre = p + + if prev_post !== nothing + @test freed_if_can_be_freed(prev_post.x) + @test freed_if_can_be_freed(prev_post.y) + end + prev_post = q + end + end + end +end diff --git a/lib/MLDataDevices/test/metal_tests.jl b/lib/MLDataDevices/test/metal_tests.jl new file mode 100644 index 000000000..2bc884553 --- /dev/null +++ b/lib/MLDataDevices/test/metal_tests.jl @@ -0,0 +1,156 @@ +using MLDataDevices, Random, Test +using ArrayInterface: parameterless_type + +@testset "CPU Fallback" begin + @test !MLDataDevices.functional(MetalDevice) + @test cpu_device() isa CPUDevice + @test gpu_device() isa CPUDevice + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force=true) + @test_throws Exception default_device_rng(MetalDevice()) +end + +using Metal + +@testset "Loaded Trigger Package" begin + @test MLDataDevices.GPU_DEVICE[] === nothing + + if MLDataDevices.functional(MetalDevice) + @info "Metal is functional" + @test gpu_device() isa MetalDevice + @test gpu_device(; force=true) isa MetalDevice + else + @info "Metal is NOT functional" + @test gpu_device() isa MetalDevice + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; + force=true) + end + @test MLDataDevices.GPU_DEVICE[] !== nothing +end + +using FillArrays, Zygote # Extensions + +@testset "Data Transfer" begin + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, + d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, + rng_default=Random.default_rng(), rng=MersenneTwister(), + one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) + + device = gpu_device() + aType = MLDataDevices.functional(MetalDevice) ? MtlArray : Array + rngType = MLDataDevices.functional(MetalDevice) ? Metal.GPUArrays.RNG : + Random.AbstractRNG + + ps_xpu = ps |> device + @test get_device(ps_xpu) isa MetalDevice + @test get_device_type(ps_xpu) <: MetalDevice + @test ps_xpu.a.c isa aType + @test ps_xpu.b isa aType + @test ps_xpu.a.d == ps.a.d + @test ps_xpu.mixed isa Vector + @test ps_xpu.mixed[1] isa Float32 + @test ps_xpu.mixed[2] isa Float64 + @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa AbstractRange + @test ps_xpu.e == ps.e + @test ps_xpu.d == ps.d + @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) isa MetalDevice + @test get_device_type(ps_xpu.rng_default) <: MetalDevice + @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing + + if MLDataDevices.functional(MetalDevice) + @test ps_xpu.one_elem isa MtlArray + @test ps_xpu.farray isa MtlArray + else + @test ps_xpu.one_elem isa Zygote.OneElement + @test ps_xpu.farray isa Fill + end + + ps_cpu = ps_xpu |> cpu_device() + @test get_device(ps_cpu) isa CPUDevice + @test get_device_type(ps_cpu) <: CPUDevice + @test ps_cpu.a.c isa Array + @test ps_cpu.b isa Array + @test ps_cpu.a.c == ps.a.c + @test ps_cpu.b == ps.b + @test ps_cpu.a.d == ps.a.d + @test ps_cpu.mixed isa Vector + @test ps_cpu.mixed[1] isa Float32 + @test ps_cpu.mixed[2] isa Float64 + @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa AbstractRange + @test ps_cpu.e == ps.e + @test ps_cpu.d == ps.d + @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing + @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing + + if MLDataDevices.functional(MetalDevice) + @test ps_cpu.one_elem isa Array + @test ps_cpu.farray isa Array + else + @test ps_cpu.one_elem isa Zygote.OneElement + @test ps_cpu.farray isa Fill + end + + ps_mixed = (; a=rand(2), b=device(rand(2))) + @test_throws ArgumentError get_device(ps_mixed) + @test_throws ArgumentError get_device_type(ps_mixed) + + @testset "get_device_type compile constant" begin + x = rand(10, 10) |> device + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + + return_val2(x) = Val(get_device(x)) + @test @inferred(return_val2(ps)) isa Val{get_device(x)} + end +end + +@testset "Functions" begin + if MLDataDevices.functional(MetalDevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> MetalDevice() + @test get_device(ff_xpu) isa MetalDevice + @test get_device_type(ff_xpu) <: MetalDevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + +@testset "Wrapper Arrays" begin + if MLDataDevices.functional(MetalDevice) + x = rand(Float32, 10, 10) |> MetalDevice() + @test get_device(x) isa MetalDevice + @test get_device_type(x) <: MetalDevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa MetalDevice + @test get_device_type(x_view) <: MetalDevice + end +end + +@testset "setdevice!" begin + if MLDataDevices.functional(MetalDevice) + @test_logs (:warn, + "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting.") MLDataDevices.set_device!( + MetalDevice, nothing, 1) + end +end diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl new file mode 100644 index 000000000..28275d3b7 --- /dev/null +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -0,0 +1,221 @@ +using Adapt, MLDataDevices, ComponentArrays, Random +using ArrayInterface: parameterless_type +using ChainRulesTestUtils: test_rrule +using ReverseDiff, Tracker, ForwardDiff +using SparseArrays, FillArrays, Zygote, RecursiveArrayTools +using Functors: Functors + +@testset "Issues Patches" begin + @testset "#10 patch" begin + dev = CPUDevice() + ps = (; weight=randn(10, 1), bias=randn(1)) + + ps_ca = ps |> ComponentArray + + ps_ca_dev = ps_ca |> dev + + @test ps_ca_dev isa ComponentArray + + @test ps_ca_dev.weight == ps.weight + @test ps_ca_dev.bias == ps.bias + + @test ps_ca_dev == (ps |> dev |> ComponentArray) + end +end + +@testset "AD Types" begin + x = randn(Float32, 10) + + x_rdiff = ReverseDiff.track(x) + @test get_device(x_rdiff) isa CPUDevice + x_rdiff = ReverseDiff.track.(x) + @test get_device(x_rdiff) isa CPUDevice + + gdev = gpu_device() + + x_tracker = Tracker.param(x) + @test get_device(x_tracker) isa CPUDevice + x_tracker = Tracker.param.(x) + @test get_device(x_tracker) isa CPUDevice + x_tracker_dev = Tracker.param(x) |> gdev + @test get_device(x_tracker_dev) isa parameterless_type(typeof(gdev)) + x_tracker_dev = Tracker.param.(x) |> gdev + @test get_device(x_tracker_dev) isa parameterless_type(typeof(gdev)) + + x_fdiff = ForwardDiff.Dual.(x) + @test get_device(x_fdiff) isa CPUDevice + x_fdiff_dev = ForwardDiff.Dual.(x) |> gdev + @test get_device(x_fdiff_dev) isa parameterless_type(typeof(gdev)) +end + +@testset "CRC Tests" begin + dev = cpu_device() # Other devices don't work with FiniteDifferences.jl + test_rrule(Adapt.adapt, dev, randn(Float64, 10); check_inferred=true) + + gdev = gpu_device() + if !(gdev isa MetalDevice) # On intel devices causes problems + x = randn(10) + ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt, gdev, x) + @test ∂dev === nothing + @test ∂x ≈ ones(10) + + x = randn(10) |> gdev + ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt, cpu_device(), x) + @test ∂dev === nothing + @test ∂x ≈ gdev(ones(10)) + @test get_device(∂x) isa parameterless_type(typeof(gdev)) + end +end + +# The following just test for noops +@testset "NoOps CPU" begin + cdev = cpu_device() + + @test cdev(sprand(10, 10, 0.9)) isa SparseMatrixCSC + @test cdev(1:10) isa AbstractRange + @test cdev(Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4))) isa Zygote.OneElement +end + +@testset "RecursiveArrayTools" begin + gdev = gpu_device() + + diffeqarray = DiffEqArray([rand(10) for _ in 1:10], rand(10)) + @test get_device(diffeqarray) isa CPUDevice + + diffeqarray_dev = diffeqarray |> gdev + @test get_device(diffeqarray_dev) isa parameterless_type(typeof(gdev)) + + vecarray = VectorOfArray([rand(10) for _ in 1:10]) + @test get_device(vecarray) isa CPUDevice + + vecarray_dev = vecarray |> gdev + @test get_device(vecarray_dev) isa parameterless_type(typeof(gdev)) +end + +@testset "CPU default rng" begin + @test default_device_rng(CPUDevice()) isa Random.TaskLocalRNG +end + +@testset "CPU setdevice!" begin + @test_logs (:warn, + "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting.") MLDataDevices.set_device!( + CPUDevice, nothing, 1) +end + +@testset "get_device on Arrays" begin + x = rand(10, 10) + x_view = view(x, 1:5, 1:5) + + @test get_device(x) isa CPUDevice + @test get_device(x_view) isa CPUDevice + + struct MyArrayType <: AbstractArray{Float32, 2} + data::Array{Float32, 2} + end + + x_custom = MyArrayType(rand(10, 10)) + + @test get_device(x_custom) isa CPUDevice +end + +@testset "loaded and functional" begin + @test MLDataDevices.loaded(CPUDevice) + @test MLDataDevices.functional(CPUDevice) +end + +@testset "writing to preferences" begin + @test_logs (:info, + "Deleted the local preference for `gpu_backend`. Restart Julia to use the new backend.") gpu_backend!() + + for backend in (:CUDA, :AMDGPU, :oneAPI, :Metal, AMDGPUDevice(), + CUDADevice(), MetalDevice(), oneAPIDevice()) + backend_name = backend isa Symbol ? string(backend) : + MLDataDevices.Internal.get_device_name(backend) + @test_logs (:info, + "GPU backend has been set to $(backend_name). Restart Julia to use the new backend.") gpu_backend!(backend) + end + + gpu_backend!(:CUDA) + @test_logs (:info, "GPU backend is already set to CUDA. No action is required.") gpu_backend!(:CUDA) + + @test_throws ArgumentError gpu_backend!("my_backend") +end + +@testset "get_device_type compile constant" begin + x = rand(10, 10) + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{typeof(cpu_device())} + + return_val2(x) = Val(get_device(x)) + @test @inferred(return_val2(ps)) isa Val{cpu_device()} +end + +@testset "undefined references array" begin + x = Matrix{Any}(undef, 10, 10) + + @test get_device(x) isa MLDataDevices.UnknownDevice + @test get_device_type(x) <: MLDataDevices.UnknownDevice +end + +@testset "isleaf" begin + @testset "basics" begin + # Functors.isleaf fallback + @test MLDataDevices.isleaf(rand(2)) + @test !MLDataDevices.isleaf((rand(2),)) + + struct Tleaf + x::Any + end + Functors.@functor Tleaf + MLDataDevices.isleaf(::Tleaf) = true + Adapt.adapt_structure(dev::CPUDevice, t::Tleaf) = Tleaf(2 .* dev(t.x)) + + cpu = cpu_device() + t = Tleaf(ones(2)) + y = cpu(t) + @test y.x == 2 .* ones(2) + y = cpu([(t,)]) + @test y[1][1].x == 2 .* ones(2) + end + + @testset "shared parameters" begin + x = rand(1) + m = (; a=x, b=x') + count = Ref(0) + mcopy = Functors.fmap(m; exclude=MLDataDevices.isleaf) do x + count[] += 1 + return copy(x) + end + @test count[] == 1 + @test mcopy.a === mcopy.b' + end + + @testset "bitstypes and wrapped types" begin + struct BitsType + x::Int32 + y::Float64 + end + + @testset for x in [1.0, 'a', BitsType(1, 2.0)] + @test MLDataDevices.isleaf([x]) + @test !MLDataDevices.isleaf([x]') + @test !MLDataDevices.isleaf(transpose([x])) + @test !MLDataDevices.isleaf(PermutedDimsArray([x;;], (1, 2))) + end + end +end + +@testset "Zygote.gradient(wrapped arrays)" begin + using Zygote + + x = rand(4, 4) + cdev = cpu_device() + + @test only(Zygote.gradient(x -> sum(abs2, cdev(x)), x')) isa Matrix{Float64} + + gdev = gpu_device() + + @test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64} +end diff --git a/lib/MLDataDevices/test/oneapi_tests.jl b/lib/MLDataDevices/test/oneapi_tests.jl new file mode 100644 index 000000000..2169869d3 --- /dev/null +++ b/lib/MLDataDevices/test/oneapi_tests.jl @@ -0,0 +1,156 @@ +using MLDataDevices, Random, Test +using ArrayInterface: parameterless_type + +@testset "CPU Fallback" begin + @test !MLDataDevices.functional(oneAPIDevice) + @test cpu_device() isa CPUDevice + @test gpu_device() isa CPUDevice + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; force=true) + @test_throws Exception default_device_rng(oneAPIDevice()) +end + +using oneAPI + +@testset "Loaded Trigger Package" begin + @test MLDataDevices.GPU_DEVICE[] === nothing + + if MLDataDevices.functional(oneAPIDevice) + @info "oneAPI is functional" + @test gpu_device() isa oneAPIDevice + @test gpu_device(; force=true) isa oneAPIDevice + else + @info "oneAPI is NOT functional" + @test gpu_device() isa oneAPIDevice + @test_throws MLDataDevices.Internal.DeviceSelectionException gpu_device(; + force=true) + end + @test MLDataDevices.GPU_DEVICE[] !== nothing +end + +using FillArrays, Zygote # Extensions + +@testset "Data Transfer" begin + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, + d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, + rng_default=Random.default_rng(), rng=MersenneTwister(), + one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) + + device = gpu_device() + aType = MLDataDevices.functional(oneAPIDevice) ? oneArray : Array + rngType = MLDataDevices.functional(oneAPIDevice) ? oneAPI.GPUArrays.RNG : + Random.AbstractRNG + + ps_xpu = ps |> device + @test get_device(ps_xpu) isa oneAPIDevice + @test get_device_type(ps_xpu) <: oneAPIDevice + @test ps_xpu.a.c isa aType + @test ps_xpu.b isa aType + @test ps_xpu.a.d == ps.a.d + @test ps_xpu.mixed isa Vector + @test ps_xpu.mixed[1] isa Float32 + @test ps_xpu.mixed[2] isa Float64 + @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa AbstractRange + @test ps_xpu.e == ps.e + @test ps_xpu.d == ps.d + @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) isa oneAPIDevice + @test get_device_type(ps_xpu.rng_default) <: oneAPIDevice + @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing + + if MLDataDevices.functional(oneAPIDevice) + @test ps_xpu.one_elem isa oneArray + @test ps_xpu.farray isa oneArray + else + @test ps_xpu.one_elem isa Zygote.OneElement + @test ps_xpu.farray isa Fill + end + + ps_cpu = ps_xpu |> cpu_device() + @test get_device(ps_cpu) isa CPUDevice + @test get_device_type(ps_cpu) <: CPUDevice + @test ps_cpu.a.c isa Array + @test ps_cpu.b isa Array + @test ps_cpu.a.c == ps.a.c + @test ps_cpu.b == ps.b + @test ps_cpu.a.d == ps.a.d + @test ps_cpu.mixed isa Vector + @test ps_cpu.mixed[1] isa Float32 + @test ps_cpu.mixed[2] isa Float64 + @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa AbstractRange + @test ps_cpu.e == ps.e + @test ps_cpu.d == ps.d + @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing + @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing + + if MLDataDevices.functional(oneAPIDevice) + @test ps_cpu.one_elem isa Array + @test ps_cpu.farray isa Array + else + @test ps_cpu.one_elem isa Zygote.OneElement + @test ps_cpu.farray isa Fill + end + + ps_mixed = (; a=rand(2), b=device(rand(2))) + @test_throws ArgumentError get_device(ps_mixed) + @test_throws ArgumentError get_device_type(ps_mixed) + + @testset "get_device_type compile constant" begin + x = rand(10, 10) |> device + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + + return_val2(x) = Val(get_device(x)) + @test @inferred(return_val2(ps)) isa Val{get_device(x)} + end +end + +@testset "Functions" begin + if MLDataDevices.functional(oneAPIDevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> oneAPIDevice() + @test get_device(ff_xpu) isa oneAPIDevice + @test get_device_type(ff_xpu) <: oneAPIDevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + +@testset "Wrapper Arrays" begin + if MLDataDevices.functional(oneAPIDevice) + x = rand(10, 10) |> oneAPIDevice() + @test get_device(x) isa oneAPIDevice + @test get_device_type(x) <: oneAPIDevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa oneAPIDevice + @test get_device_type(x_view) <: oneAPIDevice + end +end + +@testset "setdevice!" begin + if MLDataDevices.functional(oneAPIDevice) + @test_logs (:warn, + "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting.") MLDataDevices.set_device!( + oneAPIDevice, nothing, 1) + end +end diff --git a/lib/MLDataDevices/test/qa_tests.jl b/lib/MLDataDevices/test/qa_tests.jl new file mode 100644 index 000000000..b5e4cb65a --- /dev/null +++ b/lib/MLDataDevices/test/qa_tests.jl @@ -0,0 +1,19 @@ +using Aqua, ExplicitImports, MLDataDevices, Test + +@testset "Aqua Tests" begin + Aqua.test_all(MLDataDevices) +end + +import FillArrays, RecursiveArrayTools, SparseArrays, Zygote + +@testset "Explicit Imports" begin + @test check_no_implicit_imports(MLDataDevices) === nothing + @test check_no_stale_explicit_imports(MLDataDevices) === nothing + @test check_no_self_qualified_accesses(MLDataDevices) === nothing + @test check_all_explicit_imports_via_owners(MLDataDevices) === nothing + @test check_all_qualified_accesses_via_owners( + MLDataDevices; ignore=(:SparseArrays,)) === nothing + # mostly upstream problems + @test_broken check_all_explicit_imports_are_public(MLDataDevices) === nothing + @test_broken check_all_qualified_accesses_are_public(MLDataDevices) === nothing +end diff --git a/lib/MLDataDevices/test/runtests.jl b/lib/MLDataDevices/test/runtests.jl new file mode 100644 index 000000000..4b02862e3 --- /dev/null +++ b/lib/MLDataDevices/test/runtests.jl @@ -0,0 +1,49 @@ +using Pkg: Pkg, PackageSpec +using SafeTestsets, Test + +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none")) + +const EXTRA_PKGS = PackageSpec[] +const EXTRA_DEV_PKGS = PackageSpec[] + +if (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") + if isdir(joinpath(@__DIR__, "../../LuxCUDA")) + @info "Using local LuxCUDA" + push!(EXTRA_DEV_PKGS, PackageSpec(; path=joinpath(@__DIR__, "../../LuxCUDA"))) + else + push!(EXTRA_PKGS, PackageSpec(; name="LuxCUDA")) + end +end +(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && + push!(EXTRA_PKGS, PackageSpec(; name="AMDGPU")) +(BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && + push!(EXTRA_PKGS, PackageSpec(; name="oneAPI")) +(BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && + push!(EXTRA_PKGS, PackageSpec(; name="Metal")) +(BACKEND_GROUP == "all" || BACKEND_GROUP == "xla") && + push!(EXTRA_PKGS, PackageSpec(; name="Reactant")) + +if !isempty(EXTRA_PKGS) || !isempty(EXTRA_DEV_PKGS) + @info "Installing Extra Packages for testing" EXTRA_PKGS EXTRA_DEV_PKGS + isempty(EXTRA_PKGS) || Pkg.add(EXTRA_PKGS) + isempty(EXTRA_DEV_PKGS) || Pkg.develop(EXTRA_DEV_PKGS) + Pkg.update() + Base.retry_load_extensions() + Pkg.instantiate() +end + +@testset "MLDataDevices Tests" begin + all_files = ["cuda_tests.jl", "amdgpu_tests.jl", + "metal_tests.jl", "oneapi_tests.jl", "xla_tests.jl"] + file_names = BACKEND_GROUP == "all" ? all_files : + BACKEND_GROUP ∈ ("cpu", "none") ? [] : [BACKEND_GROUP * "_tests.jl"] + @testset "$(file_name)" for file_name in file_names + run(`$(Base.julia_cmd()) --color=yes --project=$(dirname(Pkg.project().path)) + --startup-file=no --code-coverage=user $(@__DIR__)/$file_name`) + Test.@test true + end + + @safetestset "Iterator Tests" include("iterator_tests.jl") + @safetestset "Misc Tests" include("misc_tests.jl") + @safetestset "QA Tests" include("qa_tests.jl") +end diff --git a/lib/MLDataDevices/test/xla_tests.jl b/lib/MLDataDevices/test/xla_tests.jl new file mode 100644 index 000000000..21466bd1d --- /dev/null +++ b/lib/MLDataDevices/test/xla_tests.jl @@ -0,0 +1,155 @@ +using MLDataDevices, Random, Test +using ArrayInterface: parameterless_type + +@testset "CPU Fallback" begin + @test !MLDataDevices.functional(XLADevice) + @test cpu_device() isa CPUDevice + @test xla_device() isa CPUDevice + @test_throws MLDataDevices.Internal.DeviceSelectionException xla_device(; force=true) + @test_throws Exception default_device_rng(XLADevice()) +end + +using Reactant +if "gpu" in keys(Reactant.XLA.backends) + Reactant.set_default_backend("gpu") +end + +@testset "Loaded Trigger Package" begin + if MLDataDevices.functional(XLADevice) + @info "Reactant is functional" + @test xla_device() isa XLADevice + @test xla_device(; force=true) isa XLADevice + else + @info "Reactant is NOT functional" + @test xla_device() isa CPUDevice + @test_throws MLDataDevices.Internal.DeviceSelectionException xla_device(; + force=true) + end +end + +using FillArrays, Zygote # Extensions + +@testset "Data Transfer" begin + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, + d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, + rng_default=Random.default_rng(), rng=MersenneTwister(), + one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) + + device = xla_device() + aType = MLDataDevices.functional(XLADevice) ? Reactant.ConcreteRArray : Array + rngType = Random.AbstractRNG + + ps_xpu = ps |> device + @test get_device(ps_xpu) isa XLADevice + @test get_device_type(ps_xpu) <: XLADevice + @test ps_xpu.a.c isa aType + @test ps_xpu.b isa aType + @test ps_xpu.a.d == ps.a.d + @test ps_xpu.mixed isa Vector + @test ps_xpu.mixed[1] isa Float32 + @test ps_xpu.mixed[2] isa Float64 + @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa AbstractRange + @test ps_xpu.e == ps.e + @test ps_xpu.d == ps.d + @test ps_xpu.rng_default isa rngType + @test get_device(ps_xpu.rng_default) === nothing + @test get_device_type(ps_xpu.rng_default) <: Nothing + @test ps_xpu.rng == ps.rng + @test get_device(ps_xpu.rng) === nothing + @test get_device_type(ps_xpu.rng) <: Nothing + + if MLDataDevices.functional(XLADevice) + @test ps_xpu.one_elem isa Reactant.RArray + @test ps_xpu.farray isa Reactant.RArray + else + @test ps_xpu.one_elem isa Zygote.OneElement + @test ps_xpu.farray isa Fill + end + + ps_cpu = ps_xpu |> cpu_device() + @test get_device(ps_cpu) isa CPUDevice + @test get_device_type(ps_cpu) <: CPUDevice + @test ps_cpu.a.c isa Array + @test ps_cpu.b isa Array + @test ps_cpu.a.c == ps.a.c + @test ps_cpu.b == ps.b + @test ps_cpu.a.d == ps.a.d + @test ps_cpu.mixed isa Vector + @test ps_cpu.mixed[1] isa Float32 + @test ps_cpu.mixed[2] isa Float64 + @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa AbstractRange + @test ps_cpu.e == ps.e + @test ps_cpu.d == ps.d + @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test get_device(ps_cpu.rng_default) === nothing + @test get_device_type(ps_cpu.rng_default) <: Nothing + @test ps_cpu.rng == ps.rng + @test get_device(ps_cpu.rng) === nothing + @test get_device_type(ps_cpu.rng) <: Nothing + + if MLDataDevices.functional(XLADevice) + @test ps_cpu.one_elem isa Array + @test ps_cpu.farray isa Array + else + @test ps_cpu.one_elem isa Zygote.OneElement + @test ps_cpu.farray isa Fill + end + + ps_mixed = (; a=rand(2), b=device(rand(2))) + @test_throws ArgumentError get_device(ps_mixed) + @test_throws ArgumentError get_device_type(ps_mixed) + + @testset "get_device_type compile constant" begin + x = rand(10, 10) |> device + ps = (; weight=x, bias=x, d=(x, x)) + + return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work + @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + + return_val2(x) = Val(get_device(x)) + @test @inferred(return_val2(ps)) isa Val{get_device(x)} + end +end + +@testset "Functions" begin + if MLDataDevices.functional(XLADevice) + @test get_device(tanh) isa MLDataDevices.UnknownDevice + @test get_device_type(tanh) <: MLDataDevices.UnknownDevice + + f(x, y) = () -> (x, x .^ 2, y) + + ff = f([1, 2, 3], 1) + @test get_device(ff) isa CPUDevice + @test get_device_type(ff) <: CPUDevice + + ff_xpu = ff |> XLADevice() + @test get_device(ff_xpu) isa XLADevice + @test get_device_type(ff_xpu) <: XLADevice + + ff_cpu = ff_xpu |> cpu_device() + @test get_device(ff_cpu) isa CPUDevice + @test get_device_type(ff_cpu) <: CPUDevice + end +end + +@testset "Wrapped Arrays" begin + if MLDataDevices.functional(XLADevice) + x = rand(10, 10) |> XLADevice() + @test get_device(x) isa XLADevice + @test get_device_type(x) <: XLADevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa XLADevice + @test get_device_type(x_view) <: XLADevice + end +end + +@testset "setdevice!" begin + if MLDataDevices.functional(XLADevice) + @test_logs (:warn, + "Setting device for `XLADevice` hasn't been implemented yet. Ignoring the device setting.") MLDataDevices.set_device!( + XLADevice, nothing, 1) + end +end diff --git a/lib/WeightInitializers/LICENSE b/lib/WeightInitializers/LICENSE new file mode 100644 index 000000000..e87b80c0d --- /dev/null +++ b/lib/WeightInitializers/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Avik Pal and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lib/WeightInitializers/Project.toml b/lib/WeightInitializers/Project.toml new file mode 100644 index 000000000..bb39b7955 --- /dev/null +++ b/lib/WeightInitializers/Project.toml @@ -0,0 +1,45 @@ +name = "WeightInitializers" +uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" +authors = ["Avik Pal and contributors"] +version = "1.0.4" + +[deps] +ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[weakdeps] +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" + +[extensions] +WeightInitializersAMDGPUExt = ["AMDGPU", "GPUArrays"] +WeightInitializersCUDAExt = ["CUDA", "GPUArrays"] +WeightInitializersChainRulesCoreExt = "ChainRulesCore" +WeightInitializersGPUArraysExt = "GPUArrays" +WeightInitializersMetalExt = ["Metal", "GPUArrays"] +WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"] + +[compat] +AMDGPU = "0.9.6, 1" +ArgCheck = "2.3.0" +CUDA = "5.3.2" +ChainRulesCore = "1.23" +ConcreteStructs = "0.2.3" +GPUArrays = "10.2, 11" +GPUArraysCore = "0.1.6, 0.2" +LinearAlgebra = "1.10" +Metal = "1.3.0" +Random = "1.10" +SpecialFunctions = "2.4" +Statistics = "1.10" +julia = "1.10" +oneAPI = "1.5.0" diff --git a/lib/WeightInitializers/README.md b/lib/WeightInitializers/README.md new file mode 100644 index 000000000..14d3edba7 --- /dev/null +++ b/lib/WeightInitializers/README.md @@ -0,0 +1,69 @@ +# WeightInitializers + +This package is a light dependency providing common weight initialization schemes for deep +learning models. + +## Example + +These code snippets are just provided to give a high level overview of the functionalities +of the package. + +```julia +using WeightInitializers, Random + +# Fixing rng +rng = MersenneTwister(42) + +# Explicit rng call +weights = kaiming_normal(rng, 2, 5) +#2×5 Matrix{Float32}: +# -0.351662 0.0171745 1.12442 -0.296372 -1.67094 +# -0.281053 -0.18941 -0.724099 0.0987538 0.634549 + +# Default rng call +weights = kaiming_normal(2, 5) +#2×5 Matrix{Float32}: +# -0.227513 -0.265372 0.265788 1.29955 -0.192836 +# 0.687611 0.454679 -0.433656 0.20548 0.292002 + +# Passing kwargs (if needed) with explicit rng call +weights_cl = kaiming_normal(rng; gain=1.0) +weights = weights_cl(rng, 2, 5) +#2×5 Matrix{Float32}: +# 0.484056 0.231723 0.164379 0.306147 0.18365 +# 0.0836414 0.666965 -0.396323 -0.711329 -0.382971 + +# Passing kwargs (if needed) with default rng call +weights_cl = kaiming_normal(; gain=1.0) +weights = weights_cl(2, 5) +#2×5 Matrix{Float32}: +# -0.160876 -0.187646 0.18794 0.918918 -0.136356 +# 0.486214 0.321506 -0.306641 0.145296 0.206476 +``` + +## API + +The package is meant to be working with deep learning libraries such as F/Lux. All the +methods take as input the chosen `rng` type and the dimension for the AbstractArray. + +```julia +weights = init(rng, dims...) +``` + +The `rng` is optional, if not specified a default one will be used. + +```julia +weights = init(dims...) +``` + +If there is the need to use keyword arguments the methods can be called with just the `rng` +(optionally) and the keywords to get in return a function behaving like the two examples +above. + +```julia +weights_init = init(rng; kwargs...) +weights = weights_init(rng, dims...) +# or +weights_init = init(; kwargs...) +weights = weights_init(dims...) +``` diff --git a/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl b/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl new file mode 100644 index 000000000..ad0fa20c5 --- /dev/null +++ b/lib/WeightInitializers/ext/WeightInitializersAMDGPUExt.jl @@ -0,0 +1,38 @@ +module WeightInitializersAMDGPUExt + +using AMDGPU: AMDGPU, ROCArray +using GPUArrays: RNG +using Random: Random +using WeightInitializers: DeviceAgnostic + +function DeviceAgnostic.zeros( + ::AMDGPU.rocRAND.RNG, ::Type{T}, dims::Integer...) where {T <: Number} + return AMDGPU.zeros(T, dims...) +end +function DeviceAgnostic.ones( + ::AMDGPU.rocRAND.RNG, ::Type{T}, dims::Integer...) where {T <: Number} + return AMDGPU.ones(T, dims...) +end + +function DeviceAgnostic.zeros( + ::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} + return AMDGPU.zeros(T, dims...) +end +function DeviceAgnostic.ones( + ::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} + return AMDGPU.ones(T, dims...) +end +function DeviceAgnostic.rand( + rng::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} + y = ROCArray{T}(undef, dims...) + Random.rand!(rng, y) + return y +end +function DeviceAgnostic.randn( + rng::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number} + y = ROCArray{T}(undef, dims...) + Random.randn!(rng, y) + return y +end + +end diff --git a/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl new file mode 100644 index 000000000..db7573f58 --- /dev/null +++ b/lib/WeightInitializers/ext/WeightInitializersCUDAExt.jl @@ -0,0 +1,40 @@ +module WeightInitializersCUDAExt + +using CUDA: CUDA, CURAND, CuArray +using GPUArrays: RNG +using Random: Random +using WeightInitializers: DeviceAgnostic + +const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} + +function DeviceAgnostic.zeros( + ::AbstractCuRNG, ::Type{T}, dims::Integer...) where {T <: Number} + return CUDA.zeros(T, dims...) +end +function DeviceAgnostic.ones( + ::AbstractCuRNG, ::Type{T}, dims::Integer...) where {T <: Number} + return CUDA.ones(T, dims...) +end + +function DeviceAgnostic.zeros( + ::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} + return CUDA.zeros(T, dims...) +end +function DeviceAgnostic.ones( + ::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} + return CUDA.ones(T, dims...) +end +function DeviceAgnostic.rand( + rng::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} + y = CuArray{T}(undef, dims...) + Random.rand!(rng, y) + return y +end +function DeviceAgnostic.randn( + rng::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number} + y = CuArray{T}(undef, dims...) + Random.randn!(rng, y) + return y +end + +end diff --git a/lib/WeightInitializers/ext/WeightInitializersChainRulesCoreExt.jl b/lib/WeightInitializers/ext/WeightInitializersChainRulesCoreExt.jl new file mode 100644 index 000000000..2b54893d3 --- /dev/null +++ b/lib/WeightInitializers/ext/WeightInitializersChainRulesCoreExt.jl @@ -0,0 +1,18 @@ +module WeightInitializersChainRulesCoreExt + +using ChainRulesCore: @non_differentiable +using WeightInitializers: WeightInitializers, DeviceAgnostic + +for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, + :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, + :randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, + :randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal, + :kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init] + @eval @non_differentiable WeightInitializers.$(f)(::Any...) +end + +for f in (:zeros, :ones, :rand, :randn) + @eval @non_differentiable DeviceAgnostic.$(f)(::Any...) +end + +end diff --git a/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl b/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl new file mode 100644 index 000000000..78e0ec63a --- /dev/null +++ b/lib/WeightInitializers/ext/WeightInitializersGPUArraysExt.jl @@ -0,0 +1,24 @@ +module WeightInitializersGPUArraysExt + +using GPUArrays: RNG +using WeightInitializers: DeviceAgnostic + +for f in (:zeros, :ones, :rand, :randn) + @eval function DeviceAgnostic.$(f)( + rng::RNG, ::Type{T}, dims::Integer...) where {T <: Number} + return DeviceAgnostic.$(f)(rng, rng.state, T, dims...) + end +end + +## Certain backends don't support sampling Complex numbers, so we avoid hitting those +## dispatches +for f in (:rand, :randn) + @eval function DeviceAgnostic.$(f)( + rng::RNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number} + real_part = DeviceAgnostic.$(f)(rng, rng.state, T, args...) + imag_part = DeviceAgnostic.$(f)(rng, rng.state, T, args...) + return Complex{T}.(real_part, imag_part) + end +end + +end diff --git a/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl b/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl new file mode 100644 index 000000000..79e5b34da --- /dev/null +++ b/lib/WeightInitializers/ext/WeightInitializersMetalExt.jl @@ -0,0 +1,29 @@ +module WeightInitializersMetalExt + +using Metal: Metal, MtlArray +using GPUArrays: RNG +using Random: Random +using WeightInitializers: DeviceAgnostic + +function DeviceAgnostic.zeros( + ::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} + return Metal.zeros(T, dims...) +end +function DeviceAgnostic.ones( + ::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} + return Metal.ones(T, dims...) +end +function DeviceAgnostic.rand( + rng::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} + y = MtlArray{T}(undef, dims...) + Random.rand!(rng, y) + return y +end +function DeviceAgnostic.randn( + rng::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number} + y = MtlArray{T}(undef, dims...) + Random.randn!(rng, y) + return y +end + +end diff --git a/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl b/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl new file mode 100644 index 000000000..e1827e115 --- /dev/null +++ b/lib/WeightInitializers/ext/WeightInitializersoneAPIExt.jl @@ -0,0 +1,29 @@ +module WeightInitializersoneAPIExt + +using oneAPI: oneAPI, oneArray +using GPUArrays: RNG +using Random: Random +using WeightInitializers: DeviceAgnostic + +function DeviceAgnostic.zeros( + ::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} + return oneAPI.zeros(T, dims...) +end +function DeviceAgnostic.ones( + ::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} + return oneAPI.ones(T, dims...) +end +function DeviceAgnostic.rand( + rng::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} + y = oneArray{T}(undef, dims...) + Random.rand!(rng, y) + return y +end +function DeviceAgnostic.randn( + rng::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number} + y = oneArray{T}(undef, dims...) + Random.randn!(rng, y) + return y +end + +end diff --git a/lib/WeightInitializers/src/WeightInitializers.jl b/lib/WeightInitializers/src/WeightInitializers.jl new file mode 100644 index 000000000..6702f3fec --- /dev/null +++ b/lib/WeightInitializers/src/WeightInitializers.jl @@ -0,0 +1,22 @@ +module WeightInitializers + +using ArgCheck: @argcheck +using GPUArraysCore: @allowscalar +using LinearAlgebra: LinearAlgebra, Diagonal, qr +using Random: Random, AbstractRNG, shuffle +using SpecialFunctions: SpecialFunctions, erfinv # TODO: Move to Ext in v2.0 +using Statistics: Statistics, std + +include("partial.jl") +include("utils.jl") +include("initializers.jl") + +export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16, + rand16, randn16 +export zerosC64, onesC64, randC64, randnC64, zerosC32, onesC32, randC32, randnC32, zerosC16, + onesC16, randC16, randnC16 +export glorot_normal, glorot_uniform +export kaiming_normal, kaiming_uniform +export truncated_normal, orthogonal, sparse_init, identity_init + +end diff --git a/lib/WeightInitializers/src/initializers.jl b/lib/WeightInitializers/src/initializers.jl new file mode 100644 index 000000000..81de6a17c --- /dev/null +++ b/lib/WeightInitializers/src/initializers.jl @@ -0,0 +1,372 @@ +for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros, :rand, :randn) + name = Symbol(fname, T) + docstring = Utils.generic_docstring(string(name)) + TP = Utils.NUM_TO_FPOINT[Symbol(T)] + + @eval begin + @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) + return DeviceAgnostic.$(fname)(rng, $TP, dims...; kwargs...) + end + end +end + +""" + glorot_uniform([::AbstractRNG=Utils.default_rng()], [T=Float32], size...; + gain = 1) -> AbstractArray{T, length(size)} + +Return an `AbstractArray{T}` of the given `size` containing random numbers drawn from a +uniform distribution on the interval ``[-x, x]``, where +`x = gain * sqrt(6 / (fan_in + fan_out))`. This method is described in [1] and also known as +Xavier initialization. + +# References + +[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep +feedforward neural networks." _Proceedings of the thirteenth international conference on +artificial intelligence and statistics_. 2010. +""" +function glorot_uniform( + rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} + scale = T(gain) * sqrt(T(24) / sum(Utils.nfan(dims...))) + x = DeviceAgnostic.rand(rng, T, dims...) + half = T(0.5) + @. x = (x - half) * scale + return x +end + +""" + glorot_normal([::AbstractRNG=Utils.default_rng()], [T=Float32], size...; + gain = 1) -> AbstractArray{T, length(size)} + +Return an `AbstractArray{T}` of the given `size` containing random numbers drawn from a +normal distribution with standard deviation `gain * sqrt(2 / (fan_in + fan_out))`. This +method is described in [1] and also known as Xavier initialization. + +# References + +[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep +feedforward neural networks." _Proceedings of the thirteenth international conference on +artificial intelligence and statistics_. 2010. +""" +function glorot_normal( + rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} + std = T(gain) * sqrt(T(2) / sum(Utils.nfan(dims...))) + x = DeviceAgnostic.randn(rng, T, dims...) + x .*= std + return x +end + +""" + kaiming_uniform([::AbstractRNG=Utils.default_rng()], [T=Float32], size...; + gain = √T(2)) -> AbstractArray{T, length(size)} + +Return an `AbstractArray{T}` of the given `size` containing random numbers drawn from a +uniform distribution on the interval `[-x, x]`, where `x = gain * sqrt(3/fan_in)`. + +# References + +[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on +imagenet classification." _Proceedings of the IEEE international conference on computer +vision_. 2015. +""" +function kaiming_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; + gain::Number=√T(2)) where {T <: Number} + bound = √T(3) * T(gain) / sqrt(T(first(Utils.nfan(dims...)))) + x = DeviceAgnostic.rand(rng, T, dims...) + half = T(0.5) + @. x = (x - half) * 2 * bound + return x +end + +""" + kaiming_normal([::AbstractRNG=Utils.default_rng()], [T=Float32], size...; + gain = √T(2)) -> AbstractArray{T, length(size)} + +Return an `AbstractArray{T}` of the given `size` containing random numbers taken from a +normal distribution standard deviation `gain / sqrt(fan_in)` + +# References + +[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on +imagenet classification." _Proceedings of the IEEE international conference on computer +vision_. 2015. +""" +function kaiming_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; + gain::Number=√T(2)) where {T <: Number} + std = T(gain) / sqrt(T(first(Utils.nfan(dims...)))) + x = DeviceAgnostic.randn(rng, T, dims...) + x .*= std + return x +end + +""" + truncated_normal([::AbstractRNG=Utils.default_rng()], [T=Float32], size...; mean = 0, + std = 1, lo = -2, hi = 2) -> AbstractArray{T, length(size)} + +Return an `AbstractArray{T}` of the given `size` where each element is drawn from a +truncated normal distribution. The numbers are distributed like +`filter(x -> lo ≤ x ≤ hi, mean .+ std .* randn(100))`. +""" +function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T(0), + std=T(1), lo=-T(2), hi=T(2)) where {T <: Real} + if (mean < lo - 2 * std) || (mean > hi + 2 * std) + @warn "Mean is more than 2 std outside the limits in truncated_normal, so the \ + distribution of values may be inaccurate." + end + l = Utils.norm_cdf((T(lo) - T(mean)) / T(std)) + u = Utils.norm_cdf((T(hi) - T(mean)) / T(std)) + xs = DeviceAgnostic.rand(rng, T, dims...) + broadcast!(xs, xs) do x + x = x * 2(u - l) + (2l - one(T)) + x = erfinv(x) + return clamp(x * T(std) * √T(2) + T(mean), T(lo), T(hi)) + end + return xs +end + +""" + orthogonal([::AbstractRNG=Utils.default_rng()], [T=Float32], dims::Integer...; + gain = 1) -> AbstractArray{T, length(dims)} + +Return an `AbstractArray{T}` of the given dimensions (`dims`) which is a +(semi) orthogonal matrix, as described in [1]. + +The function constructs an orthogonal or semi-orthogonal matrix depending on the specified +dimensions. For two dimensions, it returns a matrix where `dims = (rows, cols)`. For more +than two dimensions, it computes an orthogonal matrix of size `prod(dims[1:(end - 1)])` by +`dims[end]` before reshaping it to the original dimensions. + +Cannot construct a vector, i.e., `length(dims) == 1` is forbidden. + +# Arguments + + - `rng::AbstractRNG`: Random number generator. + - `T::Type{<:Real}`: The type of the elements in the array. + - `dims::Integer...`: The dimensions of the array. + - `gain::Number`: Scaling factor for the elements of the orthogonal matrix. + +# References + +[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in +deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120 +""" +function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; + gain::Number=T(1.0)) where {T <: Number} + @argcheck length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed" + + rows, cols = length(dims) == 2 ? dims : (prod(dims[1:(end - 1)]), dims[end]) + rows < cols && return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain))) + + mat = DeviceAgnostic.randn(rng, T, rows, cols) + Q, R = qr(mat) + mat .= Q * sign.(Diagonal(R)) .* T(gain) + + return length(dims) > 2 ? reshape(mat, dims) : mat +end + +""" + sparse_init([::AbstractRNG=Utils.default_rng()], [T=Float32], dims::Integer...; + sparsity::Number, std::Number=0.01) -> AbstractArray{T} + +Creates a sparsely initialized weight matrix with a specified proportion of zeroed elements, +using random numbers drawn from a normal distribution for the non-zero elements. This method +was introduced in [1]. + +!!! note + + The sparsity parameter controls the proportion of the matrix that will be zeroed. For + example, a sparsity of 0.3 means that approximately 30% of the elements will be set to + zero. The non-zero elements are distributed according to a normal distribution, scaled + by the std parameter. + +# Arguments + + - `rng::AbstractRNG`: The random number generator to use. + - `T::Type{<:Number}`: The numeric type of the elements in the returned array. + - `dims::Integer...`: The dimensions of the weight matrix to be generated. + - `sparsity::Number`: The proportion of elements to be zeroed. Must be between 0 and 1. + - `std::Number=0.01`: The standard deviation of the normal distribution before applying + `gain`. + +# Returns + + - `AbstractArray{T}`: A sparsely initialized weight matrix of dimensions `dims` and type + `T`. + +# Examples + +```jldoctest +julia> y = sparse_init(Xoshiro(123), Float32, 5, 5; sparsity=0.3, std=0.01); + +julia> y isa Matrix{Float32} +true + +julia> size(y) == (5, 5) +true +``` + +# References + +[1] Martens, J, "Deep learning via Hessian-free optimization" Proceedings of the 27th +International Conference on International Conference on Machine Learning. 2010. +""" +function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; + sparsity::Number, std::Number=T(0.01)) where {T <: Number} + if length(dims) != 2 + throw(ArgumentError("Only 2-dimensional outputs are supported for sparse \ + initialization.")) + end + + rows, _ = dims + prop_zero = min(1.0, sparsity) + num_zeros = ceil(Integer, prop_zero * rows) + + sparse_array = DeviceAgnostic.randn(rng, T, dims...) + sparse_array .*= T(std) + fill!(view(sparse_array, 1:num_zeros, :), zero(T)) + + return @allowscalar mapslices(shuffle, sparse_array; dims=1) +end + +""" + identity_init([::AbstractRNG=Utils.default_rng()], [T=Float32], size...; gain::Number=1, + shift::Union{Integer, Tuple{Integer, Integer}}=0) -> AbstractArray{T} + +Constructs an array that aims to provide an identity mapping when used as parameters in +most layers of a neural network. The identity mapping is scaled by the `gain` parameter. + +# Behavior + + - 1D: Returns a `Vector` of zeros (useful for biases in layers where + `input_size == output_size`). + - 2D: Returns an identity matrix + (useful for fully connected layers with equal input and output sizes). + - More than 2D: Returns a tensor where the central slice along the last + two dimensions is an identity matrix, and the rest are zeros + (useful for convolutional layers, simulating an identity convolution). + +# Caveats + + - Not all layers will result in an identity mapping when using this initializer. + Exceptions include recurrent and normalization layers. + - Layers must have `input_size == output_size` for a perfect identity mapping. + In cases where this condition is not met, the function pads extra dimensions with zeros. + - For convolutional layers to achieve an identity mapping, kernel sizes must be odd, + and appropriate padding must be applied to ensure the output feature maps are the same + size as the input feature maps. + +# Arguments + + - `rng::AbstractRNG`: An optional random number generator, + included for consistency with other initializers but ignored since the + output is deterministic. + - `T::Type{<:Number}`: The numeric type of the array elements. + - `size...`: The dimensions of the array to be initialized. + - `gain::Number=1`: A scaling factor applied to the identity mapping. + - `shift::Union{Integer, Tuple{Integer, Integer}}=0`: An integer or + a tuple specifying the circular shift applied to the output array. + +# Returns + + - `AbstractArray{T}`: An array initialized to represent an identity mapping, + scaled by `gain` and optionally shifted by `shift`. + +# Examples + +```jldoctest +julia> identity_init(Xoshiro(123), Float32, 5, 5) +5×5 Matrix{Float32}: + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + +julia> identity_init(Xoshiro(123), Float32, 3, 3, 1, 1; gain=1.5) +3×3×1×1 Array{Float32, 4}: +[:, :, 1, 1] = + 0.0 0.0 0.0 + 0.0 1.5 0.0 + 0.0 0.0 0.0 +``` +""" +function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; + gain::Number=1, shift::Integer=0) where {T <: Number} + length(dims) == 1 && return DeviceAgnostic.zeros(rng, T, dims...) # Bias initialization + + if length(dims) == 2 + rows, cols = dims + mat = DeviceAgnostic.zeros(rng, T, rows, cols) + diag_indices = 1:min(rows, cols) + fill!(view(mat, diag_indices, diag_indices), T(gain)) + return circshift(mat, shift) + end + + # Convolution or more dimensions + nin, nout = dims[end - 1], dims[end] + centers = map(d -> cld(d, 2), dims[1:(end - 2)]) + weights = DeviceAgnostic.zeros(rng, T, dims...) + @allowscalar for i in 1:min(nin, nout) + index = (centers..., i, i) + weights[index...] = T(gain) + end + return circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift)) +end + +# Default Fallbacks for all functions +for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_normal, + :truncated_normal, :orthogonal, :sparse_init, :identity_init) + NType = ifelse(initializer === :truncated_normal, Real, Number) + @eval begin + function ($initializer)(dims::Integer...; kwargs...) + return $initializer(Utils.default_rng(), Float32, dims...; kwargs...) + end + function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...) + return $initializer(rng, Float32, dims...; kwargs...) + end + function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T <: $NType} + return $initializer(Utils.default_rng(), T, dims...; kwargs...) + end + + # Partial application + function ($initializer)(rng::AbstractRNG; kwargs...) + return PartialFunction.Partial{Nothing}($initializer, rng, kwargs) + end + function ($initializer)(::Type{T}; kwargs...) where {T <: $NType} + return PartialFunction.Partial{T}($initializer, nothing, kwargs) + end + function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: $NType} + return PartialFunction.Partial{T}($initializer, rng, kwargs) + end + function ($initializer)(; kwargs...) + return PartialFunction.Partial{Nothing}($initializer, nothing, kwargs) + end + end +end + +for tp in ("16", "32", "64", "C16", "C32", "C64"), func in (:zeros, :ones, :randn, :rand) + initializer = Symbol(func, tp) + @eval begin + function ($initializer)(dims::Integer...; kwargs...) + return $initializer(Utils.default_rng(), dims...; kwargs...) + end + function ($initializer)(::Type{T}, dims::Integer...; kwargs...) where {T} + throw(ArgumentError(string($initializer) * " doesn't accept a type argument.")) + end + function ($initializer)( + ::AbstractRNG, ::Type{T}, dims::Integer...; kwargs...) where {T} + throw(ArgumentError(string($initializer) * " doesn't accept a type argument.")) + end + + # Partial application + function ($initializer)(rng::AbstractRNG; kwargs...) + return PartialFunction.Partial{Missing}($initializer, rng, kwargs) + end + function ($initializer)(rng::AbstractRNG, ::Type{T}; kwargs...) where {T} + throw(ArgumentError(string($initializer) * " doesn't accept a type argument.")) + end + function ($initializer)(; kwargs...) + return PartialFunction.Partial{Missing}($initializer, nothing, kwargs) + end + end +end diff --git a/lib/WeightInitializers/src/partial.jl b/lib/WeightInitializers/src/partial.jl new file mode 100644 index 000000000..52cde29a9 --- /dev/null +++ b/lib/WeightInitializers/src/partial.jl @@ -0,0 +1,51 @@ +module PartialFunction + +using ArgCheck: @argcheck +using ConcreteStructs: @concrete +using Random: AbstractRNG + +@concrete struct Partial{T} <: Function + f <: Function + rng <: Union{Nothing, AbstractRNG} + kwargs +end + +function Base.show(io::IO, ::MIME"text/plain", f::Partial{T}) where {T} + print(io, "$(f.f)(") + if f.rng !== nothing + print(io, "$(nameof(typeof(f.rng)))(...), ") + else + print(io, "rng, ") + end + if T === Nothing + print(io, "::Type{T}, ") + else + T !== Missing ? print(io, "$(T), ") : nothing + end + print(io, "dims...") + kwargs_str = String[] + for (k, v) in pairs(f.kwargs) + push!(kwargs_str, "$(k)=$(v)") + end + length(kwargs_str) > 0 && print(io, "; ", join(kwargs_str, ", ")) + print(io, ")") +end + +function (f::Partial{<:Union{Nothing, Missing}})(args...; kwargs...) + f.rng === nothing && return f.f(args...; f.kwargs..., kwargs...) + return f.f(f.rng, args...; f.kwargs..., kwargs...) +end +function (f::Partial{<:Union{Nothing, Missing}})(rng::AbstractRNG, args...; kwargs...) + @argcheck f.rng === nothing + return f.f(rng, args...; f.kwargs..., kwargs...) +end +function (f::Partial{T})(args...; kwargs...) where {T <: Number} + f.rng === nothing && return f.f(T, args...; f.kwargs..., kwargs...) + return f.f(f.rng, T, args...; f.kwargs..., kwargs...) +end +function (f::Partial{T})(rng::AbstractRNG, args...; kwargs...) where {T <: Number} + @argcheck f.rng === nothing + return f.f(rng, T, args...; f.kwargs..., kwargs...) +end + +end diff --git a/lib/WeightInitializers/src/utils.jl b/lib/WeightInitializers/src/utils.jl new file mode 100644 index 000000000..e2a3a363f --- /dev/null +++ b/lib/WeightInitializers/src/utils.jl @@ -0,0 +1,78 @@ +module Utils + +using Random: Xoshiro +using SpecialFunctions: erf + +nfan() = 1, 1 # fan_in, fan_out +nfan(n) = 1, n # A vector is treated as a n×1 matrix +nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices +nfan(dims::Tuple) = nfan(dims...) +nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels + +norm_cdf(x::T) where {T} = T(0.5) * (1 + T(erf(x / √2))) # erf often doesn't respect the type + +default_rng() = Xoshiro(1234) + +#! format: off +const NAME_TO_DIST = Dict( + :zeros => "an AbstractArray of zeros", + :ones => "an AbstractArray of ones", + :randn => "random numbers from a standard normal distribution", + :rand => "random numbers from a uniform distribution" +) +const NUM_TO_FPOINT = Dict( + Symbol(16) => Float16, + Symbol(32) => Float32, + Symbol(64) => Float64, + :C16 => ComplexF16, + :C32 => ComplexF32, + :C64 => ComplexF64 +) +#! format: on + +function function_name(fname::String) + fp = fname[(end - 2):end] + Symbol(fp) in keys(NUM_TO_FPOINT) && return fname[1:(end - 3)], fp + return fname[1:(end - 2)], fname[(end - 1):end] +end + +function generic_docstring(fname::String) + funcname, fp = function_name(fname) + name = NAME_TO_DIST[Symbol(funcname)] + dist_type = NUM_TO_FPOINT[Symbol(fp)] + return """ + $fname([::AbstractRNG=Utils.default_rng()], size...; + kwargs...) -> AbstractArray{$(dist_type), length(size)} + + Return an `AbstractArray{$(dist_type)}` of the given `size` containing $(name). + """ +end + +end + +module DeviceAgnostic + +using Random: AbstractRNG + +# Helpers for device agnostic initializers +function zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} + return Base.zeros(T, dims...) +end +ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T <: Number} = Base.ones(T, dims...) +function rand(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number} + return Base.rand(rng, T, args...) +end +function randn(rng::AbstractRNG, ::Type{T}, args::Integer...) where {T <: Number} + return Base.randn(rng, T, args...) +end + +## Certain backends don't support sampling Complex numbers, so we avoid hitting those +## dispatches +for f in (:rand, :randn) + @eval function $(f)( + rng::AbstractRNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number} + return Complex{T}.($(f)(rng, T, args...), $(f)(rng, T, args...)) + end +end + +end diff --git a/lib/WeightInitializers/test/Project.toml b/lib/WeightInitializers/test/Project.toml new file mode 100644 index 000000000..ce6ba7994 --- /dev/null +++ b/lib/WeightInitializers/test/Project.toml @@ -0,0 +1,31 @@ +[deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +Aqua = "0.8.7" +Documenter = "1.5.0" +ExplicitImports = "1.9.0" +GPUArrays = "10.2" +GPUArraysCore = "0.1.6" +Hwloc = "3.3" +InteractiveUtils = "<0.0.1, 1" +LinearAlgebra = "1.10" +Pkg = "1.10" +Random = "1.10" +ReTestItems = "1.24.0" +StableRNGs = "1" +Statistics = "1.10" +Test = "1.10" diff --git a/lib/WeightInitializers/test/initializers_tests.jl b/lib/WeightInitializers/test/initializers_tests.jl new file mode 100644 index 000000000..8f09f3ab0 --- /dev/null +++ b/lib/WeightInitializers/test/initializers_tests.jl @@ -0,0 +1,350 @@ +@testitem "Warning: truncated_normal" begin + @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so \ + the distribution of values may be inaccurate." truncated_normal(2; mean=-5.0f0) +end + +@testitem "Identity Initialization" begin + @testset "Non-identity sizes" begin + @test identity_init(2, 3)[:, end] == zeros(Float32, 2) + @test identity_init(3, 2; shift=1)[1, :] == zeros(Float32, 2) + @test identity_init(1, 1, 3, 4)[:, :, :, end] == zeros(Float32, 1, 1, 3) + @test identity_init(2, 1, 3, 3)[end, :, :, :] == zeros(Float32, 1, 3, 3) + @test identity_init(1, 2, 3, 3)[:, end, :, :] == zeros(Float32, 1, 3, 3) + end +end + +@testitem "Orthogonal Initialization" setup=[SharedTestSetup] begin + using GPUArraysCore, LinearAlgebra + + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64, backend) in RNGS_ARRTYPES + # A matrix of dim = (m,n) with m > n should produce a QR decomposition. + # In the other case, the transpose should be taken to compute the QR decomposition. + if backend == "oneapi" || backend == "metal" # `qr` not implemented + @test_broken orthogonal(rng, 3, 5) isa arrtype{Float32, 2} + continue + end + + for (rows, cols) in [(5, 3), (3, 5)] + v = orthogonal(rng, rows, cols) + GPUArraysCore.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : + (@test v' * v ≈ I(cols)) + end + + for mat in [(3, 4, 5), (2, 2, 5)] + v = orthogonal(rng, mat...) + cols = mat[end] + rows = div(prod(mat), cols) + v = reshape(v, (rows, cols)) + GPUArraysCore.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : + (@test v' * v ≈ I(cols)) + end + + @testset "Orthogonal Types $T" for T in (Float32, Float64) + !supports_fp64 && T == Float64 && continue + + @test eltype(orthogonal(rng, T, 3, 4; gain=1.5)) == T + @test eltype(orthogonal(rng, T, 3, 4, 5; gain=1.5)) == T + end + + @testset "Orthogonal AbstractArray Type $T" for T in (Float32, Float64) + !supports_fp64 && T == Float64 && continue + + @test orthogonal(rng, T, 3, 5) isa AbstractArray{T, 2} + @test orthogonal(rng, T, 3, 5) isa arrtype{T, 2} + + cl = orthogonal(rng) + display(cl) + @test cl(T, 3, 5) isa arrtype{T, 2} + + cl = orthogonal(rng, T) + display(cl) + @test cl(3, 5) isa arrtype{T, 2} + end + + @testset "Orthogonal Closure" begin + cl = orthogonal(;) + display(cl) + + # Sizes + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + @test size(cl(3, 4, 5)) == (3, 4, 5) + @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) + + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end + end +end + +@testitem "Sparse Initialization" setup=[SharedTestSetup] begin + using Statistics + + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64, backend) in RNGS_ARRTYPES + # sparse_init should yield an error for non 2-d dimensions + # sparse_init should yield no zero elements if sparsity < 0 + # sparse_init should yield all zero elements if sparsity > 1 + # sparse_init should yield exactly ceil(n_in * sparsity) elements in each column for + # other sparsity values + # sparse_init should yield a kernel in its non-zero elements consistent with the std + # parameter + + @test_throws ArgumentError sparse_init(3, 4, 5, sparsity=0.1) + @test_throws ArgumentError sparse_init(3, sparsity=0.1) + v = sparse_init(100, 100; sparsity=-0.1) + @test sum(v .== 0) == 0 + v = sparse_init(100, 100; sparsity=1.1) + @test sum(v .== 0) == length(v) + + for (n_in, n_out, sparsity, σ) in [(100, 100, 0.25, 0.1), (100, 400, 0.75, 0.01)] + expected_zeros = ceil(Integer, n_in * sparsity) + v = sparse_init(n_in, n_out; sparsity=sparsity, std=σ) + @test all([sum(v[:, col] .== 0) == expected_zeros for col in 1:n_out]) + @test 0.9 * σ < std(v[v .!= 0]) < 1.1 * σ + end + + @testset "sparse_init Type $T" for T in (Float16, Float32, Float64) + !supports_fp64 && T == Float64 && continue + + @test eltype(sparse_init(rng, T, 3, 4; sparsity=0.5)) == T + end + + @testset "sparse_init AbstractArray Type $T" for T in (Float16, Float32, Float64) + !supports_fp64 && T == Float64 && continue + + @test sparse_init(T, 3, 5; sparsity=0.5) isa AbstractArray{T, 2} + @test sparse_init(rng, T, 3, 5; sparsity=0.5) isa arrtype{T, 2} + + cl = sparse_init(rng; sparsity=0.5) + display(cl) + @test cl(T, 3, 5) isa arrtype{T, 2} + + cl = sparse_init(rng, T; sparsity=0.5) + display(cl) + @test cl(3, 5) isa arrtype{T, 2} + end + + @testset "sparse_init Closure" begin + cl = sparse_init(; sparsity=0.5) + display(cl) + + # Sizes + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end + end +end + +@testitem "Basic Initializations" setup=[SharedTestSetup] begin + using LinearAlgebra, Statistics + + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype, supports_fp64, backend) in RNGS_ARRTYPES + @testset "Sizes and Types: $init" for init in [ + zeros32, ones32, rand32, randn32, kaiming_uniform, kaiming_normal, + glorot_uniform, glorot_normal, truncated_normal, identity_init] + !supports_fp64 && + (init === zeros32 || + init === ones32 || + init === rand32 || + init === randn32) && + continue + + if backend == "oneapi" && init === truncated_normal + @test_broken size(init(rng, 3)) == (3,) # `erfinv` not implemented + continue + end + + # Sizes + @test size(init(3)) == (3,) + @test size(init(rng, 3)) == (3,) + @test size(init(3, 4)) == (3, 4) + @test size(init(rng, 3, 4)) == (3, 4) + @test size(init(3, 4, 5)) == (3, 4, 5) + @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + + # Type + @test eltype(init(rng, 4, 2)) == Float32 + @test eltype(init(4, 2)) == Float32 + + # RNG Closure + cl = init(rng) + display(cl) + @test cl(3) isa arrtype{Float32, 1} + @test cl(3, 5) isa arrtype{Float32, 2} + end + + @testset "Sizes and Types: $init" for (init, fp) in [ + (zeros16, Float16), (zerosC16, ComplexF16), (zeros32, Float32), + (zerosC32, ComplexF32), (zeros64, Float64), (zerosC64, ComplexF64), + (ones16, Float16), (onesC16, ComplexF16), (ones32, Float32), + (onesC32, ComplexF32), (ones64, Float64), (onesC64, ComplexF64), + (rand16, Float16), (randC16, ComplexF16), (rand32, Float32), + (randC32, ComplexF32), (rand64, Float64), (randC64, ComplexF64), + (randn16, Float16), (randnC16, ComplexF16), (randn32, Float32), + (randnC32, ComplexF32), (randn64, Float64), (randnC64, ComplexF64)] + !supports_fp64 && (fp == Float64 || fp == ComplexF64) && continue + + # Sizes + @test size(init(3)) == (3,) + @test size(init(rng, 3)) == (3,) + @test size(init(3, 4)) == (3, 4) + @test size(init(rng, 3, 4)) == (3, 4) + @test size(init(3, 4, 5)) == (3, 4, 5) + @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + + # Type + @test eltype(init(rng, 4, 2)) == fp + @test eltype(init(4, 2)) == fp + + # RNG Closure + cl = init(rng) + display(cl) + @test cl(3) isa arrtype{fp, 1} + @test cl(3, 5) isa arrtype{fp, 2} + + # Kwargs closure + cl = init(;) + display(cl) + @test cl(rng, 3) isa arrtype{fp, 1} + @test cl(rng, 3, 5) isa arrtype{fp, 2} + + # throw error on type as input + @test_throws ArgumentError init(Float32) + @test_throws ArgumentError init(Float32, 3, 5) + @test_throws ArgumentError init(rng, Float32) + @test_throws ArgumentError init(rng, Float32, 3, 5) + end + + @testset "AbstractArray Type: $init $T" for init in [ + kaiming_uniform, kaiming_normal, glorot_uniform, + glorot_normal, truncated_normal, identity_init], + T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) + + !supports_fp64 && (T == Float64 || T == ComplexF64) && continue + + init === truncated_normal && !(T <: Real) && continue + + if backend == "oneapi" && init === truncated_normal && T == Float32 + @test_broken init(rng, T, 3) isa AbstractArray{T, 1} # `erfinv` not implemented + continue + end + + @test init(T, 3) isa AbstractArray{T, 1} + @test init(rng, T, 3) isa arrtype{T, 1} + @test init(T, 3, 5) isa AbstractArray{T, 2} + @test init(rng, T, 3, 5) isa arrtype{T, 2} + + cl = init(rng) + display(cl) + @test cl(T, 3) isa arrtype{T, 1} + @test cl(T, 3, 5) isa arrtype{T, 2} + + cl = init(rng, T) + display(cl) + @test cl(3) isa arrtype{T, 1} + @test cl(3, 5) isa arrtype{T, 2} + + cl = init(T) + display(cl) + @test cl(3) isa Array{T, 1} + @test cl(3, 5) isa Array{T, 2} + @test cl(rng, 3, 5) isa arrtype{T, 2} + end + + @testset "Closure: $init" for init in [ + kaiming_uniform, kaiming_normal, glorot_uniform, + glorot_normal, truncated_normal, identity_init] + if backend == "oneapi" && init === truncated_normal + @test_broken size(init(rng, 3)) == (3,) # `erfinv` not implemented + continue + end + + cl = init(;) + display(cl) + + # Sizes + @test size(cl(3)) == (3,) + @test size(cl(rng, 3)) == (3,) + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + @test size(cl(3, 4, 5)) == (3, 4, 5) + @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) + + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end + + @testset "Kwargs types" for T in ( + Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) + !supports_fp64 && (T == Float64 || T == ComplexF64) && continue + + if (T <: Real) + @test eltype(truncated_normal(T, 2, 5; mean=0, std=1, lo=-2, hi=2)) == T + @test eltype(orthogonal(T, 2, 5; gain=1.0)) == T + end + @test eltype(glorot_uniform(T, 2, 5; gain=1.0)) == T + @test eltype(glorot_normal(T, 2, 5; gain=1.0)) == T + @test eltype(kaiming_uniform(T, 2, 5; gain=sqrt(2))) == T + @test eltype(kaiming_normal(T, 2, 5; gain=sqrt(2))) == T + @test eltype(identity_init(T, 2, 5; gain=1.0)) == T + @test eltype(sparse_init(T, 2, 5; sparsity=0.5, std=0.01)) == T + end + + @testset "kaiming" begin + # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] + # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out) + for (n_in, n_out) in [(100, 100), (100, 400)] + v = kaiming_uniform(rng, n_in, n_out) + σ2 = sqrt(6 / n_out) + @test -1σ2 < minimum(v) < -0.9σ2 + @test 0.9σ2 < maximum(v) < 1σ2 + + v = kaiming_normal(rng, n_in, n_out) + σ2 = sqrt(2 / n_out) + + if (backend == "cuda" || backend == "amdgpu") && rng isa GPUArrays.RNG + @test_broken 0.9σ2 < std(v) < 1.1σ2 + else + @test 0.9σ2 < std(v) < 1.1σ2 + end + end + # Type + @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5f0)) == Float32 + @test eltype(kaiming_normal(rng, 3, 4; gain=1.5f0)) == Float32 + end + + @testset "glorot: $init" for init in [glorot_uniform, glorot_normal] + # glorot_uniform and glorot_normal should both yield a kernel with + # variance ≈ 2/(fan_in + fan_out) + for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] + v = init(dims...) + fan_in, fan_out = WeightInitializers.Utils.nfan(dims...) + σ2 = 2 / (fan_in + fan_out) + @test 0.9σ2 < var(v) < 1.1σ2 + end + @test eltype(init(3, 4; gain=1.5)) == Float32 + end + + @testset "orthogonal" begin + # A matrix of dim = (m,n) with m > n should produce a QR decomposition. In the other case, the transpose should be taken to compute the QR decomposition. + for (rows, cols) in [(5, 3), (3, 5)] + v = orthogonal(rows, cols) + rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) + end + for mat in [(3, 4, 5), (2, 2, 5)] + v = orthogonal(mat...) + cols = mat[end] + rows = div(prod(mat), cols) + v = reshape(v, (rows, cols)) + rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) + end + @test eltype(orthogonal(3, 4; gain=1.5)) == Float32 + end + end +end diff --git a/lib/WeightInitializers/test/qa_tests.jl b/lib/WeightInitializers/test/qa_tests.jl new file mode 100644 index 000000000..63f52966f --- /dev/null +++ b/lib/WeightInitializers/test/qa_tests.jl @@ -0,0 +1,33 @@ +@testitem "Aqua: Quality Assurance" begin + using Aqua + + Aqua.test_all(WeightInitializers; ambiguities=false) + Aqua.test_ambiguities(WeightInitializers; recursive=false) +end + +@testitem "Explicit Imports: Quality Assurance" setup=[SharedTestSetup] begin + using ExplicitImports + + @test check_no_implicit_imports(WeightInitializers) === nothing + @test check_no_stale_explicit_imports(WeightInitializers) === nothing + @test check_no_self_qualified_accesses(WeightInitializers) === nothing + @test check_all_explicit_imports_via_owners(WeightInitializers) === nothing + @test check_all_qualified_accesses_via_owners(WeightInitializers) === nothing + @test_broken check_all_explicit_imports_are_public(WeightInitializers) === nothing # mostly upstream problems + + try # FIXME: Soft fail for now + acc = check_all_qualified_accesses_are_public(WeightInitializers) + @test acc === nothing + catch + @test_broken check_all_qualified_accesses_are_public(WeightInitializers) === nothing + end +end + +@testitem "doctests: Quality Assurance" begin + using Documenter + + doctestexpr = :(using Random, WeightInitializers) + + DocMeta.setdocmeta!(WeightInitializers, :DocTestSetup, doctestexpr; recursive=true) + doctest(WeightInitializers; manual=false) +end diff --git a/lib/WeightInitializers/test/runtests.jl b/lib/WeightInitializers/test/runtests.jl new file mode 100644 index 000000000..9de7d16bf --- /dev/null +++ b/lib/WeightInitializers/test/runtests.jl @@ -0,0 +1,30 @@ +using Pkg, ReTestItems, WeightInitializers +using InteractiveUtils, Hwloc + +@info sprint(versioninfo) + +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) + +const EXTRA_PKGS = String[] + +(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, "CUDA") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, "AMDGPU") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") && push!(EXTRA_PKGS, "Metal") +(BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") && push!(EXTRA_PKGS, "oneAPI") + +if !isempty(EXTRA_PKGS) + @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS + Pkg.add(EXTRA_PKGS) + Pkg.update() + Base.retry_load_extensions() + Pkg.instantiate() +end + +const RETESTITEMS_NWORKERS = parse( + Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 4)))) +const RETESTITEMS_NWORKER_THREADS = parse(Int, + get(ENV, "RETESTITEMS_NWORKER_THREADS", + string(max(Hwloc.num_virtual_cores() ÷ RETESTITEMS_NWORKERS, 1)))) + +ReTestItems.runtests(WeightInitializers; nworkers=RETESTITEMS_NWORKERS, + nworker_threads=RETESTITEMS_NWORKER_THREADS) diff --git a/lib/WeightInitializers/test/shared_testsetup.jl b/lib/WeightInitializers/test/shared_testsetup.jl new file mode 100644 index 000000000..8d7cb836a --- /dev/null +++ b/lib/WeightInitializers/test/shared_testsetup.jl @@ -0,0 +1,43 @@ +@testsetup module SharedTestSetup + +using GPUArrays, GPUArraysCore, Random, StableRNGs + +GPUArraysCore.allowscalar(false) + +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) + +RNGS_ARRTYPES = [] +if BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" + append!(RNGS_ARRTYPES, + [(StableRNG(12345), AbstractArray, true, "cpu"), + (Random.GLOBAL_RNG, AbstractArray, true, "cpu")]) +end +if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" + using CUDA + append!(RNGS_ARRTYPES, + [(CUDA.default_rng(), CuArray, true, "cuda"), + (GPUArrays.default_rng(CuArray), CuArray, true, "cuda")]) +end +if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" + using AMDGPU + append!(RNGS_ARRTYPES, + [(AMDGPU.rocrand_rng(), ROCArray, true, "amdgpu"), + (AMDGPU.gpuarrays_rng(), ROCArray, true, "amdgpu")]) +end +if BACKEND_GROUP == "all" || BACKEND_GROUP == "metal" + using Metal + push!(RNGS_ARRTYPES, (Metal.gpuarrays_rng(), MtlArray, false, "metal")) +end +if BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi" + using oneAPI + using oneAPI: oneL0 + + supports_fp64 = oneL0.module_properties(first(oneAPI.devices())).fp64flags & + oneL0.ZE_DEVICE_MODULE_FLAG_FP64 == oneL0.ZE_DEVICE_MODULE_FLAG_FP64 + + push!(RNGS_ARRTYPES, (oneAPI.gpuarrays_rng(), oneArray, supports_fp64, "oneapi")) +end + +export StableRNG, RNGS_ARRTYPES, BACKEND_GROUP, GPUArrays + +end diff --git a/lib/WeightInitializers/test/utils_tests.jl b/lib/WeightInitializers/test/utils_tests.jl new file mode 100644 index 000000000..027fd6d21 --- /dev/null +++ b/lib/WeightInitializers/test/utils_tests.jl @@ -0,0 +1,9 @@ +@testitem "Utils.nfan" begin + using WeightInitializers: Utils + + @test Utils.nfan() == (1, 1) # Fallback + @test Utils.nfan(4) == (1, 4) # Vector + @test Utils.nfan(4, 5) == (5, 4) # Matrix + @test Utils.nfan((4, 5, 6)) == Utils.nfan(4, 5, 6) # Tuple + @test Utils.nfan(4, 5, 6) == 4 .* (5, 6) # Convolution +end diff --git a/test/runtests.jl b/test/runtests.jl index 6d311c8aa..ae8fbc392 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,6 +20,7 @@ end @info "Running tests for group: $LUX_TEST_GROUP" const EXTRA_PKGS = Pkg.PackageSpec[] +const EXTRA_DEV_PKGS = Pkg.PackageSpec[] if ("all" in LUX_TEST_GROUP || "distributed" in LUX_TEST_GROUP) push!(EXTRA_PKGS, Pkg.PackageSpec("MPI")) @@ -34,14 +35,21 @@ if !Sys.iswindows() push!(EXTRA_PKGS, Pkg.PackageSpec("Reactant")) end -(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && - push!(EXTRA_PKGS, Pkg.PackageSpec("LuxCUDA")) +if (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") + if isdir(joinpath(@__DIR__, "../lib/LuxCUDA")) + @info "Using local LuxCUDA" + push!(EXTRA_DEV_PKGS, Pkg.PackageSpec(; path=joinpath(@__DIR__, "../lib/LuxCUDA"))) + else + push!(EXTRA_PKGS, Pkg.PackageSpec("LuxCUDA")) + end +end (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") && push!(EXTRA_PKGS, Pkg.PackageSpec("AMDGPU")) -if !isempty(EXTRA_PKGS) - @info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS - Pkg.add(EXTRA_PKGS) +if !isempty(EXTRA_PKGS) || !isempty(EXTRA_DEV_PKGS) + @info "Installing Extra Packages for testing" EXTRA_PKGS EXTRA_DEV_PKGS + isempty(EXTRA_PKGS) || Pkg.add(EXTRA_PKGS) + isempty(EXTRA_DEV_PKGS) || Pkg.develop(EXTRA_DEV_PKGS) Pkg.update() Base.retry_load_extensions() Pkg.instantiate()